Skip to main content

RAT: Reinforced Adaptive Transformer

Project description

RAT: Reinforced Adaptive Transformer

RAT Architecture

RAT (Reinforced Adaptive Transformer) is a next-generation transformer architecture featuring adaptive attention mechanisms powered by reinforcement learning. It combines cutting-edge techniques like Rotary Position Embeddings, SwiGLU feed-forward networks, and temporal convolutions for superior language modeling performance.

✨ Key Features

  • 🧠 Adaptive Policy Attention: Dynamic head gating using multiple RL-based policy networks
  • 🔄 Rotary Position Embeddings: Enhanced positional understanding with RoPE
  • 🚀 SwiGLU Feed-Forward: Efficient activation for better expressiveness
  • ⏰ Temporal Convolutions: Sequence modeling with depthwise convolutions
  • 📊 Advanced Logging: Comprehensive training monitoring and debugging
  • 🛡️ Error Handling: Robust validation and graceful failure recovery
  • 💾 Auto-Checkpointing: Automatic model saving with training state
  • 🎯 Optimized Generation: Multiple sampling strategies with KV caching

🏗️ Architecture Components

Core Components

  • RAT: Main transformer model with adaptive attention
  • AdaptivePolicyAttention: Multi-policy attention with reinforcement learning
  • RATBlock: Transformer block with attention, FFN, and temporal conv
  • SwiGLUFeedForward: Efficient feed-forward network
  • RotaryPositionEmbedding: Rotary positional encodings

Training & Inference

  • RATTrainer: Advanced trainer with logging and checkpointing
  • RATGenerator: Optimized text generation with multiple strategies
  • RATDataset: Enhanced dataset with preprocessing and validation

Utilities

  • RATLogger: Comprehensive logging system
  • ModelCheckpoint: Automatic checkpoint management
  • Configuration validation: Input sanitization and error checking

🚀 Quick Start

Installation

From PyPI (Recommended)

pip install rat-transformer

From Source

# Clone the repository
git clone https://github.com/arjun988/RAT.git
cd RAT

# Install in development mode
pip install -e .

# Or install with optional dependencies
pip install -e ".[dev,training,serving]"

Basic Usage

from rat import RAT, RATTrainer, RATGenerator
from transformers import AutoTokenizer

# Initialize model
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = RAT(vocab_size=tokenizer.vocab_size)

# Training
trainer = RATTrainer(model, tokenizer)
# ... training code ...

# Generation
generator = RATGenerator(model, tokenizer)
text = generator.generate("Hello, how are you?", max_len=50)
print(text)

Command Line Interface

# Train a model
rat-train --config config.json --output-dir ./checkpoints

# Generate text
rat-generate --model-path checkpoints/model.pt --prompt "Hello world"

# Run tests
rat-test --quick

# Evaluate model
rat-eval --model-path model.pt --dataset wikitext

Advanced Configuration

# Custom model configuration
model = RAT(
    vocab_size=50000,
    d_model=1024,
    n_layers=24,
    n_heads=16,
    n_policies=5,
    dropout=0.1,
    use_rope=True,
    use_checkpointing=True
)

# Advanced training
trainer = RATTrainer(
    model=model,
    tokenizer=tokenizer,
    lr=1e-4,
    max_steps=100000,
    grad_clip=1.0,
    checkpoint_dir="./checkpoints"
)

📊 Performance & Benchmarks

  • Parameter Efficiency: Better performance with fewer parameters
  • Training Stability: Advanced optimization and regularization
  • Generation Quality: Superior text coherence and diversity
  • Memory Optimization: Gradient checkpointing and KV caching

🔧 Configuration

Model Parameters

  • vocab_size: Size of token vocabulary
  • d_model: Model dimension (must be divisible by n_heads)
  • n_layers: Number of transformer layers
  • n_heads: Number of attention heads
  • n_policies: Number of RL policies for attention gating
  • max_seq_len: Maximum sequence length
  • dropout: Dropout probability

Training Parameters

  • lr: Learning rate
  • warmup_steps: Learning rate warmup steps
  • weight_decay: Weight decay for regularization
  • grad_clip: Gradient clipping threshold
  • accum_steps: Gradient accumulation steps

🧪 Testing & Validation

Run the comprehensive test suite:

python test_rat.py

The test suite validates:

  • ✅ Component functionality
  • ✅ Training pipeline
  • ✅ Text generation
  • ✅ Memory usage
  • ✅ Gradient flow
  • ✅ Error handling

📈 Training Tips

  1. Batch Size: Start with smaller batches and increase gradually
  2. Learning Rate: Use 1e-4 for large models, 5e-4 for smaller ones
  3. Gradient Accumulation: Use for effective larger batch sizes
  4. Checkpointing: Enable automatic saving every 1000 steps
  5. Monitoring: Watch perplexity and loss curves

🤝 Contributing

We welcome contributions! Please:

  1. Fork the repository
  2. Create a feature branch
  3. Add tests for new functionality
  4. Ensure all tests pass
  5. Submit a pull request

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.

🙏 Acknowledgments

  • Inspired by modern transformer architectures
  • Built on PyTorch and Hugging Face Transformers
  • Thanks to the research community for advancing transformer models

RAT: Reinforced Adaptive Transformer - Revolutionizing language models with reinforcement learning

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rat_transformer-0.1.1.tar.gz (20.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rat_transformer-0.1.1-py3-none-any.whl (20.7 kB view details)

Uploaded Python 3

File details

Details for the file rat_transformer-0.1.1.tar.gz.

File metadata

  • Download URL: rat_transformer-0.1.1.tar.gz
  • Upload date:
  • Size: 20.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.6

File hashes

Hashes for rat_transformer-0.1.1.tar.gz
Algorithm Hash digest
SHA256 5fa5784d42d71b7b171fdd786bb8e29ed3902bbb4f2456a628c49647a3aedc81
MD5 a9d1984a332caeae6cec1b10616cfa76
BLAKE2b-256 dbbfbbbb3df74ed78066a666b05b4bd35bedf7a80dad625d1428cd7b7acb8034

See more details on using hashes here.

File details

Details for the file rat_transformer-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for rat_transformer-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 dddbf51d4007c5c17c4869d08ce26e29a43a154df62b28ebca1a76c48e213c98
MD5 45208e007f5118c9cd72fdd838d10c74
BLAKE2b-256 dc28f9c428750ba4e8f4e6f66d8ff3a8e61a1e7b4fd5ee3c9e58144f82d16571

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page