Pax: Environment for ...
Project description
Pax: Meta|Multi Agent Learning in JAX
Pax is an experiment runner for multi and meta agent research built on top of JAX. It supports "other agent shaping", "multi agent RL" and "single agent RL" experiments. It supports evolutionary and RL-based optimisation.
Pax (noun) - a period of peace that has been forced on a large area, such as an empire or even the whole world
Pax is composed of 3 components: Environments, Agents and Runners.
Environments
Environments are similar to gymnax.
from pax.envs.iterated_matrix_game import (
IteratedMatrixGame,
EnvParams,
)
env = IteratedMatrixGame(num_inner_steps=5)
env_params = EnvParams(payoff_matrix=payoff)
# 0 = Defect, 1 = Cooperate
actions = (jnp.ones(()), jnp.ones(()))
obs, env_state = env.reset(rng, env_params)
done = False
while not done:
obs, env_state, rewards, done, info = env.step(
rng, env_state, actions, env_params
)
We can compose these with JAX built-in functins jit
, vmap
, pmap
and lax.scan
.
import IteratedMatrixGame, EnvParams
import jax.numpy as jnp
# batch over env initalisations
num_envs = 2
payoff = [[2, 2], [0, 3], [3, 0], [1, 1]]
rollout_length = 50
rng = jnp.concatenate(
[jax.random.PRNGKey(0), jax.random.PRNGKey(1)]
).reshape(num_envs, -1)
env = IteratedMatrixGame(num_inner_steps=rollout_length)
env_params = EnvParams(payoff_matrix=payoff)
action = jnp.ones((num_envs,), dtype=jnp.float32)
# we want to batch over rngs, actions
env.step = jax.vmap(
env.step,
in_axes=(0, None, 0, None),
out_axes=(0, None, 0, 0, 0),
)
obs, env_state = env.reset(rng, env_params)
# lets scan the rollout for speed
def rollout(carry, unused):
carry = (_, env_state, env_rng)
actions = (action, action)
obs, env_state, rewards, done, info = env.step(
env_rng, env_state, actions, env_params
)
return (obs, env_state, env_rng), (
obs,
actions,
rewards,
done,
)
final_state, trajectory = jax.lax.scan(
rollout, (obs, env_state, rng), rollout_length
)
Agents
The agent interface is as follows:
import jax.numpy as jnp
import Agent
args = {"hidden": 16, "observation_spec": 5}
rng = jax.random.PRNGKey(0)
bs = 1
init_hidden = jnp.zeros((bs, args.hidden))
obs = jnp.ones((bs, 5))
agent = Agent(args)
state, memory = agent.make_initial_state(rng, init_hidden)
action, state, mem = agent.policy(rng, obs, mem)
state, memory, stats = agent.update(
traj_batch, obs, state, mem
)
mem = agent.reset_memory(mem, False)
Note that make_initial_state
, policy
, update
and reset_memory
all support jit
, vmap
and lax.scan
. Allowing you to compile more of your experiment to XLA.
# batch MemoryState not TrainingState
agent.batch_reset = jax.jit(
jax.vmap(agent.reset_memory, (0, None), 0),
static_argnums=1,
)
agent.batch_policy = jax.jit(
jax.vmap(agent._policy, (None, 0, 0), (0, None, 0))
)
agent1.batch_init = jax.vmap(
agent.make_initial_state,
(None, 0),
(None, 0),
)
Runners
We can finally combine all the above into our runner code. This is where you'd expect to write most custom logic for your own experimental set up,
def _rollout(carry, unused):
"""Runner for inner episode"""
(
rngs,
obs,
a_state,
a_mem,
env_state,
env_params,
) = carry
# unpack rngs
rngs = self.split(rngs, 4)
action, a_state, new_a_mem = agent1.batch_policy(
a_state,
obs[0],
a_mem,
)
next_obs, env_state, rewards, done, info = env.step(
rngs,
env_state,
(action, action),
env_params,
)
traj = Sample(
obs1,
action,
rewards[0],
new_a1_mem.extras["log_probs"],
new_a1_mem.extras["values"],
done,
a1_mem.hidden,
)
return (
rngs,
next_obs,
a1_state,
new_a1_mem,
env_state,
env_params,
), (
traj1,
traj2,
)
agent = Agent(args)
state, memory = agent.make_initial_state(rng, init_hidden)
for _ in range(num_updates):
final_timestep, batch_trajectory = jax.lax.scan(
_rollout,
((obs, env_state, rng), rollout_length),
10,
)
_, obs, rewards, a1_state, a1_mem, _, _ = final_timestep
state, memory, stats = agent.update(
batch_trajectory, obs[0], state, memory
)
Note this isn't even a fully optimised example - we could jit the outer loop!
Installation
Pax is written in pure Python, but depends on C++ code via JAX.
Because JAX installation is different depending on your CUDA version, Haiku does not list JAX as a dependency in requirements.txt.
First, follow these instructions to install JAX with the relevant accelerator support.
General Information
The project entrypoint is pax/experiment.py
. The simplest command to run a game would be:
python -m pax.experiment
We currently use WandB for logging and Hydra for configs. Hyperparameters are stored /conf/experiment
as .yaml
files. Depending on your needs, you can specify hyperparameters through the CLI or by changing the .yaml
files directly.
python -m pax.experiment +total_timesteps=1_000_000 +num_envs=10
We currently support two major environments: MatrixGames
and CoinGame
.
For `MatrixGames`, we support the ability to specify your own payoff matrix either through the CLI or the `yaml` files. For example the common Iterated Prisoners Dilemma is:
```bash
python -m pax.experiment +experiment/ipd=ppo ++payoff="[[-2,-2], [0,-3], [-3,0], [-1,-1]]" ++wandb.group="testing"
Experiments
python -m pax.experiment +experiment/ipd=yaml ++wandb.group="testing"
We store previous experiments as parity tests. We use Hydra to store these configs and keep track of good hyper-paremters. As a rule for development, we try retain backwards compatability and allow all previous results to be replicated. These can be run easily by python -m pax.experiment +experiment=NAME
. We also provide a list of our existing experiments and expected result here.
Loading and Saving
- All models trained using Pax by default are saved to the
exp
folder. 2a. If you have the model saved locally, specifymodel_path = exp/...
. By default, Player 1 will be loaded with the parameters.
2b. If you do not have the weights saved locally, specify the wandb runrun_path={wandb-group}{wandb-project}{}
andmodel_path = exp/...
player 1 will be loaded with the parameters. - In order to run evaluation, specify
eval: True
and evaluation fornum_seeds
iterations.
Citation
If you use Pax in any of your work, please cite:
@misc{pax,
author = {Khan, Akbir and Willi, Timon and Kwan, Newton, and Samvelyan, Mikayel and Lu, Chris},
title = {Pax: Multi-Agent Learning in JAX},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/akbir/pax}},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file pax-rl-0.1.0b0.tar.gz
.
File metadata
- Download URL: pax-rl-0.1.0b0.tar.gz
- Upload date:
- Size: 29.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 54aa1d006716660fad4dbae066aa8f0610137697dcf548a714139d39373d8fda |
|
MD5 | b7dc412a389c02da87fde1c60ff20342 |
|
BLAKE2b-256 | 345b23b090e31248b7956b5e90ae0b35b85b0e52d142180f71204fb1a5c4bd46 |