Skip to main content

Reinforcement Learning with Pixyz (PixyzRL)

Project description

PixyzRL: A Reinforcement Learning Framework with Probabilistic Generative Models

PixyzRL Logo

License: MIT PyTorch Version Python Version workflow codecov Open in Visual Studio Code

Documentation | Examples | GitHub

What is PixyzRL?

PixyzRL is a reinforcement learning (RL) framework based on probabilistic generative models and Bayesian theory. Built on top of the Pixyz library, it provides a modular and flexible design to enable uncertainty-aware decision-making and improve sample efficiency. PixyzRL supports:

  • Probabilistic Policy Optimization (e.g., PPO, A2C)
  • On-policy and Off-policy Learning
  • Memory Management for RL (Replay Buffer, Rollout Buffer)
  • Integration with Gymnasium environments
  • Logging and Model Training Utilities

Installation

Requirements

  • Python 3.10+
  • PyTorch 2.5.1+
  • Gymnasium (for environment interaction)

Install PixyzRL

Using pip

pip install torch torchvision torchaudio pixyz gymnasium[box2d] torchrl

Install from Source

git clone https://github.com/ItoMasaki/PixyzRL.git
cd PixyzRL
pip install -e .

Quick Start

1. Set Up Environment

from pixyzrl.environments import Env

env = Env("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

2. Define Actor and Critic Networks

import torch
from pixyz.distributions import Categorical, Deterministic
from torch import nn

class Actor(Categorical):
    def __init__(self):
        super().__init__(var=["a"], cond_var=["o"], name="p")
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, o: torch.Tensor):
        return {"probs": self.net(o)}

class Critic(Deterministic):
    def __init__(self):
        super().__init__(var=["v"], cond_var=["o"], name="f")
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, o: torch.Tensor):
        return {"v": self.net(o)}

actor = Actor()
critic = Critic()

2.1 Display distributions as latex

>>> pixyzrl.utils.print_latex(actor)
p(a|o)

>>> pixyzrl.utils.print_latex(critic)
f(v|o)

3. Prepare PPO and Buffer

from pixyzrl.models import PPO
from pixyzrl.memory import RolloutBuffer
from pixyzrl.trainer import OnPolicyTrainer

agent = PPO(actor, critic, entropy_coef=0.0, mse_coef=1.0)

buffer = RolloutBuffer(
    2048,
    {
        "obs": {"shape": (4,), "map": "o"},
        "value": {"shape": (1,), "map": "v"},
        "action": {"shape": (2,), "map": "a"},
        "reward": {"shape": (1,)},
        "done": {"shape": (1,)},
        "returns": {"shape": (1,), "map": "r"},
        "advantages": {"shape": (1,), "map": "A"},
    },
    "cpu",
    1,
)

3.1 Display model as latex

>>> print_latex(agent)
mean \left(1.0 MSE(f(v|o), r) - min \left(A clip(\frac{p(a|o)}{old(a|o)}, 0.8, 1.2), A \frac{p(a|o)}{old(a|o)}\right) \right)
latex

4. Training with Trainer

trainer = OnPolicyTrainer(env, buffer, agent, "cpu")
trainer.train(1000)

Directory Structure

PixyzRL
├── docs
│   └── pixyz
│       └── README.pixyz.md
├── examples  # Example scripts
├── pixyzrl
│   ├── environments  # Environment wrappers
│   ├── models
│   │   ├── on_policy  # On-policy models (e.g., PPO, A2C)
│   │   └── off_policy  # Off-policy models (e.g., DQN)
│   ├── memory  # Experience replay & rollout buffer
│   ├── trainer  # Training utilities
│   ├── losses  # Loss function definitions
│   ├── logger  # Logging utilities
│   └── utils.py
└── pyproject.toml

Future Work

  • Implement Deep Q-Network (DQN)
  • Implement Dreamer (model-based RL)
  • Integrate with ChatGPT for automatic architecture generation
  • Integrate with Genesis

License

PixyzRL is released under the MIT License.

Community & Support

For questions and discussions, please visit:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pixyzrl-0.1.0.tar.gz (25.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pixyzrl-0.1.0-py3-none-any.whl (28.3 kB view details)

Uploaded Python 3

File details

Details for the file pixyzrl-0.1.0.tar.gz.

File metadata

  • Download URL: pixyzrl-0.1.0.tar.gz
  • Upload date:
  • Size: 25.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for pixyzrl-0.1.0.tar.gz
Algorithm Hash digest
SHA256 4a5846a9658fb4c34d2845512f4d9494e9f01d7a8448516f6e82ee7ef555d41c
MD5 c1c2c29cb31289c013409283e69d67f6
BLAKE2b-256 b6702859079f6d685ba7b5fe44d5f7bd0b5996bcd8c141b82f8db225ed0bc66b

See more details on using hashes here.

File details

Details for the file pixyzrl-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: pixyzrl-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 28.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for pixyzrl-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 416707f9459d8f275a1e6aeae5bd0d57a5feccbad1bbdbd4a50c33de27dd3072
MD5 304b568eb33ad09a798c9af7d99a431d
BLAKE2b-256 d5ed9101df3283b08def2a9faa5c725bd8673c8676d258239f80ccc9dbe247d1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page