Skip to main content

Short description

Project description

MINIMAL-MJX

WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

This repository represents starter code for MJX-based RL from the Dynamic Mobility Lab. The code has been built of off MuJoCo Playground and tailored to make policy training, saving, and evaluation easy and implements some nice-to-haves.

  1. Swappable backend, allowing for fast evaluation using numpy/C++ MuJoCo and fast training using JaX/MJX
  2. Base environment class that contains generic reward functions and useful functions
  3. Configuration files that allow you to set parameters for your envs + for PPO

There exists a requirements.txt in the home directory which includes the packages needed to run this code.

Steps to take to 'do RL'!

We will work of off the Cheetah environment from Deepmind's "DM Control". This guide is not meant to be comprehensive, but is hopefully enough information to fill in the gaps as you develop and look around the codebase.

  1. The setup basics

    Each environment has a reset and step function, which primarily operate off of a Markov Decision Process (MDP) state. In our case, this is represented as a class that looks like this

        @dataclass
        class MujocoState:
            data: mujoco.MjData
            obs: np.ndarray
            reward: float
            done: bool
            metrics: dict
            info: dict
    

    Note that you should treat info as a 'carry' variable between consecutive step functions. Store things here you need to use in future timesteps that are not ordinarily stored in data (i.e. store a history of system states x).

  2. Simulate your environment to make sure things look good

    Environments should be simulated and inspected before training. Once you are done developing your env, make a configuration file to specify variables specific to training runs, like PPO or reward weight parameters.

    For the Cheetah, one is already created. To simulate the environment, run

    python3 -m envs.simulate envs/dmcontrol/config/cheetah.yaml
    

    Under visualization, you should see a metrics plot as well as a video of your environment. Note that the policy here is simulated to output all zeros (see envs/simulate.py). Inspect these files to make sure things look good.

    For extra good measure, change the backend parameter in your config file to jnp to make sure your environment is JaX compatible.

  3. Train a policy

    Once your environment is ready and JaX compatible, you can train by running

    python3 -m learning.training.begin_run envs/dmcontrol/config/cheetah.yaml
    

    Note that this will train directly in your terminal session. A bash script has been provided at learning/training/train.sh that opens a tmux terminal for this (useful if you want to train for long periods of time). Make sure to change your conda environment name in this script.

    Within your save directory mentioned in your config file, a new directory will be created in which intermediate and final training results will be saved.

  4. Rollout your trained policy

    Once your policy is finished training, simply run

    python3 -m eval.rollout_policy envs/dmcontrol/config/cheetah.yaml
    

    This will rollout your policy, plot metrics of your reward function, and save a video of your result. Make sure to use an np backend for quick evaluation!

    If you would like to rollout older policies, check the directory of your past training runs. There should be a config there, and you can simply replace the config above with that one (running that specific older policy). Note that your code might have changed, so this config holds a git commit that references when the training was run.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

minimal_mjx-0.1.2.tar.gz (18.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

minimal_mjx-0.1.2-py3-none-any.whl (19.5 kB view details)

Uploaded Python 3

File details

Details for the file minimal_mjx-0.1.2.tar.gz.

File metadata

  • Download URL: minimal_mjx-0.1.2.tar.gz
  • Upload date:
  • Size: 18.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for minimal_mjx-0.1.2.tar.gz
Algorithm Hash digest
SHA256 770cc8aef01d25cea09e1f1fe136e9e686850a8ac7e8afae97f0a2b3d22116ce
MD5 342dbdc4982d364a9eb95c6cffc874e0
BLAKE2b-256 abfdd1d4129ea820f879ca9f4782df0a18aa37f2ebb90933539f9b8e73d74140

See more details on using hashes here.

File details

Details for the file minimal_mjx-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: minimal_mjx-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 19.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for minimal_mjx-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 7e67bf7962f7b03cb9f63430247fbe8a2a4cc6fd9e245573fb1f2ce22d5e0d6d
MD5 836f353d96954f48bdbb7a34b66c5ed0
BLAKE2b-256 51f954adae30ff283573979f2e7b549ea559c94e78ced513ea35014340bb894d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page