Skip to main content

Short description

Project description

MINIMAL-MJX

WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

This repository represents starter code for MJX-based RL from the Dynamic Mobility Lab. The code has been built of off MuJoCo Playground and tailored to make policy training, saving, and evaluation easy and implements some nice-to-haves.

  1. Swappable backend, allowing for fast evaluation using numpy/C++ MuJoCo and fast training using JaX/MJX
  2. Base environment class that contains generic reward functions and useful functions
  3. Configuration files that allow you to set parameters for your envs + for PPO

There exists a requirements.txt in the home directory which includes the packages needed to run this code.

Steps to take to 'do RL'!

We will work of off the Cheetah environment from Deepmind's "DM Control". This guide is not meant to be comprehensive, but is hopefully enough information to fill in the gaps as you develop and look around the codebase.

  1. The setup basics

    Each environment has a reset and step function, which primarily operate off of a Markov Decision Process (MDP) state. In our case, this is represented as a class that looks like this

        @dataclass
        class MujocoState:
            data: mujoco.MjData
            obs: np.ndarray
            reward: float
            done: bool
            metrics: dict
            info: dict
    

    Note that you should treat info as a 'carry' variable between consecutive step functions. Store things here you need to use in future timesteps that are not ordinarily stored in data (i.e. store a history of system states x).

  2. Simulate your environment to make sure things look good

    Environments should be simulated and inspected before training. Once you are done developing your env, make a configuration file to specify variables specific to training runs, like PPO or reward weight parameters.

    For the Cheetah, one is already created. To simulate the environment, run

    python3 -m envs.simulate envs/dmcontrol/config/cheetah.yaml
    

    Under visualization, you should see a metrics plot as well as a video of your environment. Note that the policy here is simulated to output all zeros (see envs/simulate.py). Inspect these files to make sure things look good.

    For extra good measure, change the backend parameter in your config file to jnp to make sure your environment is JaX compatible.

  3. Train a policy

    Once your environment is ready and JaX compatible, you can train by running

    python3 -m learning.training.begin_run envs/dmcontrol/config/cheetah.yaml
    

    Note that this will train directly in your terminal session. A bash script has been provided at learning/training/train.sh that opens a tmux terminal for this (useful if you want to train for long periods of time). Make sure to change your conda environment name in this script.

    Within your save directory mentioned in your config file, a new directory will be created in which intermediate and final training results will be saved.

  4. Rollout your trained policy

    Once your policy is finished training, simply run

    python3 -m eval.rollout_policy envs/dmcontrol/config/cheetah.yaml
    

    This will rollout your policy, plot metrics of your reward function, and save a video of your result. Make sure to use an np backend for quick evaluation!

    If you would like to rollout older policies, check the directory of your past training runs. There should be a config there, and you can simply replace the config above with that one (running that specific older policy). Note that your code might have changed, so this config holds a git commit that references when the training was run.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

minimal_mjx-0.1.4.tar.gz (18.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

minimal_mjx-0.1.4-py3-none-any.whl (19.5 kB view details)

Uploaded Python 3

File details

Details for the file minimal_mjx-0.1.4.tar.gz.

File metadata

  • Download URL: minimal_mjx-0.1.4.tar.gz
  • Upload date:
  • Size: 18.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for minimal_mjx-0.1.4.tar.gz
Algorithm Hash digest
SHA256 1c7f1eaecdad492cb33234ef76b0b6fa5cfdbc0eaed05626103581e43df5a07e
MD5 798b44a51a13f032175358a8fe6c33a6
BLAKE2b-256 6b7a40de5134e21fef9870ec3e4f9c9822f23f31718d6e9d6b3a314acb6945f1

See more details on using hashes here.

File details

Details for the file minimal_mjx-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: minimal_mjx-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 19.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for minimal_mjx-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 35867cbf56045e9024251156c3da2ed471e615b3b6bf54bac26c70174df40065
MD5 dd26a32355ab78d47f1e53dccf3e1a67
BLAKE2b-256 bd282ddce437c6dc48eaa1626e138e85481b4170f053822c043f0f175690a58e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page