Skip to main content

A modular implementation of the Free Transformer architecture with conditional VAE

Project description

Free Transformer

Free Transformer: A Llama-style decoder architecture with explicit latent plans, conditional VAE training, and benchmark comparisons against standard Transformers.

Designed for efficient PyTorch training on modern GPUs with full FSDP support and modern optimizations.


What Is the Free Transformer?

Traditional autoregressive Transformers generate each token by conditioning only on the sequence so far ("reactive" behavior). Free Transformer introduces a latent planning mechanismโ€”first choosing a stochastic abstract plan (Z), then generating tokens to fit that plan.
This scalable conditional VAE architecture maintains high-level coherence, improves controllable generation, and enables richer sequence modeling.

Architecture Overview

flowchart TD
    subgraph "Training Mode"
        A[Input Tokens] --> B[Embedding Layer]
        B --> C["Decoder Blocks 1..L/2"]
        C --> D["Encoder Block
        (Non-causal, learned query ฮถ)"]
        D --> E[Encoder Readout FC]
        E --> F["Binary Mapper
        Diff. discrete plan Z"]
        F --> G["Inject Z into model
        via Post-sampler FC"]
        C --> G
        G --> H["Decoder Blocks L/2+1..L"]
        H --> I[Output Logits]
    end

    subgraph "Inference Mode"
        AA[Prompt] --> BB[Embedding Layer]
        BB --> CC["Decoder Blocks 1..L/2"]
        subgraph "Plan Sampling"
            DD["Sample Random Z
            (Uniform prior)"]
        end
        DD --> GG[Inject Z via FC]
        CC --> GG
        GG --> HH["Decoder Blocks L/2+1..L"]
        HH --> II[Generate Tokens]
    end

Features

๐Ÿ—๏ธ Architecture

  • Llama-style backbone: RMSNorm, SwiGLU, RoPE, Grouped-Query Attention (GQA)
  • Latent Planning: Explicit plan variable Z with differentiable binary coding
  • Conditional VAE: Reconstruction + KL loss with free bits regularization

โšก Performance & Scaling

  • FSDP Support: Multi-GPU training with PyTorch Fully Sharded Data Parallel
  • Mixed Precision: Automatic Mixed Precision (AMP) with gradient scaling
  • Memory Efficient: Gradient checkpointing and optimized attention patterns
  • Modern Optimizations: bfloat16, efficient parameter sharding

๐Ÿ”ง Development & Training

  • Flexible Training: Switchable inference/training flows with mode selection
  • Synthetic + Real Data: Fast prototyping with built-in synthetic data generation
  • Comprehensive Testing: Unit/integration tests, benchmark comparisons
  • Quality Assurance: Type checking, linting, formatting, CI-ready

๐Ÿ“ฆ Usability

  • Extensible API: Modular classes, CLI scripts, YAML configuration
  • Docker Support: Containerized demos and development environment
  • Documentation: API references, architecture guides, examples

Installation

Using UV:

curl -LsSf https://astral.sh/uv/install.sh | sh # Install UV
uv venv --python 3.10
source .venv/bin/activate
uv pip install -e ".[dev]"

# Or after PyPI release
uv pip install free-transformer

Standard pip (after PyPI release):

pip install free-transformer

Quick Start with Docker

The fastest way to try the Free Transformer is using Docker:

# Clone the repository
git clone https://github.com/udapy/free-transformer.git
cd free-transformer

# Run the demo (requires Docker and nvidia-docker for GPU)
docker-compose up free-transformer-demo

This will:

  1. Generate small synthetic training data
  2. Train both baseline and Free Transformer models
  3. Compare their performance

For detailed Docker instructions, see docker/README.md.

Docker Options

GPU Version (Recommended):

# Build and run with GPU support
make docker-build
make docker-demo

CPU Version:

# Build and run CPU-only version
make docker-build-cpu
make docker-run-cpu

Interactive Development:

# Start interactive container for development
make docker-interactive

Manual Installation & Quick Start Demo

  1. Generate Small Synthetic Data

    make generate-data-small
    
  2. Train Baseline Transformer

    make train-baseline
    
  3. Train Free Transformer

    make train-free
    
  4. Run Model Comparison

    make compare
    

Or run the full pipeline:

make demo

Check results in:

  • checkpoints/baseline/
  • checkpoints/free/
  • results/comparison/results.json

Python API Example

from free_transformer import FreeTransformer, ModelConfig

config = ModelConfig(
    vocab_size=1000,
    hidden_dim=128,
    num_layers=6,
    num_heads=4,
    latent_dim=8,
)

model = FreeTransformer(config)

# Training mode
tokens = torch.randint(0, 1000, (2, 128))
logits, z_logits = model(tokens, mode='training')

# Inference/generation
generated = model.generate(tokens[:, :10], max_new_tokens=20)

Repository Structure

free-transformer/
โ”œโ”€โ”€ src/free_transformer/
โ”‚   โ”œโ”€โ”€ model.py
โ”‚   โ”œโ”€โ”€ baseline.py
โ”‚   โ”œโ”€โ”€ encoder.py
โ”‚   โ”œโ”€โ”€ latent.py
โ”‚   โ”œโ”€โ”€ injection.py
โ”‚   โ”œโ”€โ”€ losses.py
โ”‚   โ”œโ”€โ”€ synthetic_data.py
โ”‚   โ”œโ”€โ”€ train_utils.py
โ”‚   โ””โ”€โ”€ config.py
โ”œโ”€โ”€ examples/
โ”‚   โ”œโ”€โ”€ train_baseline.py
โ”‚   โ”œโ”€โ”€ train_free.py
โ”‚   โ”œโ”€โ”€ eval_compare.py
โ”‚   โ””โ”€โ”€ generate_data.py
โ”œโ”€โ”€ configs/
โ”‚   โ”œโ”€โ”€ baseline.yaml
โ”‚   โ””โ”€โ”€ free_transformer.yaml
โ”œโ”€โ”€ docker/
โ”‚   โ”œโ”€โ”€ demo.sh
โ”‚   โ””โ”€โ”€ README.md
โ”œโ”€โ”€ tests/
โ”‚   โ”œโ”€โ”€ unit/
โ”‚   โ”œโ”€โ”€ integration/
โ”‚   โ””โ”€โ”€ test_comparison.py
โ”œโ”€โ”€ docs/
โ”œโ”€โ”€ Dockerfile
โ”œโ”€โ”€ Dockerfile.cpu
โ”œโ”€โ”€ docker-compose.yml
โ”œโ”€โ”€ Makefile
โ”œโ”€โ”€ pyproject.toml
โ”œโ”€โ”€ .python-version
โ”œโ”€โ”€ LICENSE
โ””โ”€โ”€ README.md

Testing & Quality

Run all tests:

make test

Quality checks:

make quality

Advanced Features

๐Ÿš€ Multi-GPU Training

# FSDP training with automatic GPU detection
make train-baseline-fsdp
make train-free-fsdp

# Or use torchrun directly
torchrun --nproc_per_node=auto examples/train_free.py --config configs/free_transformer.yaml --use-fsdp

๐Ÿ“Š Custom Datasets

  • Plug in HuggingFace datasets via config files
  • Built-in synthetic data generation for quick prototyping
  • Extensible data loading pipeline

๐Ÿ”ง Extensibility

  • Modular architecture for easy customization
  • Add custom loss objectives, attention mechanisms, or model components
  • Hook system for training callbacks and monitoring

โš ๏ธ Current Limitations

  • DeepSpeed: Not yet implemented (FSDP is the current distributed training solution)
  • Flash Attention: Uses standard PyTorch attention (Flash Attention integration planned)
  • Inference Optimizations: No quantization or specialized inference backends yet

Documentation

  • Architecture: docs/architecture.md
  • API: auto-generated documentation (see docs/)
  • Example configs and usage tips included.

License

MIT License โ€” see LICENSE


Contributing

PRs and issues welcome โ€” see CONTRIBUTING.md

Before submitting code, run:

make test
make quality

FAQ

Can I use this for real-world (non-synthetic) data?
Yes! Edit configs and use HuggingFace datasets.

How do I run distributed training?
Use provided CLI flags or edit config. See docs and Makefile.

How do I change architecture parameters?
Edit YAML config files for layer size, latent dim, number of blocks, etc.

Can I run this without installing dependencies locally?
Yes! Use Docker: docker-compose up free-transformer-demo for a complete demo.

What if I don't have a GPU?
Use the CPU Docker image: make docker-build-cpu && make docker-run-cpu


Links

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

free_transformer-0.1.0.tar.gz (21.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

free_transformer-0.1.0-py3-none-any.whl (21.4 kB view details)

Uploaded Python 3

File details

Details for the file free_transformer-0.1.0.tar.gz.

File metadata

  • Download URL: free_transformer-0.1.0.tar.gz
  • Upload date:
  • Size: 21.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.22

File hashes

Hashes for free_transformer-0.1.0.tar.gz
Algorithm Hash digest
SHA256 2c49fa2226e97999d64de29f2a15a022fa545228445ca18aed619615445d79a9
MD5 f5d47693423ddd9620d4ea4dd4cbe406
BLAKE2b-256 a1022d1b80443ebef5812f2fddeb1dac9b2352c254ebb9c7a6b4ca2341784c6a

See more details on using hashes here.

File details

Details for the file free_transformer-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for free_transformer-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e6ef166066f8c9bf3c66b483960d67c1b3e93d1cd2f1c18b4972af747d397410
MD5 58eeb0502fb5f2aebacb99b8380dccf9
BLAKE2b-256 35a96bf2240fc9a2b2f1a1c8e8693043900053afc93990071e88741a7fec702f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page