Skip to main content

A library written in Jax that provides help for using DeepMind's mctx on gym-style environments.

Project description

muax 😘

Muax provides help for using DeepMind's mctx on gym-style environments.

Installation

You can install the released version of muax through PyPI:

pip install muax

Getting started

Muax provides some functions around mctx's high-level policy muzero_policy. The usage of muax could be similar to using policies like DQN, PPO and etc. For instance, in a typical loop for interacting with the environment, the code is like(code snippet from muax/test):

random_seed = 0
key = jax.random.PRNGKey(random_seed)
obs, info = env.reset(seed=random_seed)
done = False
episode_reward = 0
for t in range(env.spec.max_episode_steps):
    key, subkey = jax.random.split(key)
    a = model.act(subkey, obs, 
                  num_simulations=num_simulations,
                  temperature=0.) # Use deterministic actions during testing
    obs_next, r, done, truncated, info = env.step(a)
    episode_reward += r
    if done or truncated:
        break 
    obs = obs_next

Check cartpole.ipynb for a basic training example(The notebook should be runnable on colab).

  1. To train a MuZero model, the user needs to define the representation_fn, prediction_fn and dynamic_fn with haiku. muax/nn provides an example of defining an MLP with single hidden layer.
import jax 
jax.config.update('jax_platform_name', 'cpu')

import muax
from muax import nn 

support_size = 10 
embedding_size = 8
num_actions = 2
full_support_size = int(support_size * 2 + 1)

repr_fn = nn._init_representation_func(nn.Representation, embedding_size)
pred_fn = nn._init_prediction_func(nn.Prediction, num_actions, full_support_size)
dy_fn = nn._init_dynamic_func(nn.Dynamic, embedding_size, num_actions, full_support_size)
  1. muax has built-in episode tracer and replay buffuer to track and store trajectories from interacting with environments. The first parameter of muax.PNStep (10 in the following code) is the n for n-step bootstrapping.
discount = 0.99
tracer = muax.PNStep(10, discount, 0.5)
buffer = muax.TrajectoryReplayBuffer(500)
  1. muax leverages optax to build optimizer to update weights
gradient_transform = muax.model.optimizer(init_value=0.02, peak_value=0.02, end_value=0.002, warmup_steps=5000, transition_steps=5000)
  1. Now we are ready to call muax.fit function to fit the model to the CartPole environment
model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)

model_path = muax.fit(model, 'CartPole-v1', 
                    max_episodes=1000,
                    max_training_steps=10000,
                    tracer=tracer,
                    buffer=buffer,
                    k_steps=10,
                    sample_per_trajectory=1,
                    num_trajectory=32,
                    tensorboard_dir='/content/tensorboard/cartpole',
                    model_save_path='/content/models/cartpole',
                    save_name='cartpole_model_params',
                    random_seed=0,
                    log_all_metrics=True)

The full training script:

import muax
from muax import nn 

support_size = 10 
embedding_size = 8
discount = 0.99
num_actions = 2
full_support_size = int(support_size * 2 + 1)

repr_fn = nn._init_representation_func(nn.Representation, embedding_size)
pred_fn = nn._init_prediction_func(nn.Prediction, num_actions, full_support_size)
dy_fn = nn._init_dynamic_func(nn.Dynamic, embedding_size, num_actions, full_support_size)

tracer = muax.PNStep(10, discount, 0.5)
buffer = muax.TrajectoryReplayBuffer(500)

gradient_transform = muax.model.optimizer(init_value=0.02, peak_value=0.02, end_value=0.002, warmup_steps=5000, transition_steps=5000)

model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)

model_path = muax.fit(model, 'CartPole-v1', 
                    max_episodes=1000,
                    max_training_steps=10000,
                    tracer=tracer,
                    buffer=buffer,
                    k_steps=10,
                    sample_per_trajectory=1,
                    num_trajectory=32,
                    tensorboard_dir='/content/tensorboard/cartpole',
                    model_save_path='/content/models/cartpole',
                    save_name='cartpole_model_params',
                    random_seed=0,
                    log_all_metrics=True)
  1. After the training is done, one can use tensorboard to check the training procedure
%load_ext tensorboard 
%tensorboard --logdir=tensorboard/cartpole

In the figure below, the model is able to solve the environment in ~500 episodes, ~30k updates

tensorboard example

  1. We can also have more tests with the best parameter
from muax.test import test

model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)

model.load(model_path)

env_id = 'CartPole-v1'
test_env = gym.make(env_id, render_mode='rgb_array')
test_key = jax.random.PRNGKey(0)
test(model, test_env, test_key, num_simulations=50, num_test_episodes=100, random_seed=None)

Alternatively, the users could easily write their own training loop. One example is from cartpole.ipynb

More examples can be found under the example directory.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

muax-0.0.2.8.3.tar.gz (28.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

muax-0.0.2.8.3-py3-none-any.whl (32.2 kB view details)

Uploaded Python 3

File details

Details for the file muax-0.0.2.8.3.tar.gz.

File metadata

  • Download URL: muax-0.0.2.8.3.tar.gz
  • Upload date:
  • Size: 28.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.11

File hashes

Hashes for muax-0.0.2.8.3.tar.gz
Algorithm Hash digest
SHA256 4d4bfbebfed7fd8645ac9f8ded9a2dbcd0181249bbef5d7231bfa3719cfc4eea
MD5 be800e4f98d7556630e8f04e095fe202
BLAKE2b-256 c25177fa9047e6e0f46b67d0ff14a6200390cbf61e8025d54390ced29a464950

See more details on using hashes here.

File details

Details for the file muax-0.0.2.8.3-py3-none-any.whl.

File metadata

  • Download URL: muax-0.0.2.8.3-py3-none-any.whl
  • Upload date:
  • Size: 32.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.11

File hashes

Hashes for muax-0.0.2.8.3-py3-none-any.whl
Algorithm Hash digest
SHA256 c63433d12fb8d0423e9c197a0890693c044b13edd8e0e1940e5900240c261352
MD5 61bbae3c6bec50b3e6ba195ff8518ee1
BLAKE2b-256 5751e1cb0a9d91e9cde3bf78a78688af8a7800934e9e8b2aeae09b4b74b4d3fe

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page