Common backend for Jax or Numpy.
Project description
Jumpy is a common backend for NumPy and optionally JAX:
- If Jax is installed and jax inputs are provided then the
jax.numpyfunction is run - If Jax is installed and the function is jitted then the
jax.numpyfunction is run - Otherwise the jumpy function returns the NumPy outputs
There are several functions (e.g. vmap, scan) that are available with jax installed.
Jumpy lets you write framework-agnostic code that is easy to debug by running as raw Numpy, but is just as performant as JAX when jitted.
We maintain this repository primarily so to enable writing Gymnasium and PettingZoo wrappers that can be applied to both standard NumPy or hardware accelerated Jax based environments, however this package can be used for many more things.
Installing Jumpy
To install Jumpy from pypi: pip install jax-jumpy[jax] will include jax while pip install jax-jumpy will not include jax.
Alternatively, to install Jumpy from source, clone this repo, cd to it, and then: pip install .
Contributing
Jumpy does not have a complete implementation of all numpy or jax.numpy functions.
If you are missing functions then please create an issue or pull request, we will be happy to add them.
In the future, we are interested in adding optional support for PyTorch and looking for pull request to complete this.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax-jumpy-1.0.0.tar.gz.
File metadata
- Download URL: jax-jumpy-1.0.0.tar.gz
- Upload date:
- Size: 19.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
195fb955cc4c2b7f0b1453e3cb1fb1c414a51a407ffac7a51e69a73cb30d59ad
|
|
| MD5 |
54ab66eeed059a418c6caadb497a906e
|
|
| BLAKE2b-256 |
526ab6affff68f172a4c8316d9ab9b7d952e865df15b854f158690991864e0fe
|
File details
Details for the file jax_jumpy-1.0.0-py3-none-any.whl.
File metadata
- Download URL: jax_jumpy-1.0.0-py3-none-any.whl
- Upload date:
- Size: 20.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ab7e01454bba462de3c4d098e3e585c302a8f06bc36d9182ab4e7e4aa7067c5e
|
|
| MD5 |
260866f1fd141ca864c3fbb0553b0cbe
|
|
| BLAKE2b-256 |
7323338caee543d80584916da20f018aeb017764509d964fd347b97f41f97baa
|