A library written in Jax that provides help for using DeepMind's mctx on gym-style environments.
Project description
muax 😘
Muax provides help for using DeepMind's mctx on gym-style environments.
Installation
You can install the released version of muax through PyPI:
pip install muax
Getting started
Muax provides some functions around mctx's high-level policy muzero_policy. The usage of muax could be similar to using policies like DQN, PPO and etc. For instance, in a typical loop for interacting with the environment, the code is like(code snippet from muax/test):
random_seed = 0
key = jax.random.PRNGKey(random_seed)
obs, info = env.reset(seed=random_seed)
done = False
episode_reward = 0
for t in range(env.spec.max_episode_steps):
key, subkey = jax.random.split(key)
a = model.act(subkey, obs,
num_simulations=num_simulations,
temperature=0.) # Use deterministic actions during testing
obs_next, r, done, truncated, info = env.step(a)
episode_reward += r
if done or truncated:
break
obs = obs_next
Check cartpole.ipynb for a basic training example(The notebook should be runnable on colab).
- To train a MuZero model, the user needs to define the
representation_fn,prediction_fnanddynamic_fnwith haiku. muax/nn provides an example of defining an MLP with single hidden layer.
import jax
jax.config.update('jax_platform_name', 'cpu')
import muax
from muax import nn
support_size = 10
embedding_size = 8
num_actions = 2
full_support_size = int(support_size * 2 + 1)
repr_fn = nn._init_representation_func(nn.Representation, embedding_size)
pred_fn = nn._init_prediction_func(nn.Prediction, num_actions, full_support_size)
dy_fn = nn._init_dynamic_func(nn.Dynamic, embedding_size, num_actions, full_support_size)
muaxhas built-inepisode tracerandreplay buffuerto track and store trajectories from interacting with environments. The first parameter ofmuax.PNStep(10 in the following code) is thenfor n-step bootstrapping.
discount = 0.99
tracer = muax.PNStep(10, discount, 0.5)
buffer = muax.TrajectoryReplayBuffer(500)
muaxleveragesoptaxto build optimizer to update weights
gradient_transform = muax.model.optimizer(init_value=0.02, peak_value=0.02, end_value=0.002, warmup_steps=5000, transition_steps=5000)
- Now we are ready to call
muax.fitfunction to fit the model to theCartPoleenvironment
model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
optimizer=gradient_transform, support_size=support_size)
model_path = muax.fit(model, 'CartPole-v1',
max_episodes=1000,
max_training_steps=10000,
tracer=tracer,
buffer=buffer,
k_steps=10,
sample_per_trajectory=1,
num_trajectory=32,
tensorboard_dir='/content/tensorboard/cartpole',
model_save_path='/content/models/cartpole',
save_name='cartpole_model_params',
random_seed=0,
log_all_metrics=True)
The full training script:
import muax
from muax import nn
support_size = 10
embedding_size = 8
discount = 0.99
num_actions = 2
full_support_size = int(support_size * 2 + 1)
repr_fn = nn._init_representation_func(nn.Representation, embedding_size)
pred_fn = nn._init_prediction_func(nn.Prediction, num_actions, full_support_size)
dy_fn = nn._init_dynamic_func(nn.Dynamic, embedding_size, num_actions, full_support_size)
tracer = muax.PNStep(10, discount, 0.5)
buffer = muax.TrajectoryReplayBuffer(500)
gradient_transform = muax.model.optimizer(init_value=0.02, peak_value=0.02, end_value=0.002, warmup_steps=5000, transition_steps=5000)
model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
optimizer=gradient_transform, support_size=support_size)
model_path = muax.fit(model, 'CartPole-v1',
max_episodes=1000,
max_training_steps=10000,
tracer=tracer,
buffer=buffer,
k_steps=10,
sample_per_trajectory=1,
num_trajectory=32,
tensorboard_dir='/content/tensorboard/cartpole',
model_save_path='/content/models/cartpole',
save_name='cartpole_model_params',
random_seed=0,
log_all_metrics=True)
- After the training is done, one can use tensorboard to check the training procedure
%load_ext tensorboard
%tensorboard --logdir=tensorboard/cartpole
In the figure below, the model is able to solve the environment in ~500 episodes, ~30k updates
- We can also have more tests with the best parameter
from muax.test import test
model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
optimizer=gradient_transform, support_size=support_size)
model.load(model_path)
env_id = 'CartPole-v1'
test_env = gym.make(env_id, render_mode='rgb_array')
test_key = jax.random.PRNGKey(0)
test(model, test_env, test_key, num_simulations=50, num_test_episodes=100, random_seed=None)
Alternatively, the users could easily write their own training loop. One example is from cartpole.ipynb
More examples can be found under the example directory.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file muax-0.0.2.8.3.tar.gz.
File metadata
- Download URL: muax-0.0.2.8.3.tar.gz
- Upload date:
- Size: 28.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4d4bfbebfed7fd8645ac9f8ded9a2dbcd0181249bbef5d7231bfa3719cfc4eea
|
|
| MD5 |
be800e4f98d7556630e8f04e095fe202
|
|
| BLAKE2b-256 |
c25177fa9047e6e0f46b67d0ff14a6200390cbf61e8025d54390ced29a464950
|
File details
Details for the file muax-0.0.2.8.3-py3-none-any.whl.
File metadata
- Download URL: muax-0.0.2.8.3-py3-none-any.whl
- Upload date:
- Size: 32.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c63433d12fb8d0423e9c197a0890693c044b13edd8e0e1940e5900240c261352
|
|
| MD5 |
61bbae3c6bec50b3e6ba195ff8518ee1
|
|
| BLAKE2b-256 |
5751e1cb0a9d91e9cde3bf78a78688af8a7800934e9e8b2aeae09b4b74b4d3fe
|