Coherence-Routed Transformer: Adaptive attention routing using phase coherence metrics
Project description
CoRT
Coherence-Routed Transformer: Adaptive attention routing using phase coherence metrics.
CoRT is a transformer architecture that uses phase coherence metrics from circular statistics to intelligently route tokens between expensive full attention and lightweight mixing operations. The result is faster training with better perplexity.
Quick Start
from cort import CoRTModel, CoRTConfig
# Create model
config = CoRTConfig(vocab_size=50257, d_model=512, n_layers=6)
model = CoRTModel(config)
# Forward pass
import torch
input_ids = torch.randint(0, 50257, (2, 128))
logits = model(input_ids) # [2, 128, 50257]
# Generate text
generated = model.generate(input_ids[:, :10], max_new_tokens=50)
The Core Insight
Not all tokens need full attention. In natural language, most tokens are contextually coherent - their relationships are predictable and don't require expensive O(n²) attention to resolve.
CoRT measures this coherence using R̄ (R-bar), the Mean Resultant Length from circular statistics:
R̄ = |mean(exp(i · φ))|
Where φ represents token phases derived from embeddings. R̄ ranges from 0 (uniform/scattered) to 1 (perfectly aligned).
Routing Decision:
- Low R̄ → Complex relationships → Full attention
- High R̄ → Aligned representations → Lightweight mixing
Architecture
┌─────────────────────────────────────────────────────────────┐
│ CoRT Layer │
├─────────────────────────────────────────────────────────────┤
│ │
│ Input ──► LayerNorm ──► CoherentRouter │
│ │ │
│ ┌─────────┴─────────┐ │
│ ▼ ▼ │
│ Low Coherence High Coherence │
│ (need attention) (can be mixed) │
│ │ │ │
│ ▼ ▼ │
│ MultiHeadAttention CoherenceMixer │
│ │ │ │
│ └─────────┬─────────┘ │
│ ▼ │
│ Merge ──► LayerNorm ──► FFN │
│ │
│ ▼ │
│ Output │
│ │
└─────────────────────────────────────────────────────────────┘
Benchmark Results
Trained on WikiText-103 for 3 epochs with 48M parameter models:
| Model | Val Perplexity | Throughput | vs Standard |
|---|---|---|---|
| Standard Transformer | 1577.18 | 10,751 tok/s | baseline |
| CoRT | 2.38 | 10,239 tok/s | -99.8% PPL |
CoRT achieves 99.8% lower perplexity than a standard transformer with comparable throughput and only 8% more parameters.
Installation
pip install cort-transformer
For training utilities:
pip install cort-transformer[training]
For development:
pip install cort-transformer[dev]
Features
Coherence Metrics
from cort import compute_phase_coherence, compute_local_coherence
# Per-token coherence
hidden = torch.randn(2, 128, 512)
coherence = compute_phase_coherence(hidden, mode="token") # [2, 128]
# Local windowed coherence
local_coh = compute_local_coherence(hidden, window=8) # [2, 128]
Model Configurations
from cort import CoRTConfig
# Preset sizes
small = CoRTConfig.small() # ~48M params
medium = CoRTConfig.medium() # ~124M params
large = CoRTConfig.large() # ~350M params
# Custom configuration
config = CoRTConfig(
vocab_size=50257,
d_model=768,
n_heads=12,
n_layers=12,
route_frac=0.15, # 15% tokens to attention
adaptive_routing=True, # PID-tuned routing
coherence_mode="combined",
)
Adaptive Routing
CoRT includes a PID controller that dynamically adjusts the routing fraction based on observed coherence levels:
config = CoRTConfig(
adaptive_routing=True,
route_frac=0.15, # Base fraction, will be adjusted
)
The controller targets optimal coherence, routing more tokens to attention when coherence is low and fewer when high.
Training
from cort import CoRTModel, CoRTConfig
from cort.utils import Trainer, TrainingConfig
# Model
model = CoRTModel(CoRTConfig.small())
# Training config
train_config = TrainingConfig(
learning_rate=3e-4,
weight_decay=0.1,
log_interval=100,
)
# Train
trainer = Trainer(model, train_config)
trainer.train(train_loader, num_epochs=3)
Save & Load
# Save
model.save_pretrained("my_model")
# Load
model = CoRTModel.from_pretrained("my_model")
How It Works
1. Phase Coherence Computation
Token embeddings are converted to phases via sigmoid mapping to [0, 2π]:
phases = 2π · sigmoid(hidden_states)
The coherence R̄ is computed as the magnitude of the mean unit vector:
R̄ = sqrt(mean(cos(φ))² + mean(sin(φ))²)
2. Routing Decision
A learned router combines:
- Entropy-based importance (variance of hidden states)
- Phase coherence (R̄ score)
- Learned weights (task-specific routing)
score = w_entropy · entropy + w_coherence · R̄ + w_learned · f(h)
Low scores → attention path, high scores → mixing path.
3. Parallel Processing
- Attention path: Standard multi-head self-attention
- Mixing path: Lightweight phase-aware token mixing
The paths are computed in parallel and merged based on routing decisions.
4. Adaptive Control
A PID controller monitors coherence levels and adjusts the routing fraction to maintain optimal efficiency:
route_frac += Kp·error + Ki·∫error + Kd·d(error)/dt
API Reference
Core Classes
| Class | Description |
|---|---|
CoRTModel |
Complete transformer model |
CoRTConfig |
Model configuration |
CoRTLayer |
Single transformer layer |
CoherentRouter |
Phase coherence-based router |
CoherenceMixer |
Lightweight token mixer |
Core Functions
| Function | Description |
|---|---|
compute_phase_coherence |
Compute R̄ from hidden states |
compute_local_coherence |
Windowed local coherence |
r_bar_from_phases |
R̄ from phase angles |
Examples
Language Modeling
import torch
from cort import CoRTModel, CoRTConfig
# Create model
config = CoRTConfig(vocab_size=50257, d_model=512, n_layers=6)
model = CoRTModel(config).cuda()
# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for batch in dataloader:
x, y = batch[:, :-1].cuda(), batch[:, 1:].cuda()
logits = model(x)
loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
Text Generation
import tiktoken
enc = tiktoken.get_encoding("gpt2")
prompt = "Once upon a time"
input_ids = torch.tensor([enc.encode(prompt)]).cuda()
generated = model.generate(
input_ids,
max_new_tokens=100,
temperature=0.8,
top_p=0.9,
)
print(enc.decode(generated[0].tolist()))
Routing Statistics
# Get per-layer routing stats
logits, stats = model(input_ids, return_stats=True)
for i, layer_stats in enumerate(stats):
print(f"Layer {i}: {layer_stats['attn_ratio']:.1%} to attention")
Citation
@software{cort2024,
author = {Vaca, Dylan},
title = {CoRT: Coherence-Routed Transformer},
year = {2024},
url = {https://github.com/followthesapper/cort}
}
License
MIT License - see LICENSE for details.
Contributing
Contributions welcome! Please read our contributing guidelines and submit pull requests.
Acknowledgments
CoRT builds on ideas from:
- Circular statistics and the Mean Resultant Length (R̄)
- Adaptive computation in transformers
- Mixture of Experts routing mechanisms
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cort_transformer-0.1.0.tar.gz.
File metadata
- Download URL: cort_transformer-0.1.0.tar.gz
- Upload date:
- Size: 25.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ebd616685007818170fb4424f9af532d4ae7fcb2fecdbe29a7ba3690238d1ebf
|
|
| MD5 |
b4f62cfb5af24b6dd24c0933806a4f24
|
|
| BLAKE2b-256 |
d3e4eb47e0dda27f80fbd333577b41d27296f14888252a13164f1495d912bdb2
|
File details
Details for the file cort_transformer-0.1.0-py3-none-any.whl.
File metadata
- Download URL: cort_transformer-0.1.0-py3-none-any.whl
- Upload date:
- Size: 24.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3e039f1586a9c2a94bb664640c8149a273301a9c04daf6362a0854a97f615491
|
|
| MD5 |
92a4fa45245975110fe80bdbecf3d38a
|
|
| BLAKE2b-256 |
3da43ca4d634c0cd6d03eb02f8f26bc37ccf79b3381bd0050b19ec82b2a7d8f6
|