Skip to main content

Differentiable cognitive primitives for agentic AI

Project description

TorchAgentic

PyTorch Model Definitions for AI Agents

A comprehensive library of deep learning architectures for building trainable AI agents, including reinforcement learning models, transformer-based agents, memory-augmented networks, and multi-agent systems.

Features

  • 🧠 Core Architectures - MLP, CNN, RNN/LSTM/GRU backbones
  • 🎮 RL Models - DQN, PPO, A3C, SAC, TD3
  • 🔄 Transformers - Decision Transformer, Perceiver IO
  • 💾 Memory Networks - NTM, DNC (Differentiable Neural Computer)
  • 👥 Multi-Agent - MADDPG, QMIX, VDN
  • Utilities - Initialization, normalization, distributions
  • 🚀 PyTorch 2.0 Compile - torch.compile() support for 2-3x speedup

Installation

# Basic installation
pip install torchagentic

# With transformer support
pip install torchagentic[transformers]

# Full installation
pip install torchagentic[full]

# Development installation
pip install torchagentic[dev]

Quick Start

DQN Agent

import torch
from torchagentic import DQN, NatureCNN

# Create DQN model for Atari
model = DQN(
    config=ModelConfig(input_dim=4, action_dim=6),  # 4 stacked frames, 6 actions
    image_input=True,
)

# Forward pass
observations = torch.randn(32, 4, 84, 84)  # (batch, channels, height, width)
q_values = model.get_q_values(observations)

# Get action
action = model.get_action(observations, epsilon=0.1)

PPO Actor-Critic

from torchagentic import PPOActorCritic, ModelConfig

# Create actor-critic for continuous control
model = PPOActorCritic(
    config=ModelConfig(
        input_dim=24,      # Observation dim
        action_dim=4,      # Action dim
        hidden_dims=[256, 256],
    ),
    continuous=True,
)

# Get action and value
observation = torch.randn(1, 24)
action, log_prob, entropy, value = model.get_action_and_value(observation)

Decision Transformer

from torchagentic import DecisionTransformer

# Create Decision Transformer for offline RL
model = DecisionTransformer(
    config=ModelConfig(input_dim=17, action_dim=3),
    embed_dim=128,
    num_layers=3,
    max_seq_len=20,
)

# Forward with trajectory
states = torch.randn(1, 10, 17)
actions = torch.randn(1, 10, 3)
returns_to_go = torch.ones(1, 10, 1) * 100

predicted_actions = model(states, actions, returns_to_go)

Neural Turing Machine

from torchagentic import NeuralTuringMachine

# Create NTM
ntm = NeuralTuringMachine(
    input_size=10,
    memory_size=128,
    memory_dim=64,
    num_reads=4,
    num_writes=1,
)

# Process sequence
inputs = torch.randn(1, 50, 10)  # (batch, seq_len, input_dim)
outputs = []
hidden = None

for t in range(inputs.shape[1]):
    x = inputs[:, t:t+1, :]
    output, hidden = ntm(x, hidden)
    outputs.append(output)

Multi-Agent (MADDPG)

from torchagentic import MADDPGAgent

# Create MADDPG for 3 agents
model = MADDPGAgent(
    num_agents=3,
    obs_dim=10,
    action_dim=2,
    hidden_dims=[256, 256],
    shared_params=True,
)

# Get actions
observations = torch.randn(32, 3, 10)  # (batch, num_agents, obs_dim)
actions = model.get_actions(observations)

# Get centralized Q-value
q_value = model.get_q_value(observations, actions)

PyTorch 2.0 Compilation

TorchAgentic provides built-in support for torch.compile() from PyTorch 2.0+, enabling 2-3x speedup for inference and training.

Basic Compilation

from torchagentic import MLPNetwork, ModelConfig

# Create model
model = MLPNetwork(ModelConfig(input_dim=64, action_dim=4))

# Compile for inference (low latency)
model.compile(mode="reduce-overhead")

# Compile for training
model.compile(mode="default", dynamic=True)

# Check if compiled
print(model.is_compiled)  # True

Compilation Modes

Mode Use Case Speedup
default Balanced 1.5-2x
reduce-overhead Low latency inference 2-3x
max-autotune Maximum throughput 2-4x

RL-Specific Optimization

from torchagentic import DQN, optimize_for_inference, optimize_for_training

# Create DQN
model = DQN(ModelConfig(input_dim=4, action_dim=6), image_input=True)

# Optimize for inference (recommended for deployment)
model = optimize_for_inference(model, device="cuda")

# Optimize for training
model = optimize_for_training(model, device="cuda", batch_size=64)

Benchmark Example

import torch
from torchagentic import PPOActorCritic, ModelConfig

device = "cuda" if torch.cuda.is_available() else "cpu"
model = PPOActorCritic(ModelConfig(input_dim=24, action_dim=4)).to(device)

# Compile
model.compile(mode="reduce-overhead")

# Benchmark
obs = torch.randn(1, 24, device=device)

# Uncompiled: ~0.5ms
# Compiled: ~0.2ms (2.5x speedup)

Advanced Configuration

from torchagentic import CompileConfig, compile_model

config = CompileConfig(
    mode="max-autotune",
    dynamic=False,
    fullgraph=True,
    backend="inductor",
    options={
        "triton.cudagraphs": True,  # Enable CUDA graphs
        "max_autotune.gemm": True,
    },
)

compiled_model = compile_model(model, config=config)

Model Zoo

Core Architectures

Model Description Use Case
MLPNetwork Multi-layer perceptron Simple environments
CNNNetwork Convolutional network Visual observations
NatureCNN Nature DQN architecture Atari games
ResNetNetwork Residual CNN Complex visual tasks
LSTMAgent LSTM-based agent Sequential decisions
GRUAgent GRU-based agent Sequential decisions

RL Models

Model Algorithm Action Space
DQN Deep Q-Network Discrete
DuelingDQN Dueling DQN Discrete
NoisyDQN Noisy Nets DQN Discrete
PPOActorCritic PPO Both
A3CNetwork A3C Both
SACActor SAC Continuous
TD3Actor TD3 Continuous

Transformer Models

Model Description Use Case
TransformerAgent Self-attention agent Sequential tasks
DecisionTransformer Offline RL transformer Offline RL
PerceiverAgent Perceiver IO Large inputs

Memory Networks

Model Description Use Case
NeuralTuringMachine NTM Memory tasks
DifferentiableNeuralComputer DNC Complex memory

Multi-Agent

Model Algorithm Cooperation
MADDPGAgent MADDPG Mixed
QMIXNetwork QMIX Cooperative
VDNNetwork VDN Cooperative

Examples

The examples/ directory contains complete, runnable examples:

Example Description
examples/dqn_training.py Full DQN training loop with replay buffer
examples/ppo_training.py Complete PPO training with GAE and rollout buffer
examples/decision_transformer.py Offline RL with Decision Transformer
examples/maddpg_training.py Multi-agent training with MADDPG
examples/compile_example.py PyTorch 2.0 compilation examples
examples/compile_benchmark.py Performance benchmarks for compiled models

To run an example:

python examples/dqn_training.py

Training Example

PPO Training Loop

import torch
from torchagentic import PPOActorCritic, ModelConfig

# Create model
model = PPOActorCritic(
    config=ModelConfig(input_dim=8, action_dim=2, hidden_dims=[64, 64]),
    continuous=False,
)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

# Training step
def ppo_update(observations, actions, old_log_probs, advantages, returns):
    # Get new log probs and values
    log_probs, entropies, values = model.evaluate_actions(observations, actions)
    
    # Policy loss
    ratio = torch.exp(log_probs - old_log_probs)
    surr1 = ratio * advantages
    surr2 = ratio.clamp(0.8, 1.2) * advantages
    policy_loss = -torch.min(surr1, surr2).mean()
    
    # Value loss
    value_loss = (values - returns).pow(2).mean()
    
    # Entropy bonus
    entropy_bonus = entropies.mean() * 0.01
    
    # Total loss
    loss = policy_loss + value_loss * 0.5 - entropy_bonus
    
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    
    return loss.item()

DQN Training Loop

import torch
from torchagentic import DQN, ModelConfig
import torch.nn.functional as F

# Create models
policy_net = DQN(config=ModelConfig(input_dim=4, action_dim=2), image_input=False)
target_net = DQN(config=ModelConfig(input_dim=4, action_dim=2), image_input=False)
target_net.load_state_dict(policy_net.state_dict())

optimizer = torch.optim.Adam(policy_net.parameters(), lr=1e-4)

def dqn_update(batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones):
    # Current Q values
    q_values = policy_net(batch_states)
    q_values = q_values.gather(1, batch_actions.unsqueeze(-1)).squeeze(-1)
    
    # Target Q values
    with torch.no_grad():
        next_q_values = target_net(batch_next_states)
        next_q_max = next_q_values.max(1)[0]
        targets = batch_rewards + (1 - batch_dones) * 0.99 * next_q_max
    
    # Loss
    loss = F.mse_loss(q_values, targets)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()

# Update target network periodically
def update_target():
    target_net.load_state_dict(policy_net.state_dict())

API Reference

Base Classes

from torchagentic import BaseAgentModel, ModelConfig

# Configuration
config = ModelConfig(
    input_dim=64,
    action_dim=4,
    hidden_dims=[256, 256],
    activation="relu",
    dropout=0.0,
)

# All models inherit from BaseAgentModel
model = YourModel(config)

# Common methods
model.forward(x)                    # Forward pass
model.get_action(obs, deterministic)  # Get action
model.get_value(obs)                # Get value estimate
model.save("path.pt")               # Save checkpoint
model.load("path.pt")               # Load checkpoint
model.get_num_params()              # Count parameters

Utilities

from torchagentic import orthogonal_init_, RunningNorm, DiagGaussian

# Weight initialization
orthogonal_init_(layer, gain=1.0)

# Normalization
norm = RunningNorm(256)

# Distributions
dist = DiagGaussian(mean, std)
sample = dist.sample()
log_prob = dist.log_prob(action)

Contributing

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/my-feature)
  3. Commit changes (git commit -am 'Add new feature')
  4. Push to branch (git push origin feature/my-feature)
  5. Create a Pull Request

License

MIT License - see LICENSE for details.

Citation

@software{torchagentic2024,
  title = {TorchAgentic: PyTorch Model Definitions for AI Agents},
  author = {Liodon AI},
  year = {2024},
  url = {https://github.com/liodon-ai/torchagentic}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchagentic-0.2.0.tar.gz (68.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchagentic-0.2.0-py3-none-any.whl (75.7 kB view details)

Uploaded Python 3

File details

Details for the file torchagentic-0.2.0.tar.gz.

File metadata

  • Download URL: torchagentic-0.2.0.tar.gz
  • Upload date:
  • Size: 68.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for torchagentic-0.2.0.tar.gz
Algorithm Hash digest
SHA256 fcdb95b15b24d60cb47551369882933d2fc72683d7ceb0241e22154e8b9a0edd
MD5 dfa18c4364beec572cb147be51075e6e
BLAKE2b-256 a6d4d817b1908e8f0e6c52bbe7e81bcd5bcb3e8916c90a5b64a1e490c4399358

See more details on using hashes here.

File details

Details for the file torchagentic-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: torchagentic-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 75.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for torchagentic-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1d9d38f3736128929909e913e96986ae488076cf32aad707e09c1882a337b00a
MD5 d6a873c0504063b0d05f0f3fed464d57
BLAKE2b-256 d75fad90cec73c3897522f9635d5999d54e7f0ce86de3922e4f05d0d2ce8f777

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page