Skip to main content

A library written in Jax that provides help for using DeepMind's mctx on gym-style environments.

Project description

muax 😘

Muax provides help for using DeepMind's mctx on gym-style environments.

Installation

You can install the released version of muax through PyPI:

pip install muax

Getting started

Muax provides some functions around mctx's high-level policy muzero_policy. The usage of muax could be similar to using policies like DQN, PPO and etc. For instance, in a typical loop for interacting with the environment, the code is like(code snippet from muax/test):

random_seed = 0
key = jax.random.PRNGKey(random_seed)
obs, info = env.reset(seed=random_seed)
done = False
episode_reward = 0
for t in range(env.spec.max_episode_steps):
    key, subkey = jax.random.split(key)
    a = model.act(subkey, obs, 
                  num_simulations=num_simulations,
                  temperature=0.) # Use deterministic actions during testing
    obs_next, r, done, truncated, info = env.step(a)
    episode_reward += r
    if done or truncated:
        break 
    obs = obs_next

Check cartpole.ipynb for a basic training example(The notebook should be runnable on colab).

  1. To train a MuZero model, the user needs to define the representation_fn, prediction_fn and dynamic_fn with haiku. muax/nn provides an example of defining an MLP with single hidden layer.
import jax 
jax.config.update('jax_platform_name', 'cpu')

import muax
from muax import nn 

support_size = 10 
embedding_size = 8
num_actions = 2
full_support_size = int(support_size * 2 + 1)

repr_fn = nn._init_representation_func(nn.Representation, embedding_size)
pred_fn = nn._init_prediction_func(nn.Prediction, num_actions, full_support_size)
dy_fn = nn._init_dynamic_func(nn.Dynamic, embedding_size, num_actions, full_support_size)
  1. muax has built-in episode tracer and replay buffuer to track and store trajectories from interacting with environments. The first parameter of muax.PNStep (10 in the following code) is the n for n-step bootstrapping.
discount = 0.99
tracer = muax.PNStep(10, discount, 0.5)
buffer = muax.TrajectoryReplayBuffer(500)
  1. muax leverages optax to build optimizer to update weights
gradient_transform = muax.model.optimizer(init_value=0.02, peak_value=0.02, end_value=0.002, warmup_steps=5000, transition_steps=5000)
  1. Now we are ready to call muax.fit function to fit the model to the CartPole environment
model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)

model_path = muax.fit(model, 'CartPole-v1', 
                    max_episodes=1000,
                    max_training_steps=10000,
                    tracer=tracer,
                    buffer=buffer,
                    k_steps=10,
                    sample_per_trajectory=1,
                    num_trajectory=32,
                    tensorboard_dir='/content/tensorboard/cartpole',
                    model_save_path='/content/models/cartpole',
                    save_name='cartpole_model_params',
                    random_seed=0,
                    log_all_metrics=True)

The full training script:

import muax
from muax import nn 

support_size = 10 
embedding_size = 8
discount = 0.99
num_actions = 2
full_support_size = int(support_size * 2 + 1)

repr_fn = nn._init_representation_func(nn.Representation, embedding_size)
pred_fn = nn._init_prediction_func(nn.Prediction, num_actions, full_support_size)
dy_fn = nn._init_dynamic_func(nn.Dynamic, embedding_size, num_actions, full_support_size)

tracer = muax.PNStep(10, discount, 0.5)
buffer = muax.TrajectoryReplayBuffer(500)

gradient_transform = muax.model.optimizer(init_value=0.02, peak_value=0.02, end_value=0.002, warmup_steps=5000, transition_steps=5000)

model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)

model_path = muax.fit(model, 'CartPole-v1', 
                    max_episodes=1000,
                    max_training_steps=10000,
                    tracer=tracer,
                    buffer=buffer,
                    k_steps=10,
                    sample_per_trajectory=1,
                    num_trajectory=32,
                    tensorboard_dir='/content/tensorboard/cartpole',
                    model_save_path='/content/models/cartpole',
                    save_name='cartpole_model_params',
                    random_seed=0,
                    log_all_metrics=True)
  1. After the training is done, one can use tensorboard to check the training procedure
%load_ext tensorboard 
%tensorboard --logdir=tensorboard/cartpole

In the figure below, the model is able to solve the environment in ~500 episodes, ~30k updates

tensorboard example

  1. We can also have more tests with the best parameter
from muax.test import test

model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)

model.load(model_path)

env_id = 'CartPole-v1'
test_env = gym.make(env_id, render_mode='rgb_array')
test_key = jax.random.PRNGKey(0)
test(model, test_env, test_key, num_simulations=50, num_test_episodes=100, random_seed=None)

Alternatively, the users could easily write their own training loop. One example is from cartpole.ipynb

More examples can be found under the example directory.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

muax-0.0.2.8.3.tar.gz (28.2 kB view hashes)

Uploaded Source

Built Distribution

muax-0.0.2.8.3-py3-none-any.whl (32.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page