Accelerated gridworld navigation with JAX for deep reinforcement learning
Project description
NAVIX: minigrid in JAX
Quickstart | Installation | Examples | Cite
What is NAVIX?
NAVIX is minigrid in JAX, >1000x faster with Autograd and XLA support. You can see a superficial performance comparison here.
The library is in active development, and we are working on adding more environments and features. If you want join the development and contribute, please open a discussion and let's have a chat!
Installation
We currently support the OSs supported by JAX. You can find a description here.
You might want to follow the same guide to install jax for your faviourite accelerator (e.g. CPU, GPU, or TPU ).
-
Stable
Then, install the stable version of navix
and its dependencies with:
pip install navix
-
Nightly
Or, if you prefer to install the latest version from source:
pip install git+https://github.com/epignatelli/navix
Examples
XLA compilation
One straightforward use case is to accelerate the computation of the environment with XLA compilation. For example, here we vectorise the environment to run multiple environments in parallel, and compile the full training run.
You can find a partial performance comparison with minigrid in the docs.
import jax
import navix as nx
def run(seed)
env = nx.environments.Room(16, 16, 8)
key = jax.random.PRNGKey(seed)
timestep = env.reset(key)
actions = jax.random.randint(key, (N_TIMESTEPS,), 0, 6)
def body_fun(timestep, action):
timestep = env.step(timestep, jnp.asarray(action))
return timestep, ()
return jax.lax.scan(body_fun, timestep, jnp.asarray(actions, dtype=jnp.int32))[0]
final_timestep = jax.jit(jax.vmap(run))(jax.numpy.arange(1000))
Backpropagation through the environment
Another use case it to backpropagate through the environment transition function, for example to learn a world model.
TODO(epignatelli): add example.
Cite
If you use navix
please consider citing it as:
@misc{pignatelli2023navix,
author = {Pignatelli, Eduardo},
title = {Navix: Accelerated gridworld navigation with JAX},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/epignatelli/navix}}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for Navix-0.3.6-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7258cfb2db10e0c270745946be3f219fd02713f88091f471d92d4785ddfc0d74 |
|
MD5 | fa8632bbc3f3e122caa9ea1467083960 |
|
BLAKE2b-256 | 75837c9e5f7419f8746c30e580af509fadcd6f5b03b713b2f00e6d48f27cde90 |