A modular implementation of the Free Transformer architecture with conditional VAE
Project description
Free Transformer
Free Transformer: A Llama-style decoder architecture with explicit latent plans, conditional VAE training, and benchmark comparisons against standard Transformers.
Designed for efficient PyTorch training on modern GPUs with full FSDP support and modern optimizations.
What Is the Free Transformer?
Traditional autoregressive Transformers generate each token by conditioning only on the sequence so far ("reactive" behavior).
Free Transformer introduces a latent planning mechanismโfirst choosing a stochastic abstract plan (Z), then generating tokens to fit that plan.
This scalable conditional VAE architecture maintains high-level coherence, improves controllable generation, and enables richer sequence modeling.
Architecture Overview
flowchart TD
subgraph "Training Mode"
A[Input Tokens] --> B[Embedding Layer]
B --> C["Decoder Blocks 1..L/2"]
C --> D["Encoder Block
(Non-causal, learned query ฮถ)"]
D --> E[Encoder Readout FC]
E --> F["Binary Mapper
Diff. discrete plan Z"]
F --> G["Inject Z into model
via Post-sampler FC"]
C --> G
G --> H["Decoder Blocks L/2+1..L"]
H --> I[Output Logits]
end
subgraph "Inference Mode"
AA[Prompt] --> BB[Embedding Layer]
BB --> CC["Decoder Blocks 1..L/2"]
subgraph "Plan Sampling"
DD["Sample Random Z
(Uniform prior)"]
end
DD --> GG[Inject Z via FC]
CC --> GG
GG --> HH["Decoder Blocks L/2+1..L"]
HH --> II[Generate Tokens]
end
Features
๐๏ธ Architecture
- Llama-style backbone: RMSNorm, SwiGLU, RoPE, Grouped-Query Attention (GQA)
- Latent Planning: Explicit plan variable
Zwith differentiable binary coding - Conditional VAE: Reconstruction + KL loss with free bits regularization
โก Performance & Scaling
- FSDP Support: Multi-GPU training with PyTorch Fully Sharded Data Parallel
- Mixed Precision: Automatic Mixed Precision (AMP) with gradient scaling
- Memory Efficient: Gradient checkpointing and optimized attention patterns
- Modern Optimizations: bfloat16, efficient parameter sharding
๐ง Development & Training
- Flexible Training: Switchable inference/training flows with mode selection
- Synthetic + Real Data: Fast prototyping with built-in synthetic data generation
- Comprehensive Testing: Unit/integration tests, benchmark comparisons
- Quality Assurance: Type checking, linting, formatting, CI-ready
๐ฆ Usability
- Extensible API: Modular classes, CLI scripts, YAML configuration
- Docker Support: Containerized demos and development environment
- Documentation: API references, architecture guides, examples
Installation
Using UV:
curl -LsSf https://astral.sh/uv/install.sh | sh # Install UV
uv venv --python 3.10
source .venv/bin/activate
uv pip install -e ".[dev]"
# Or after PyPI release
uv pip install free-transformer
Standard pip (after PyPI release):
pip install free-transformer
Quick Start with Docker
The fastest way to try the Free Transformer is using Docker:
# Clone the repository
git clone https://github.com/udapy/free-transformer.git
cd free-transformer
# Run the demo (requires Docker and nvidia-docker for GPU)
docker-compose up free-transformer-demo
This will:
- Generate small synthetic training data
- Train both baseline and Free Transformer models
- Compare their performance
For detailed Docker instructions, see docker/README.md.
Docker Options
GPU Version (Recommended):
# Build and run with GPU support
make docker-build
make docker-demo
CPU Version:
# Build and run CPU-only version
make docker-build-cpu
make docker-run-cpu
Interactive Development:
# Start interactive container for development
make docker-interactive
Manual Installation & Quick Start Demo
-
Generate Small Synthetic Data
make generate-data-small -
Train Baseline Transformer
make train-baseline -
Train Free Transformer
make train-free -
Run Model Comparison
make compare
Or run the full pipeline:
make demo
Check results in:
checkpoints/baseline/checkpoints/free/results/comparison/results.json
Python API Example
from free_transformer import FreeTransformer, ModelConfig
config = ModelConfig(
vocab_size=1000,
hidden_dim=128,
num_layers=6,
num_heads=4,
latent_dim=8,
)
model = FreeTransformer(config)
# Training mode
tokens = torch.randint(0, 1000, (2, 128))
logits, z_logits = model(tokens, mode='training')
# Inference/generation
generated = model.generate(tokens[:, :10], max_new_tokens=20)
Repository Structure
free-transformer/
โโโ src/free_transformer/
โ โโโ model.py
โ โโโ baseline.py
โ โโโ encoder.py
โ โโโ latent.py
โ โโโ injection.py
โ โโโ losses.py
โ โโโ synthetic_data.py
โ โโโ train_utils.py
โ โโโ config.py
โโโ examples/
โ โโโ train_baseline.py
โ โโโ train_free.py
โ โโโ eval_compare.py
โ โโโ generate_data.py
โโโ configs/
โ โโโ baseline.yaml
โ โโโ free_transformer.yaml
โโโ docker/
โ โโโ demo.sh
โ โโโ README.md
โโโ tests/
โ โโโ unit/
โ โโโ integration/
โ โโโ test_comparison.py
โโโ docs/
โโโ Dockerfile
โโโ Dockerfile.cpu
โโโ docker-compose.yml
โโโ Makefile
โโโ pyproject.toml
โโโ .python-version
โโโ LICENSE
โโโ README.md
Testing & Quality
Run all tests:
make test
Quality checks:
make quality
Advanced Features
๐ Multi-GPU Training
# FSDP training with automatic GPU detection
make train-baseline-fsdp
make train-free-fsdp
# Or use torchrun directly
torchrun --nproc_per_node=auto examples/train_free.py --config configs/free_transformer.yaml --use-fsdp
๐ Custom Datasets
- Plug in HuggingFace datasets via config files
- Built-in synthetic data generation for quick prototyping
- Extensible data loading pipeline
๐ง Extensibility
- Modular architecture for easy customization
- Add custom loss objectives, attention mechanisms, or model components
- Hook system for training callbacks and monitoring
โ ๏ธ Current Limitations
- DeepSpeed: Not yet implemented (FSDP is the current distributed training solution)
- Flash Attention: Uses standard PyTorch attention (Flash Attention integration planned)
- Inference Optimizations: No quantization or specialized inference backends yet
Documentation
- Architecture:
docs/architecture.md - API: auto-generated documentation (see
docs/) - Example configs and usage tips included.
License
MIT License โ see LICENSE
Contributing
PRs and issues welcome โ see CONTRIBUTING.md
Before submitting code, run:
make test
make quality
FAQ
Can I use this for real-world (non-synthetic) data?
Yes! Edit configs and use HuggingFace datasets.
How do I run distributed training?
Use provided CLI flags or edit config. See docs and Makefile.
How do I change architecture parameters?
Edit YAML config files for layer size, latent dim, number of blocks, etc.
Can I run this without installing dependencies locally?
Yes! Use Docker: docker-compose up free-transformer-demo for a complete demo.
What if I don't have a GPU?
Use the CPU Docker image: make docker-build-cpu && make docker-run-cpu
Links
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file free_transformer-0.1.0.tar.gz.
File metadata
- Download URL: free_transformer-0.1.0.tar.gz
- Upload date:
- Size: 21.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.22
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2c49fa2226e97999d64de29f2a15a022fa545228445ca18aed619615445d79a9
|
|
| MD5 |
f5d47693423ddd9620d4ea4dd4cbe406
|
|
| BLAKE2b-256 |
a1022d1b80443ebef5812f2fddeb1dac9b2352c254ebb9c7a6b4ca2341784c6a
|
File details
Details for the file free_transformer-0.1.0-py3-none-any.whl.
File metadata
- Download URL: free_transformer-0.1.0-py3-none-any.whl
- Upload date:
- Size: 21.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.22
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e6ef166066f8c9bf3c66b483960d67c1b3e93d1cd2f1c18b4972af747d397410
|
|
| MD5 |
58eeb0502fb5f2aebacb99b8380dccf9
|
|
| BLAKE2b-256 |
35a96bf2240fc9a2b2f1a1c8e8693043900053afc93990071e88741a7fec702f
|