Friendly Interface to JAX, that behaves similar to PyTorch while maintaining compatibility.
Project description
jaxfi (JAX Friendly Interface) - JAX with a PyTorch-like interface
Friendly Interface to JAX, that behaves similar to PyTorch while maintaining compatibility.
jaxfi(JAX Friendly Interface) - JAX with a PyTorch-like interface- Working with CPU and GPU
- JAX modules are accessible directly
- 🔪 The Sharp Bits 🔪
- Notes
- Installation
- Changelog
News: Better, improved interface! import jaxfi as jaxm is all you need!
Creates a JAX-like module that behaves very similarly to PyTorch, so
>>> import jaxfi as jaxm
jaxm.norm === torch.norm
jaxm.rand === torch.rand
jaxm.cat === torch.cat
jaxm.manual_seed === torch.manual_seed
Make sure to import this module before anything that might import jax (e.g., jaxopt).
# DO
import jaxfi as jaxm
import jaxopt
# DON'T!!!
import jaxopt
import jaxfi as jaxm
Working with CPU and GPU
JAX has automatic device placement in functions, so omit the
deviceargument when creating arrays in functions, i.e., in functions, specify only the dtype.
Placing arrays on GPU and CPU is easy, either specify device/dtype directly or
use jaxm.to to move the array to a specific device/dtype.
>>> jaxm.rand(2, device="cuda")
>>> jaxm.rand(2, device="gpu", dtype=jaxm.float64)
>>> jaxm.rand(2, device="cpu")
>>> jaxm.to(jaxm.zeros(2), "cuda")
Arrays are created on the CPU by default, but that can be changed using
jaxm.set_default_dtype(jaxm.float32)
jaxm.set_default_device("gpu")
jaxm.get_default_device()
jaxm.get_default_dtype()
Default dtype refers to CPU default dtype, default GPU dtype is always float32, but float64 arrays can be created on the GPU by specifying the dtype explicitly or by using jaxm.to.
jaxm behaves like numpy (jax.numpy). Some methods are
patched directly from jax.
jaxm.grad === jax.grad
jaxm.jacobian === jax.jacobian
jaxm.hessian === jax.hessian
jaxm.jit === jax.jit
jaxm.vmap === jax.vmap
JAX modules are accessible directly
Finally, jax-backed modules are available directly in jaxm
>>> jaxm.jax
>>> jaxm.numpy
>>> jaxm.random
>>> jaxm.scipy
>>> jaxm.lax
🔪 The Sharp Bits 🔪
Random numbers are implemented using a global random key (which can also be
manually set using e.g., jaxm.manual_seed(2023)). However, that means parallelized
routines will generate the same random numbers.
# DON'T DO THIS
jaxm.jax.vmap(lambda _: jaxm.randn(10))(jaxm.arange(10)) # every row of random numbers is the same!
# DO THIS INSTEAD
n = 10
random_keys = jaxm.make_random_keys(n)
jaxm.jax.vmap(lambda key, idx: jaxm.randn(10, key=key))(random_keys, jaxm.arange(n))
jit-ted functions will also return the same random numbers every time
# DON'T DO THIS
f = jaxm.jit(lambda x: x * jaxm.randn(3))
f(1) # [-1.12918106, -2.04245763, -0.40538156]
f(1) # [-1.12918106, -2.04245763, -0.40538156]
f(1) # [-1.12918106, -2.04245763, -0.40538156]
# DO THIS
f = jaxm.jit(lambda x, key=None: x * jaxm.randn(3, key=key))
f(1) # [-1.12918106, -2.04245763, -0.40538156]
f(1, jaxm.make_random_key()) # [-2.58426713, 0.90726101, 2.1546499 ]
# jaxm.make_random_keys(n) is also available
Notes
I'm not affiliated with JAX or PyTorch in any way.
Installation
$ pip install jaxfi
The package name recently change from jfi to jaxfi, PyPI hosts it as jaxfi.
Alternatively, to install from source, issue
$ pip install .
from the project root, or simply run
$ pip install git+https://github.com/rdyro/jaxfi-JAXFriendlyInterface.git
If you wish to let JAX (not jaxfi) work alongside PyTorch in the same virtual
environment, set/export the environment variable JAXFI_LOAD_SYSTEM_CUDA_LIBS=true
before importing jaxfi or jax for the first time.
$ echo 'export JAXFI_LOAD_SYSTEM_CUDA_LIBS=true' >> ~/.bashrc
$ echo 'export JAXFI_LOAD_SYSTEM_CUDA_LIBS=true' >> ~/.zshrc
This will instruct jaxfi to dynamically load the system CUDA libraries.
Changelog
-
version 0.7.3
- fixed random functions not accepting
key=kwargs for under-jit random number generation
- fixed random functions not accepting
-
version 0.7.0
jaxfiis now identical withjaxmso that bothimport jaxfi as jaxmandfrom jaxfi import jaxmwork- this change helps (at least the VSCode) Pylance resolve member fields in
jaxfi
-
version 0.6.6
- random functions now (correctly) produce uncommitted arrays (see https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices)
- added a PyTorch-like randperm function (implemented as argsort(rand))
-
version 0.6.5
- added the ability to dynamically load the system CUDA libraries so allowing
JAX to live in harmony with PyTorch, set the environment variable
JAXFI_LOAD_SYSTEM_CUDA_LIBS=trueto enable this feature
- added the ability to dynamically load the system CUDA libraries so allowing
JAX to live in harmony with PyTorch, set the environment variable
-
version 0.6.3
jaxm.tonow also moves numpy, not just jax, arrays to a device and dtype- experimental
auto_pmapfunction available, automatically assigning first batch dimension to multiple devices, e.g., dividing 16 tasks into 6 CPUs
-
version 0.6.0
- official name change from
jfitojaxfi
- official name change from
-
version 0.5.0
- settled on the default numpy module copy behavior
- omit
devicewhen creating arrays in functions - this now works correctly - introduced more tests
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxfi-0.7.3-py3-none-any.whl.
File metadata
- Download URL: jaxfi-0.7.3-py3-none-any.whl
- Upload date:
- Size: 17.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e64bdb89846b20cba2bb64dbcbc0f67d575fb4921df6f48321f86fe83026017b
|
|
| MD5 |
b52904844a7323ca64da68f7728a6aad
|
|
| BLAKE2b-256 |
f486d9acd807924b571f1107dab1a879c7696d9a38adc1c35a73a15b594fc942
|