Skip to main content

RAT: Reinforced Adaptive Transformer

Project description

RAT: Reinforced Adaptive Transformer

RAT Architecture

RAT (Reinforced Adaptive Transformer) is a next-generation transformer architecture featuring adaptive attention mechanisms powered by reinforcement learning. It combines cutting-edge techniques like Rotary Position Embeddings, SwiGLU feed-forward networks, and temporal convolutions for superior language modeling performance.

✨ Key Features

  • 🧠 Adaptive Policy Attention: Dynamic head gating using multiple RL-based policy networks
  • 🔄 Rotary Position Embeddings: Enhanced positional understanding with RoPE
  • 🚀 SwiGLU Feed-Forward: Efficient activation for better expressiveness
  • ⏰ Temporal Convolutions: Sequence modeling with depthwise convolutions
  • 📊 Advanced Logging: Comprehensive training monitoring and debugging
  • 🛡️ Error Handling: Robust validation and graceful failure recovery
  • 💾 Auto-Checkpointing: Automatic model saving with training state
  • 🎯 Optimized Generation: Multiple sampling strategies with KV caching

🏗️ Architecture Components

Core Components

  • NeuroForge: Main transformer model with adaptive attention
  • AdaptivePolicyAttention: Multi-policy attention with reinforcement learning
  • NeuroForgeBlock: Transformer block with attention, FFN, and temporal conv
  • SwiGLUFeedForward: Efficient feed-forward network
  • RotaryPositionEmbedding: Rotary positional encodings

Training & Inference

  • NeuroForgeTrainer: Advanced trainer with logging and checkpointing
  • NeuroForgeGenerator: Optimized text generation with multiple strategies
  • NeuroForgeDataset: Enhanced dataset with preprocessing and validation

Utilities

  • NeuroForgeLogger: Comprehensive logging system
  • ModelCheckpoint: Automatic checkpoint management
  • Configuration validation: Input sanitization and error checking

🚀 Quick Start

Installation

From PyPI (Recommended)

pip install rat-transformer

From Source

# Clone the repository
git clone https://github.com/arjun988/RAT.git
cd RAT

# Install in development mode
pip install -e .

# Or install with optional dependencies
pip install -e ".[dev,training,serving]"

Basic Usage

from rat import RAT, RATTrainer, RATGenerator
from transformers import AutoTokenizer

# Initialize model
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = RAT(vocab_size=tokenizer.vocab_size)

# Training
trainer = RATTrainer(model, tokenizer)
# ... training code ...

# Generation
generator = RATGenerator(model, tokenizer)
text = generator.generate("Hello, how are you?", max_len=50)
print(text)

Command Line Interface

# Train a model
rat-train --config config.json --output-dir ./checkpoints

# Generate text
rat-generate --model-path checkpoints/model.pt --prompt "Hello world"

# Run tests
rat-test --quick

# Evaluate model
rat-eval --model-path model.pt --dataset wikitext

Advanced Configuration

# Custom model configuration
model = RAT(
    vocab_size=50000,
    d_model=1024,
    n_layers=24,
    n_heads=16,
    n_policies=5,
    dropout=0.1,
    use_rope=True,
    use_checkpointing=True
)

# Advanced training
trainer = RATTrainer(
    model=model,
    tokenizer=tokenizer,
    lr=1e-4,
    max_steps=100000,
    grad_clip=1.0,
    checkpoint_dir="./checkpoints"
)

📊 Performance & Benchmarks

  • Parameter Efficiency: Better performance with fewer parameters
  • Training Stability: Advanced optimization and regularization
  • Generation Quality: Superior text coherence and diversity
  • Memory Optimization: Gradient checkpointing and KV caching

🔧 Configuration

Model Parameters

  • vocab_size: Size of token vocabulary
  • d_model: Model dimension (must be divisible by n_heads)
  • n_layers: Number of transformer layers
  • n_heads: Number of attention heads
  • n_policies: Number of RL policies for attention gating
  • max_seq_len: Maximum sequence length
  • dropout: Dropout probability

Training Parameters

  • lr: Learning rate
  • warmup_steps: Learning rate warmup steps
  • weight_decay: Weight decay for regularization
  • grad_clip: Gradient clipping threshold
  • accum_steps: Gradient accumulation steps

🧪 Testing & Validation

Run the comprehensive test suite:

python test_rat.py

The test suite validates:

  • ✅ Component functionality
  • ✅ Training pipeline
  • ✅ Text generation
  • ✅ Memory usage
  • ✅ Gradient flow
  • ✅ Error handling

📈 Training Tips

  1. Batch Size: Start with smaller batches and increase gradually
  2. Learning Rate: Use 1e-4 for large models, 5e-4 for smaller ones
  3. Gradient Accumulation: Use for effective larger batch sizes
  4. Checkpointing: Enable automatic saving every 1000 steps
  5. Monitoring: Watch perplexity and loss curves

🤝 Contributing

We welcome contributions! Please:

  1. Fork the repository
  2. Create a feature branch
  3. Add tests for new functionality
  4. Ensure all tests pass
  5. Submit a pull request

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.

🙏 Acknowledgments

  • Inspired by modern transformer architectures
  • Built on PyTorch and Hugging Face Transformers
  • Thanks to the research community for advancing transformer models

NeuroForge: Forging the future of adaptive transformers

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rat_transformer-0.1.0.tar.gz (20.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rat_transformer-0.1.0-py3-none-any.whl (20.7 kB view details)

Uploaded Python 3

File details

Details for the file rat_transformer-0.1.0.tar.gz.

File metadata

  • Download URL: rat_transformer-0.1.0.tar.gz
  • Upload date:
  • Size: 20.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.6

File hashes

Hashes for rat_transformer-0.1.0.tar.gz
Algorithm Hash digest
SHA256 fca270e968da1373de99008ab263f9de4ca8844c60c13fb7bbb7d96281b4bf5f
MD5 835552ef7c418c5b27ac5c14567097f1
BLAKE2b-256 cca3306b8a7e37f9f090088bbf2c95d384f55da78da65152edb1cecb133cb5b8

See more details on using hashes here.

File details

Details for the file rat_transformer-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for rat_transformer-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c5892fd57191620f3421f471f0a0ccc295d9c1719e242083c82b37d8625a6fec
MD5 05693d3e50640c04db83d2b44fdc00de
BLAKE2b-256 8d6bcdcce3fd4b57367c402198be18ef2b59bb45e6705d4eee92fac567621814

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page