Reinforcement Learning with Pixyz (PixyzRL)
Project description
PixyzRL: A Reinforcement Learning Framework with Probabilistic Generative Models
Documentation | Examples | GitHub
What is PixyzRL?
PixyzRL is a reinforcement learning (RL) framework based on probabilistic generative models and Bayesian theory. Built on top of the Pixyz library, it provides a modular and flexible design to enable uncertainty-aware decision-making and improve sample efficiency. PixyzRL supports:
- Probabilistic Policy Optimization (e.g., PPO, A2C)
- On-policy and Off-policy Learning
- Memory Management for RL (Replay Buffer, Rollout Buffer)
- Integration with Gymnasium environments
- Logging and Model Training Utilities
Installation
Requirements
- Python 3.10+
- PyTorch 2.5.1+
- Gymnasium (for environment interaction)
Install PixyzRL
Using pip
pip install torch torchvision torchaudio pixyz gymnasium[box2d] torchrl
Install from Source
git clone https://github.com/ItoMasaki/PixyzRL.git
cd PixyzRL
pip install -e .
Quick Start
1. Set Up Environment
from pixyzrl.environments import Env
env = Env("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
2. Define Actor and Critic Networks
import torch
from pixyz.distributions import Categorical, Deterministic
from torch import nn
class Actor(Categorical):
def __init__(self):
super().__init__(var=["a"], cond_var=["o"], name="p")
self.net = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, action_dim),
nn.Softmax(dim=-1)
)
def forward(self, o: torch.Tensor):
return {"probs": self.net(o)}
class Critic(Deterministic):
def __init__(self):
super().__init__(var=["v"], cond_var=["o"], name="f")
self.net = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
def forward(self, o: torch.Tensor):
return {"v": self.net(o)}
actor = Actor()
critic = Critic()
2.1 Display distributions as latex
>>> pixyzrl.utils.print_latex(actor)
p(a|o)
>>> pixyzrl.utils.print_latex(critic)
f(v|o)
3. Prepare PPO and Buffer
from pixyzrl.models import PPO
from pixyzrl.memory import RolloutBuffer
from pixyzrl.trainer import OnPolicyTrainer
agent = PPO(actor, critic, entropy_coef=0.0, mse_coef=1.0)
buffer = RolloutBuffer(
2048,
{
"obs": {"shape": (4,), "map": "o"},
"value": {"shape": (1,), "map": "v"},
"action": {"shape": (2,), "map": "a"},
"reward": {"shape": (1,)},
"done": {"shape": (1,)},
"returns": {"shape": (1,), "map": "r"},
"advantages": {"shape": (1,), "map": "A"},
},
"cpu",
1,
)
3.1 Display model as latex
>>> print_latex(agent)
mean \left(1.0 MSE(f(v|o), r) - min \left(A clip(\frac{p(a|o)}{old(a|o)}, 0.8, 1.2), A \frac{p(a|o)}{old(a|o)}\right) \right)
4. Training with Trainer
trainer = OnPolicyTrainer(env, buffer, agent, "cpu")
trainer.train(1000)
Directory Structure
PixyzRL
├── docs
│ └── pixyz
│ └── README.pixyz.md
├── examples # Example scripts
├── pixyzrl
│ ├── environments # Environment wrappers
│ ├── models
│ │ ├── on_policy # On-policy models (e.g., PPO, A2C)
│ │ └── off_policy # Off-policy models (e.g., DQN)
│ ├── memory # Experience replay & rollout buffer
│ ├── trainer # Training utilities
│ ├── losses # Loss function definitions
│ ├── logger # Logging utilities
│ └── utils.py
└── pyproject.toml
Future Work
- Implement Deep Q-Network (DQN)
- Implement Dreamer (model-based RL)
- Integrate with ChatGPT for automatic architecture generation
- Integrate with Genesis
License
PixyzRL is released under the MIT License.
Community & Support
For questions and discussions, please visit:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pixyzrl-0.3.0.tar.gz.
File metadata
- Download URL: pixyzrl-0.3.0.tar.gz
- Upload date:
- Size: 25.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ba76d583000ff9ea46104eb1ce9128cee68a318aa7a06342a851b6650665edf3
|
|
| MD5 |
0088c9e782f4b8c4dfb6d5370270a304
|
|
| BLAKE2b-256 |
61bbcf37dd27b428d9b0cd02cdcc62dcc2f076a42c7da4b4255862304459848c
|
File details
Details for the file pixyzrl-0.3.0-py3-none-any.whl.
File metadata
- Download URL: pixyzrl-0.3.0-py3-none-any.whl
- Upload date:
- Size: 28.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
872354b4b82e726a34b60af3f977600290df329f22082acf020933d197eb8011
|
|
| MD5 |
af25c13ab9a876b10640f0a56badc827
|
|
| BLAKE2b-256 |
023247d3b6e275e3652077ab34132c0123dab996c2f70e82e43d4a1ba621417d
|