Differentiable cognitive primitives for agentic AI
Project description
TorchAgentic
PyTorch Model Definitions for AI Agents
A comprehensive library of deep learning architectures for building trainable AI agents, including reinforcement learning models, transformer-based agents, memory-augmented networks, and multi-agent systems.
Features
- 🧠 Core Architectures - MLP, CNN, RNN/LSTM/GRU backbones
- 🎮 RL Models - DQN, PPO, A3C, SAC, TD3
- 🔄 Transformers - Decision Transformer, Perceiver IO
- 💾 Memory Networks - NTM, DNC (Differentiable Neural Computer)
- 👥 Multi-Agent - MADDPG, QMIX, VDN
- ⚡ Utilities - Initialization, normalization, distributions
- 🚀 PyTorch 2.0 Compile - torch.compile() support for 2-3x speedup
Installation
# Basic installation
pip install torchagentic
# With transformer support
pip install torchagentic[transformers]
# Full installation
pip install torchagentic[full]
# Development installation
pip install torchagentic[dev]
Quick Start
DQN Agent
import torch
from torchagentic import DQN, NatureCNN
# Create DQN model for Atari
model = DQN(
config=ModelConfig(input_dim=4, action_dim=6), # 4 stacked frames, 6 actions
image_input=True,
)
# Forward pass
observations = torch.randn(32, 4, 84, 84) # (batch, channels, height, width)
q_values = model.get_q_values(observations)
# Get action
action = model.get_action(observations, epsilon=0.1)
PPO Actor-Critic
from torchagentic import PPOActorCritic, ModelConfig
# Create actor-critic for continuous control
model = PPOActorCritic(
config=ModelConfig(
input_dim=24, # Observation dim
action_dim=4, # Action dim
hidden_dims=[256, 256],
),
continuous=True,
)
# Get action and value
observation = torch.randn(1, 24)
action, log_prob, entropy, value = model.get_action_and_value(observation)
Decision Transformer
from torchagentic import DecisionTransformer
# Create Decision Transformer for offline RL
model = DecisionTransformer(
config=ModelConfig(input_dim=17, action_dim=3),
embed_dim=128,
num_layers=3,
max_seq_len=20,
)
# Forward with trajectory
states = torch.randn(1, 10, 17)
actions = torch.randn(1, 10, 3)
returns_to_go = torch.ones(1, 10, 1) * 100
predicted_actions = model(states, actions, returns_to_go)
Neural Turing Machine
from torchagentic import NeuralTuringMachine
# Create NTM
ntm = NeuralTuringMachine(
input_size=10,
memory_size=128,
memory_dim=64,
num_reads=4,
num_writes=1,
)
# Process sequence
inputs = torch.randn(1, 50, 10) # (batch, seq_len, input_dim)
outputs = []
hidden = None
for t in range(inputs.shape[1]):
x = inputs[:, t:t+1, :]
output, hidden = ntm(x, hidden)
outputs.append(output)
Multi-Agent (MADDPG)
from torchagentic import MADDPGAgent
# Create MADDPG for 3 agents
model = MADDPGAgent(
num_agents=3,
obs_dim=10,
action_dim=2,
hidden_dims=[256, 256],
shared_params=True,
)
# Get actions
observations = torch.randn(32, 3, 10) # (batch, num_agents, obs_dim)
actions = model.get_actions(observations)
# Get centralized Q-value
q_value = model.get_q_value(observations, actions)
PyTorch 2.0 Compilation
TorchAgentic provides built-in support for torch.compile() from PyTorch 2.0+,
enabling 2-3x speedup for inference and training.
Basic Compilation
from torchagentic import MLPNetwork, ModelConfig
# Create model
model = MLPNetwork(ModelConfig(input_dim=64, action_dim=4))
# Compile for inference (low latency)
model.compile(mode="reduce-overhead")
# Compile for training
model.compile(mode="default", dynamic=True)
# Check if compiled
print(model.is_compiled) # True
Compilation Modes
| Mode | Use Case | Speedup |
|---|---|---|
default |
Balanced | 1.5-2x |
reduce-overhead |
Low latency inference | 2-3x |
max-autotune |
Maximum throughput | 2-4x |
RL-Specific Optimization
from torchagentic import DQN, optimize_for_inference, optimize_for_training
# Create DQN
model = DQN(ModelConfig(input_dim=4, action_dim=6), image_input=True)
# Optimize for inference (recommended for deployment)
model = optimize_for_inference(model, device="cuda")
# Optimize for training
model = optimize_for_training(model, device="cuda", batch_size=64)
Benchmark Example
import torch
from torchagentic import PPOActorCritic, ModelConfig
device = "cuda" if torch.cuda.is_available() else "cpu"
model = PPOActorCritic(ModelConfig(input_dim=24, action_dim=4)).to(device)
# Compile
model.compile(mode="reduce-overhead")
# Benchmark
obs = torch.randn(1, 24, device=device)
# Uncompiled: ~0.5ms
# Compiled: ~0.2ms (2.5x speedup)
Advanced Configuration
from torchagentic import CompileConfig, compile_model
config = CompileConfig(
mode="max-autotune",
dynamic=False,
fullgraph=True,
backend="inductor",
options={
"triton.cudagraphs": True, # Enable CUDA graphs
"max_autotune.gemm": True,
},
)
compiled_model = compile_model(model, config=config)
Model Zoo
Core Architectures
| Model | Description | Use Case |
|---|---|---|
MLPNetwork |
Multi-layer perceptron | Simple environments |
CNNNetwork |
Convolutional network | Visual observations |
NatureCNN |
Nature DQN architecture | Atari games |
ResNetNetwork |
Residual CNN | Complex visual tasks |
LSTMAgent |
LSTM-based agent | Sequential decisions |
GRUAgent |
GRU-based agent | Sequential decisions |
RL Models
| Model | Algorithm | Action Space |
|---|---|---|
DQN |
Deep Q-Network | Discrete |
DuelingDQN |
Dueling DQN | Discrete |
NoisyDQN |
Noisy Nets DQN | Discrete |
PPOActorCritic |
PPO | Both |
A3CNetwork |
A3C | Both |
SACActor |
SAC | Continuous |
TD3Actor |
TD3 | Continuous |
Transformer Models
| Model | Description | Use Case |
|---|---|---|
TransformerAgent |
Self-attention agent | Sequential tasks |
DecisionTransformer |
Offline RL transformer | Offline RL |
PerceiverAgent |
Perceiver IO | Large inputs |
Memory Networks
| Model | Description | Use Case |
|---|---|---|
NeuralTuringMachine |
NTM | Memory tasks |
DifferentiableNeuralComputer |
DNC | Complex memory |
Multi-Agent
| Model | Algorithm | Cooperation |
|---|---|---|
MADDPGAgent |
MADDPG | Mixed |
QMIXNetwork |
QMIX | Cooperative |
VDNNetwork |
VDN | Cooperative |
Examples
The examples/ directory contains complete, runnable examples:
| Example | Description |
|---|---|
examples/dqn_training.py |
Full DQN training loop with replay buffer |
examples/ppo_training.py |
Complete PPO training with GAE and rollout buffer |
examples/decision_transformer.py |
Offline RL with Decision Transformer |
examples/maddpg_training.py |
Multi-agent training with MADDPG |
examples/compile_example.py |
PyTorch 2.0 compilation examples |
examples/compile_benchmark.py |
Performance benchmarks for compiled models |
To run an example:
python examples/dqn_training.py
Training Example
PPO Training Loop
import torch
from torchagentic import PPOActorCritic, ModelConfig
# Create model
model = PPOActorCritic(
config=ModelConfig(input_dim=8, action_dim=2, hidden_dims=[64, 64]),
continuous=False,
)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
# Training step
def ppo_update(observations, actions, old_log_probs, advantages, returns):
# Get new log probs and values
log_probs, entropies, values = model.evaluate_actions(observations, actions)
# Policy loss
ratio = torch.exp(log_probs - old_log_probs)
surr1 = ratio * advantages
surr2 = ratio.clamp(0.8, 1.2) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
# Value loss
value_loss = (values - returns).pow(2).mean()
# Entropy bonus
entropy_bonus = entropies.mean() * 0.01
# Total loss
loss = policy_loss + value_loss * 0.5 - entropy_bonus
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
return loss.item()
DQN Training Loop
import torch
from torchagentic import DQN, ModelConfig
import torch.nn.functional as F
# Create models
policy_net = DQN(config=ModelConfig(input_dim=4, action_dim=2), image_input=False)
target_net = DQN(config=ModelConfig(input_dim=4, action_dim=2), image_input=False)
target_net.load_state_dict(policy_net.state_dict())
optimizer = torch.optim.Adam(policy_net.parameters(), lr=1e-4)
def dqn_update(batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones):
# Current Q values
q_values = policy_net(batch_states)
q_values = q_values.gather(1, batch_actions.unsqueeze(-1)).squeeze(-1)
# Target Q values
with torch.no_grad():
next_q_values = target_net(batch_next_states)
next_q_max = next_q_values.max(1)[0]
targets = batch_rewards + (1 - batch_dones) * 0.99 * next_q_max
# Loss
loss = F.mse_loss(q_values, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
# Update target network periodically
def update_target():
target_net.load_state_dict(policy_net.state_dict())
API Reference
Base Classes
from torchagentic import BaseAgentModel, ModelConfig
# Configuration
config = ModelConfig(
input_dim=64,
action_dim=4,
hidden_dims=[256, 256],
activation="relu",
dropout=0.0,
)
# All models inherit from BaseAgentModel
model = YourModel(config)
# Common methods
model.forward(x) # Forward pass
model.get_action(obs, deterministic) # Get action
model.get_value(obs) # Get value estimate
model.save("path.pt") # Save checkpoint
model.load("path.pt") # Load checkpoint
model.get_num_params() # Count parameters
Utilities
from torchagentic import orthogonal_init_, RunningNorm, DiagGaussian
# Weight initialization
orthogonal_init_(layer, gain=1.0)
# Normalization
norm = RunningNorm(256)
# Distributions
dist = DiagGaussian(mean, std)
sample = dist.sample()
log_prob = dist.log_prob(action)
Contributing
- Fork the repository
- Create a feature branch (
git checkout -b feature/my-feature) - Commit changes (
git commit -am 'Add new feature') - Push to branch (
git push origin feature/my-feature) - Create a Pull Request
License
MIT License - see LICENSE for details.
Citation
@software{torchagentic2024,
title = {TorchAgentic: PyTorch Model Definitions for AI Agents},
author = {Liodon AI},
year = {2024},
url = {https://github.com/liodon-ai/torchagentic}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchagentic-0.2.0.tar.gz.
File metadata
- Download URL: torchagentic-0.2.0.tar.gz
- Upload date:
- Size: 68.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fcdb95b15b24d60cb47551369882933d2fc72683d7ceb0241e22154e8b9a0edd
|
|
| MD5 |
dfa18c4364beec572cb147be51075e6e
|
|
| BLAKE2b-256 |
a6d4d817b1908e8f0e6c52bbe7e81bcd5bcb3e8916c90a5b64a1e490c4399358
|
File details
Details for the file torchagentic-0.2.0-py3-none-any.whl.
File metadata
- Download URL: torchagentic-0.2.0-py3-none-any.whl
- Upload date:
- Size: 75.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1d9d38f3736128929909e913e96986ae488076cf32aad707e09c1882a337b00a
|
|
| MD5 |
d6a873c0504063b0d05f0f3fed464d57
|
|
| BLAKE2b-256 |
d75fad90cec73c3897522f9635d5999d54e7f0ce86de3922e4f05d0d2ce8f777
|