Short description
Project description
MINIMAL-MJX
WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
This repository represents starter code for MJX-based RL from the Dynamic Mobility Lab. The code has been built of off MuJoCo Playground and tailored to make policy training, saving, and evaluation easy and implements some nice-to-haves.
- Swappable backend, allowing for fast evaluation using numpy/C++ MuJoCo and fast training using JaX/MJX
- Base environment class that contains generic reward functions and useful functions
- Configuration files that allow you to set parameters for your envs + for PPO
There exists a requirements.txt in the home directory which includes the packages
needed to run this code.
Steps to take to 'do RL'!
We will work of off the Cheetah environment from Deepmind's "DM Control". This guide is not meant to be comprehensive, but is hopefully enough information to fill in the gaps as you develop and look around the codebase.
-
The setup basics
Each environment has a
resetandstepfunction, which primarily operate off of a Markov Decision Process (MDP) state. In our case, this is represented as a class that looks like this@dataclass class MujocoState: data: mujoco.MjData obs: np.ndarray reward: float done: bool metrics: dict info: dict
Note that you should treat info as a 'carry' variable between consecutive
stepfunctions. Store things here you need to use in future timesteps that are not ordinarily stored indata(i.e. store a history of system statesx). -
Simulate your environment to make sure things look good
Environments should be simulated and inspected before training. Once you are done developing your env, make a configuration file to specify variables specific to training runs, like PPO or reward weight parameters.
For the Cheetah, one is already created. To simulate the environment, run
python3 -m envs.simulate envs/dmcontrol/config/cheetah.yaml
Under
visualization, you should see a metrics plot as well as a video of your environment. Note that the policy here is simulated to output all zeros (seeenvs/simulate.py). Inspect these files to make sure things look good.For extra good measure, change the
backendparameter in your config file tojnpto make sure your environment is JaX compatible. -
Train a policy
Once your environment is ready and JaX compatible, you can train by running
python3 -m learning.training.begin_run envs/dmcontrol/config/cheetah.yaml
Note that this will train directly in your terminal session. A bash script has been provided at
learning/training/train.shthat opens atmuxterminal for this (useful if you want to train for long periods of time). Make sure to change your conda environment name in this script.Within your save directory mentioned in your config file, a new directory will be created in which intermediate and final training results will be saved.
-
Rollout your trained policy
Once your policy is finished training, simply run
python3 -m eval.rollout_policy envs/dmcontrol/config/cheetah.yaml
This will rollout your policy, plot metrics of your reward function, and save a video of your result. Make sure to use an
npbackend for quick evaluation!If you would like to rollout older policies, check the directory of your past training runs. There should be a config there, and you can simply replace the config above with that one (running that specific older policy). Note that your code might have changed, so this config holds a git commit that references when the training was run.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file minimal_mjx-0.1.4.tar.gz.
File metadata
- Download URL: minimal_mjx-0.1.4.tar.gz
- Upload date:
- Size: 18.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1c7f1eaecdad492cb33234ef76b0b6fa5cfdbc0eaed05626103581e43df5a07e
|
|
| MD5 |
798b44a51a13f032175358a8fe6c33a6
|
|
| BLAKE2b-256 |
6b7a40de5134e21fef9870ec3e4f9c9822f23f31718d6e9d6b3a314acb6945f1
|
File details
Details for the file minimal_mjx-0.1.4-py3-none-any.whl.
File metadata
- Download URL: minimal_mjx-0.1.4-py3-none-any.whl
- Upload date:
- Size: 19.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
35867cbf56045e9024251156c3da2ed471e615b3b6bf54bac26c70174df40065
|
|
| MD5 |
dd26a32355ab78d47f1e53dccf3e1a67
|
|
| BLAKE2b-256 |
bd282ddce437c6dc48eaa1626e138e85481b4170f053822c043f0f175690a58e
|