Skip to main content

Reinforcement Learning with Pixyz (PixyzRL)

Project description

PixyzRL: A Reinforcement Learning Framework with Probabilistic Generative Models

PixyzRL Logo

License: MIT PyTorch Version Python Version workflow codecov Open in Visual Studio Code

Documentation | Examples | GitHub

What is PixyzRL?

PixyzRL is a reinforcement learning (RL) framework based on probabilistic generative models and Bayesian theory. Built on top of the Pixyz library, it provides a modular and flexible design to enable uncertainty-aware decision-making and improve sample efficiency. PixyzRL supports:

  • Probabilistic Policy Optimization (e.g., PPO, A2C)
  • On-policy and Off-policy Learning
  • Memory Management for RL (Replay Buffer, Rollout Buffer)
  • Integration with Gymnasium environments
  • Logging and Model Training Utilities

Installation

Requirements

  • Python 3.10+
  • PyTorch 2.5.1+
  • Gymnasium (for environment interaction)

Install PixyzRL

Using pip

pip install torch torchvision torchaudio pixyz gymnasium[box2d] torchrl

Install from Source

git clone https://github.com/ItoMasaki/PixyzRL.git
cd PixyzRL
pip install -e .

Quick Start

1. Set Up Environment

from pixyzrl.environments import Env

env = Env("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

2. Define Actor and Critic Networks

import torch
from pixyz.distributions import Categorical, Deterministic
from torch import nn

class Actor(Categorical):
    def __init__(self):
        super().__init__(var=["a"], cond_var=["o"], name="p")
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, o: torch.Tensor):
        return {"probs": self.net(o)}

class Critic(Deterministic):
    def __init__(self):
        super().__init__(var=["v"], cond_var=["o"], name="f")
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, o: torch.Tensor):
        return {"v": self.net(o)}

actor = Actor()
critic = Critic()

2.1 Display distributions as latex

>>> pixyzrl.utils.print_latex(actor)
p(a|o)

>>> pixyzrl.utils.print_latex(critic)
f(v|o)

3. Prepare PPO and Buffer

from pixyzrl.models import PPO
from pixyzrl.memory import RolloutBuffer
from pixyzrl.trainer import OnPolicyTrainer

agent = PPO(actor, critic, entropy_coef=0.0, mse_coef=1.0)

buffer = RolloutBuffer(
    2048,
    {
        "obs": {"shape": (4,), "map": "o"},
        "value": {"shape": (1,), "map": "v"},
        "action": {"shape": (2,), "map": "a"},
        "reward": {"shape": (1,)},
        "done": {"shape": (1,)},
        "returns": {"shape": (1,), "map": "r"},
        "advantages": {"shape": (1,), "map": "A"},
    },
    "cpu",
    1,
)

3.1 Display model as latex

>>> print_latex(agent)
mean \left(1.0 MSE(f(v|o), r) - min \left(A clip(\frac{p(a|o)}{old(a|o)}, 0.8, 1.2), A \frac{p(a|o)}{old(a|o)}\right) \right)
latex

4. Training with Trainer

trainer = OnPolicyTrainer(env, buffer, agent, "cpu")
trainer.train(1000)

Directory Structure

PixyzRL
├── docs
│   └── pixyz
│       └── README.pixyz.md
├── examples  # Example scripts
├── pixyzrl
│   ├── environments  # Environment wrappers
│   ├── models
│   │   ├── on_policy  # On-policy models (e.g., PPO, A2C)
│   │   └── off_policy  # Off-policy models (e.g., DQN)
│   ├── memory  # Experience replay & rollout buffer
│   ├── trainer  # Training utilities
│   ├── losses  # Loss function definitions
│   ├── logger  # Logging utilities
│   └── utils.py
└── pyproject.toml

Future Work

  • Implement Deep Q-Network (DQN)
  • Implement Dreamer (model-based RL)
  • Integrate with ChatGPT for automatic architecture generation
  • Integrate with Genesis

License

PixyzRL is released under the MIT License.

Community & Support

For questions and discussions, please visit:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pixyzrl-0.3.0.tar.gz (25.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pixyzrl-0.3.0-py3-none-any.whl (28.3 kB view details)

Uploaded Python 3

File details

Details for the file pixyzrl-0.3.0.tar.gz.

File metadata

  • Download URL: pixyzrl-0.3.0.tar.gz
  • Upload date:
  • Size: 25.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for pixyzrl-0.3.0.tar.gz
Algorithm Hash digest
SHA256 ba76d583000ff9ea46104eb1ce9128cee68a318aa7a06342a851b6650665edf3
MD5 0088c9e782f4b8c4dfb6d5370270a304
BLAKE2b-256 61bbcf37dd27b428d9b0cd02cdcc62dcc2f076a42c7da4b4255862304459848c

See more details on using hashes here.

File details

Details for the file pixyzrl-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: pixyzrl-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 28.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for pixyzrl-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 872354b4b82e726a34b60af3f977600290df329f22082acf020933d197eb8011
MD5 af25c13ab9a876b10640f0a56badc827
BLAKE2b-256 023247d3b6e275e3652077ab34132c0123dab996c2f70e82e43d4a1ba621417d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page