Skip to main content

Short description

Project description

MINIMAL-MJX

WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

This repository represents starter code for MJX-based RL from the Dynamic Mobility Lab. The code has been built of off MuJoCo Playground and tailored to make policy training, saving, and evaluation easy and implements some nice-to-haves.

  1. Swappable backend, allowing for fast evaluation using numpy/C++ MuJoCo and fast training using JaX/MJX
  2. Base environment class that contains generic reward functions and useful functions
  3. Configuration files that allow you to set parameters for your envs + for PPO

There exists a requirements.txt in the home directory which includes the packages needed to run this code.

Steps to take to 'do RL'!

We will work of off the Cheetah environment from Deepmind's "DM Control". This guide is not meant to be comprehensive, but is hopefully enough information to fill in the gaps as you develop and look around the codebase.

  1. The setup basics

    Each environment has a reset and step function, which primarily operate off of a Markov Decision Process (MDP) state. In our case, this is represented as a class that looks like this

        @dataclass
        class MujocoState:
            data: mujoco.MjData
            obs: np.ndarray
            reward: float
            done: bool
            metrics: dict
            info: dict
    

    Note that you should treat info as a 'carry' variable between consecutive step functions. Store things here you need to use in future timesteps that are not ordinarily stored in data (i.e. store a history of system states x).

  2. Simulate your environment to make sure things look good

    Environments should be simulated and inspected before training. Once you are done developing your env, make a configuration file to specify variables specific to training runs, like PPO or reward weight parameters.

    For the Cheetah, one is already created. To simulate the environment, run

    python3 -m envs.simulate envs/dmcontrol/config/cheetah.yaml
    

    Under visualization, you should see a metrics plot as well as a video of your environment. Note that the policy here is simulated to output all zeros (see envs/simulate.py). Inspect these files to make sure things look good.

    For extra good measure, change the backend parameter in your config file to jnp to make sure your environment is JaX compatible.

  3. Train a policy

    Once your environment is ready and JaX compatible, you can train by running

    python3 -m learning.training.begin_run envs/dmcontrol/config/cheetah.yaml
    

    Note that this will train directly in your terminal session. A bash script has been provided at learning/training/train.sh that opens a tmux terminal for this (useful if you want to train for long periods of time). Make sure to change your conda environment name in this script.

    Within your save directory mentioned in your config file, a new directory will be created in which intermediate and final training results will be saved.

  4. Rollout your trained policy

    Once your policy is finished training, simply run

    python3 -m eval.rollout_policy envs/dmcontrol/config/cheetah.yaml
    

    This will rollout your policy, plot metrics of your reward function, and save a video of your result. Make sure to use an np backend for quick evaluation!

    If you would like to rollout older policies, check the directory of your past training runs. There should be a config there, and you can simply replace the config above with that one (running that specific older policy). Note that your code might have changed, so this config holds a git commit that references when the training was run.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

minimal_mjx-0.1.1.tar.gz (18.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

minimal_mjx-0.1.1-py3-none-any.whl (19.3 kB view details)

Uploaded Python 3

File details

Details for the file minimal_mjx-0.1.1.tar.gz.

File metadata

  • Download URL: minimal_mjx-0.1.1.tar.gz
  • Upload date:
  • Size: 18.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for minimal_mjx-0.1.1.tar.gz
Algorithm Hash digest
SHA256 6c0d6b69fa07d0be8ab49e5527ca55577b47ebdbecac31589b7698997992c003
MD5 0202cd643af8758f7007ad7f9e4a6515
BLAKE2b-256 74b4277b614cf3a6a6f2f12309aa3381ef5d4220a8feff0dd031c0ddfecacbf3

See more details on using hashes here.

File details

Details for the file minimal_mjx-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: minimal_mjx-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 19.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for minimal_mjx-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 2478215da284b7583f00b38064a134dfe1311be4a266936de2c0f8da5350df89
MD5 b34e84ed472ea1d0564cc05b4018d985
BLAKE2b-256 7585f3928a0ebaa2f1930ea008bf999131b8dfcc9e587f89a45bb1b2d6febb29

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page