An open-ended space of 2D physics-based RL environments
Project description
Update: Kinetix was accepted at ICLR 2025 as an oral!
Kinetix
Kinetix is a framework for reinforcement learning in a 2D rigid-body physics world, written entirely in JAX. Kinetix can represent a huge array of physics-based tasks within a unified framework. We use Kinetix to investigate the training of large, general reinforcement learning agents by procedurally generating millions of tasks for training. You can play with Kinetix in our online editor, or have a look at the JAX physics engine and graphics library we made for Kinetix. Finally, see our docs for more information and more in-depth examples.
The above shows specialist agents trained on their respective levels.
📊 Paper TL; DR
We train a general agent on millions of procedurally generated physics tasks. Every task has the same goal: make the green and blue touch, without green touching red. The agent can act through applying torque via motors and force via thrusters.
The above shows a general agent zero-shotting unseen randomly generated levels.
We then investigate the transfer capabilities of this agent to unseen handmade levels. We find that the agent can zero-shot simple physics problems, but still struggles with harder tasks.
The above shows a general agent zero-shotting unseen handmade levels.
📜 Basic Usage
Kinetix follows the interface established in gymnax:
# Use default parameters
env_params = EnvParams()
static_env_params = StaticEnvParams()
# Create the environment
env = make_kinetix_env(
observation_type=ObservationType.PIXELS,
action_type=ActionType.CONTINUOUS,
reset_fn=make_reset_fn_sample_kinetix_level(env_params, static_env_params),
env_params=env_params,
static_env_params=static_env_params,
)
# Reset the environment state (this resets to a random level)
_rngs = jax.random.split(jax.random.PRNGKey(0), 3)
obs, env_state = env.reset(_rngs[0], env_params)
# Take a step in the environment
action = env.action_space(env_params).sample(_rngs[1])
obs, env_state, reward, done, info = env.step(_rngs[2], env_state, action, env_params)
# Render environment
renderer = make_render_pixels(env_params, env.static_env_params)
pixels = renderer(env_state)
plt.imshow(pixels.astype(jnp.uint8).transpose(1, 0, 2)[::-1])
plt.show()
⬇️ Installation
To install Kinetix (tested with python3.10):
git clone https://github.com/FlairOx/Kinetix.git
cd Kinetix
pip install -e ".[dev]"
pre-commit install
Please see here to install jax for your accelerator.
[!TIP] Setting
export JAX_COMPILATION_CACHE_DIR="$HOME/.jax_cache"in your~/.bashrchelps improve usability by caching the jax compiles.
Kinetix is also available on PyPi, and can be installed using pip install kinetix-env
🎯 Editor
We recommend using the KinetixJS editor, but also provide a native (less polished) Kinetix editor.
To open this editor run the following command.
python3 kinetix/editor.py
The controls in the editor are:
- Move between
editandplaymodes usingspacebar - In
editmode, the type of edit is shown by the icon at the top and is changed by scrolling the mouse wheel. For instance, by navigating to the rectangle editing function you can click to place a rectangle.- You can also press the number keys to cycle between modes.
- To open handmade levels press ctrl-O and navigate to the ones in the L folder.
- When playing a level use the arrow keys to control motors and the numeric keys (1, 2) to control thrusters.
📈 Experiments
We have three primary experiment files,
- SFL: Training on levels with high learnability, this is how we trained our best general agents.
- PLR PLR/DR/ACCEL in the JAXUED style.
- PPO Normal PPO in the PureJaxRL style.
To run experiments with default parameters run any of the following:
python3 experiments/sfl.py
python3 experiments/plr.py
python3 experiments/ppo.py
python3 experiments/plr.py ued.replay_prob=0 # for DR
We use hydra for managing our configs. See the configs/ folder for all the hydra configs that will be used by default, or the docs.
If you want to run experiments with different configurations, you can either edit these configs or pass command line arguments as follows:
python3 experiments/sfl.py model.transformer_depth=8
These experiments use wandb for logging by default.
[!Note] Experiments tend to run faster when you have JAX's persistent compilation cache enabled, and you can set it, for instance, as
export JAX_COMPILATION_CACHE_DIR=.jax_cache
🏋️ Training RL Agents
We provide several different ways to train RL agents, with the three most common options being, (a) Training an agent on random levels, (b) Training an agent on a single, hand-designed level or (c) Training an agent on a set of hand-designed levels.
[!WARNING] Kinetix has three different environment sizes,
s,mandl. When running any of the scripts, you have to set theenv_sizeoption accordingly, for instance,python3 experiments/ppo.py train_levels=random env_size=mwould train on randommlevels. It will give an error if you try and load large levels into a small env size, for instancepython3 experiments/ppo.py train_levels=m env_size=swould error.
Training on random levels
This is the default option, but we give the explicit command for completeness
python3 experiments/ppo.py train_levels=random
Training on a single hand-designed level
[!NOTE] Check the
kinetix/levels/folder for handmade levels for each size category. By default, the loading functions require a relative path to thekinetix/levels/directory
python3 experiments/ppo.py train_levels=s train_levels.train_levels_list='["s/h4_thrust_aim.json"]'
Training on a set of hand-designed levels
python3 experiments/ppo.py train_levels=s env_size=s eval=eval_auto
# python3 experiments/ppo.py train_levels=m env_size=m eval=eval_auto
# python3 experiments/ppo.py train_levels=l env_size=l eval=eval_auto
Or, on a custom set:
python3 experiments/ppo.py eval=eval_auto train_levels=l env_size=l train_levels.train_levels_list='["s/h2_one_wheel_car","l/h11_obstacle_avoidance"]'
💨 Compilation Speed
Since Kinetix is quite complex, it generally takes quite a long time to compile. In particular, running plr.py or sfl.py may take a long time to get to actually executing code. This can be a burden when you are implementing new features, and just want to debug quickly. To make this easier, we provide two options: train_levels=dummy env.dummy_env=True (e.g. using python experiments/sfl.py train_levels=dummy env.dummy_env=True). These options replace the actual environment step and reset logic with no-ops, meaning that the compilation process will be much faster. However, no logic will be executed, so this is only to check syntax / shape / jax errors, and not to debug learning issues.
❌ Errata
- The left wall was erroneously misplaced 5cm to the left in all levels and all experiments in the paper (each level is a square with side lengths of 5 metres). This error has been fixed in the latest version of Jax2D, but we have pinned Kinetix to the old version for consistency and reproducability with the original paper. Further improvements have been made, so if you wish to reproduce the paper's results, please use kinetix version 0.1.0, which is tagged on github.
🔎 See Also
- 🌐 Kinetix.js Kinetix reimplemented in Javascript, with a live demo here.
- 🍎 Jax2D The physics engine we made for Kinetix.
- 👨💻 JaxGL The graphics library we made for Kinetix.
- 📋 Our Paper for more details and empirical results.
🙏 Acknowledgements
The permutation invariant MLP model (enabled by setting model.permutation_invariant_mlp=True) was added by Anya Sims. Thanks to Thomas Foster for fixing some macOS specific issues. We'd also like to thank to Thomas Foster, Alex Goldie, Matthew Jackson, Sebastian Towers and Andrei Lupu for useful feedback.
📚 Citation
If you use Kinetix in your work, please cite it as follows:
@article{matthews2024kinetix,
title={Kinetix: Investigating the Training of General Agents through Open-Ended Physics-Based Control Tasks},
author={Michael Matthews and Michael Beukman and Chris Lu and Jakob Foerster},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://arxiv.org/abs/2410.23208}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file kinetix_env-2.0.4.tar.gz.
File metadata
- Download URL: kinetix_env-2.0.4.tar.gz
- Upload date:
- Size: 189.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4cffa15fbda53dc6e52ab646946c05b5bf3189d923a0309cfa619186c3890c03
|
|
| MD5 |
392bf982999ed200a62d958f20a12e62
|
|
| BLAKE2b-256 |
e374f0686687d8d0cf593540cae4383f5fb1ed42ad909a7332e8f95a95b3cc0e
|
File details
Details for the file kinetix_env-2.0.4-py3-none-any.whl.
File metadata
- Download URL: kinetix_env-2.0.4-py3-none-any.whl
- Upload date:
- Size: 252.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
58d13ef95dd8e45e7486564b0d712102cc1fe15f6e1b76d4821a1583e27e1893
|
|
| MD5 |
61e0494161938c6d9b77d7b1fcc996ff
|
|
| BLAKE2b-256 |
ac172eccd5e9e6c713403abff1c1730a8bbcd481035bb4b9d1a903b58cec1492
|