JAX random number generation as a NumPy generator
Project description
rng-jax — NumPy random number generator API for JAX
This is a proof of concept only.
Wraps JAX's stateless random number generation in a class implementing the
numpy.random.Generator interface.
Example
>>> import rng_jax
>>> rng = rng_jax.Generator(42) # same arguments as jax.random.key()
>>> rng.standard_normal(3)
Array([-0.5675502 , 0.28439185, -0.9320608 ], dtype=float32)
>>> rng.standard_normal(3)
Array([ 0.67903334, -1.220606 , 0.94670606], dtype=float32)
Rationale
The Array API makes it possible to write array-agnostic Python
libraries. The rng-jax package makes it easy to extend this to random number
generation in NumPy and JAX. End users only need to provide a rng object, as
usual, which can either be a NumPy one or a rng_jax.Generator instance
wrapping JAX's stateless random number generation.
How it works
The rng_jax.Generator class works in the obvious way: it keeps track of the
JAX key and calls jax.random.split() before every random operation.
JIT and native JAX code
The problem with a stateful RNG is that it cannot be passed into a compiled JAX
function. In practice, this is not usually an issue, since the goal of this
package is to work in tandem with the Array API: array-agnostic code is not
usually compiled at low level. Conversely, native JAX code usually expects a
key, anyway, not a rng_jax.Generator instance.
To interface with a native JAX function expecting a key, use the .split()
method to obtain a new random key and advance the internal state of the
generator:
>>> import jax
>>> rng = rng_jax.Generator(42)
>>> key = rng.split()
>>> jax.random.normal(key, 3)
Array([-0.5675502 , 0.28439185, -0.9320608 ], dtype=float32)
>>> key = rng.split()
>>> jax.random.normal(key, 3)
Array([ 0.67903334, -1.220606 , 0.94670606], dtype=float32)
Using the rng_jax.Generator class fully within a compiled JAX function
works without issue.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file rng_jax-0.0.4.tar.gz.
File metadata
- Download URL: rng_jax-0.0.4.tar.gz
- Upload date:
- Size: 5.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.0.1 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6c92631774d72a5e0bf4c46df2b49a0ec0cc6d9e1ebefc2ea05a7e3902545354
|
|
| MD5 |
8dd5deb642d976405a59bc6b6582a972
|
|
| BLAKE2b-256 |
1196ade8e7fdfafa0f7bc1a836085aa33a561a50940d0c5c081a3a52e1e2531d
|
Provenance
The following attestation bundles were made for rng_jax-0.0.4.tar.gz:
Publisher:
release.yml on glass-dev/rng-jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
rng_jax-0.0.4.tar.gz -
Subject digest:
6c92631774d72a5e0bf4c46df2b49a0ec0cc6d9e1ebefc2ea05a7e3902545354 - Sigstore transparency entry: 158146026
- Sigstore integration time:
-
Permalink:
glass-dev/rng-jax@2223745291d08c28088ae7e4bd7f43e2b043f725 -
Branch / Tag:
refs/tags/v0.0.4 - Owner: https://github.com/glass-dev
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@2223745291d08c28088ae7e4bd7f43e2b043f725 -
Trigger Event:
release
-
Statement type:
File details
Details for the file rng_jax-0.0.4-py3-none-any.whl.
File metadata
- Download URL: rng_jax-0.0.4-py3-none-any.whl
- Upload date:
- Size: 4.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.0.1 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
746ec5cf57a8e57b60289a69d241e8c02ba7c521250c59d0bc7a02171d16a26e
|
|
| MD5 |
519edcaeabe40f32ee29de46a712bbd8
|
|
| BLAKE2b-256 |
1c6e156a7fc2b27839ecc2ba3adc514ff78c30926f3e062bfe1982160251889e
|
Provenance
The following attestation bundles were made for rng_jax-0.0.4-py3-none-any.whl:
Publisher:
release.yml on glass-dev/rng-jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
rng_jax-0.0.4-py3-none-any.whl -
Subject digest:
746ec5cf57a8e57b60289a69d241e8c02ba7c521250c59d0bc7a02171d16a26e - Sigstore transparency entry: 158146027
- Sigstore integration time:
-
Permalink:
glass-dev/rng-jax@2223745291d08c28088ae7e4bd7f43e2b043f725 -
Branch / Tag:
refs/tags/v0.0.4 - Owner: https://github.com/glass-dev
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@2223745291d08c28088ae7e4bd7f43e2b043f725 -
Trigger Event:
release
-
Statement type: