RAT: Reinforced Adaptive Transformer
Project description
RAT: Reinforced Adaptive Transformer
RAT (Reinforced Adaptive Transformer) is a next-generation transformer architecture featuring adaptive attention mechanisms powered by reinforcement learning. It combines cutting-edge techniques like Rotary Position Embeddings, SwiGLU feed-forward networks, and temporal convolutions for superior language modeling performance.
✨ Key Features
- 🧠 Adaptive Policy Attention: Dynamic head gating using multiple RL-based policy networks
- 🔄 Rotary Position Embeddings: Enhanced positional understanding with RoPE
- 🚀 SwiGLU Feed-Forward: Efficient activation for better expressiveness
- ⏰ Temporal Convolutions: Sequence modeling with depthwise convolutions
- 📊 Advanced Logging: Comprehensive training monitoring and debugging
- 🛡️ Error Handling: Robust validation and graceful failure recovery
- 💾 Auto-Checkpointing: Automatic model saving with training state
- 🎯 Optimized Generation: Multiple sampling strategies with KV caching
🏗️ Architecture Components
Core Components
RAT: Main transformer model with adaptive attentionAdaptivePolicyAttention: Multi-policy attention with reinforcement learningRATBlock: Transformer block with attention, FFN, and temporal convSwiGLUFeedForward: Efficient feed-forward networkRotaryPositionEmbedding: Rotary positional encodings
Training & Inference
RATTrainer: Advanced trainer with logging and checkpointingRATGenerator: Optimized text generation with multiple strategiesRATDataset: Enhanced dataset with preprocessing and validation
Utilities
RATLogger: Comprehensive logging systemModelCheckpoint: Automatic checkpoint management- Configuration validation: Input sanitization and error checking
🚀 Quick Start
Installation
From PyPI (Recommended)
pip install rat-transformer
From Source
# Clone the repository
git clone https://github.com/arjun988/RAT.git
cd RAT
# Install in development mode
pip install -e .
# Or install with optional dependencies
pip install -e ".[dev,training,serving]"
Basic Usage
from rat import RAT, RATTrainer, RATGenerator
from transformers import AutoTokenizer
# Initialize model
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = RAT(vocab_size=tokenizer.vocab_size)
# Training
trainer = RATTrainer(model, tokenizer)
# ... training code ...
# Generation
generator = RATGenerator(model, tokenizer)
text = generator.generate("Hello, how are you?", max_len=50)
print(text)
Command Line Interface
# Train a model
rat-train --config config.json --output-dir ./checkpoints
# Generate text
rat-generate --model-path checkpoints/model.pt --prompt "Hello world"
# Run tests
rat-test --quick
# Evaluate model
rat-eval --model-path model.pt --dataset wikitext
Advanced Configuration
# Custom model configuration
model = RAT(
vocab_size=50000,
d_model=1024,
n_layers=24,
n_heads=16,
n_policies=5,
dropout=0.1,
use_rope=True,
use_checkpointing=True
)
# Advanced training
trainer = RATTrainer(
model=model,
tokenizer=tokenizer,
lr=1e-4,
max_steps=100000,
grad_clip=1.0,
checkpoint_dir="./checkpoints"
)
📊 Performance & Benchmarks
- Parameter Efficiency: Better performance with fewer parameters
- Training Stability: Advanced optimization and regularization
- Generation Quality: Superior text coherence and diversity
- Memory Optimization: Gradient checkpointing and KV caching
🔧 Configuration
Model Parameters
vocab_size: Size of token vocabularyd_model: Model dimension (must be divisible by n_heads)n_layers: Number of transformer layersn_heads: Number of attention headsn_policies: Number of RL policies for attention gatingmax_seq_len: Maximum sequence lengthdropout: Dropout probability
Training Parameters
lr: Learning ratewarmup_steps: Learning rate warmup stepsweight_decay: Weight decay for regularizationgrad_clip: Gradient clipping thresholdaccum_steps: Gradient accumulation steps
🧪 Testing & Validation
Run the comprehensive test suite:
python test_rat.py
The test suite validates:
- ✅ Component functionality
- ✅ Training pipeline
- ✅ Text generation
- ✅ Memory usage
- ✅ Gradient flow
- ✅ Error handling
📈 Training Tips
- Batch Size: Start with smaller batches and increase gradually
- Learning Rate: Use 1e-4 for large models, 5e-4 for smaller ones
- Gradient Accumulation: Use for effective larger batch sizes
- Checkpointing: Enable automatic saving every 1000 steps
- Monitoring: Watch perplexity and loss curves
🤝 Contributing
We welcome contributions! Please:
- Fork the repository
- Create a feature branch
- Add tests for new functionality
- Ensure all tests pass
- Submit a pull request
📄 License
This project is licensed under the MIT License - see the LICENSE file for details.
🙏 Acknowledgments
- Inspired by modern transformer architectures
- Built on PyTorch and Hugging Face Transformers
- Thanks to the research community for advancing transformer models
RAT: Reinforced Adaptive Transformer - Revolutionizing language models with reinforcement learning
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file rat_transformer-0.1.1.tar.gz.
File metadata
- Download URL: rat_transformer-0.1.1.tar.gz
- Upload date:
- Size: 20.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5fa5784d42d71b7b171fdd786bb8e29ed3902bbb4f2456a628c49647a3aedc81
|
|
| MD5 |
a9d1984a332caeae6cec1b10616cfa76
|
|
| BLAKE2b-256 |
dbbfbbbb3df74ed78066a666b05b4bd35bedf7a80dad625d1428cd7b7acb8034
|
File details
Details for the file rat_transformer-0.1.1-py3-none-any.whl.
File metadata
- Download URL: rat_transformer-0.1.1-py3-none-any.whl
- Upload date:
- Size: 20.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dddbf51d4007c5c17c4869d08ce26e29a43a154df62b28ebca1a76c48e213c98
|
|
| MD5 |
45208e007f5118c9cd72fdd838d10c74
|
|
| BLAKE2b-256 |
dc28f9c428750ba4e8f4e6f66d8ff3a8e61a1e7b4fd5ee3c9e58144f82d16571
|