Gridworld navigation for Reinforcement Learning with JAX
Project description
NAVIX
Quickstart | Installation | Examples | Cite
What is NAVIX?
NAVIX is minigrid in JAX, ~2000x faster with Autograd and XLA support. You can see a superficial performance comparison here.
Installation
We currently support the OSs that are supported by JAX does. Please follow the JAX installation guide to install the correct version of JAX for your OS.
You might want to follow the same guide to use your faviourite accelerator (e.g. CPU, GPU, or TPU ).
Then, install navix
with:
pip install navix
Examples
XLA compilation
One straightforward use case is to accelerate the computation of the environment with XLA compilation. For example, here we compile a full training run, and we vectorise the environment to run multiple environments in parallel.
You can find a superficial performance comparison with minigrid in the docs.
import jax
import navix as nx
def run(seed)
env = nx.environments.Room(16, 16, 8)
key = jax.random.PRNGKey(seed)
timestep = env.reset(key)
actions = jax.random.randint(key, (N_TIMESTEPS,), 0, 6)
def body_fun(timestep, action):
timestep = env.step(timestep, jnp.asarray(action))
return timestep, ()
return jax.lax.scan(body_fun, timestep, jnp.asarray(actions, dtype=jnp.int32))[0]
final_timestep = jax.jit(jax.vmap(run))(jax.numpy.arange(1000))
Backpropagation through the environment
Another use case it to backpropagate through the environment transition function, for example to learn a world model.
TODO(epignatelli): add example.
Cite
If you use helx
please consider citing it as:
@misc{helx,
author = {Pignatelli, Eduardo},
title = {Navix: Reinforcement Learning navigation with Autograd and XLA},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/epignatelli/navix}}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for Navix-0.1.1-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | eb671df1401e777d329effb56eeb2fe98f4e937ea570611017ffec6871f7163e |
|
MD5 | b3e4744c3cd868590f971827c1a6019c |
|
BLAKE2b-256 | d6733d33e1dccde2977b220631fa73af434034142595fe437b99576a16098723 |