Skip to main content

Coherence-Routed Transformer: Adaptive attention routing using phase coherence metrics

Project description

CoRT

Coherence-Routed Transformer: Adaptive attention routing using phase coherence metrics.

PyPI version License: MIT Python 3.9+

CoRT is a transformer architecture that uses phase coherence metrics from circular statistics to intelligently route tokens between expensive full attention and lightweight mixing operations. The result is faster training with better perplexity.

Quick Start

from cort import CoRTModel, CoRTConfig

# Create model
config = CoRTConfig(vocab_size=50257, d_model=512, n_layers=6)
model = CoRTModel(config)

# Forward pass
import torch
input_ids = torch.randint(0, 50257, (2, 128))
logits = model(input_ids)  # [2, 128, 50257]

# Generate text
generated = model.generate(input_ids[:, :10], max_new_tokens=50)

The Core Insight

Not all tokens need full attention. In natural language, most tokens are contextually coherent - their relationships are predictable and don't require expensive O(n²) attention to resolve.

CoRT measures this coherence using R̄ (R-bar), the Mean Resultant Length from circular statistics:

R̄ = |mean(exp(i · φ))|

Where φ represents token phases derived from embeddings. R̄ ranges from 0 (uniform/scattered) to 1 (perfectly aligned).

Routing Decision:

  • Low R̄ → Complex relationships → Full attention
  • High R̄ → Aligned representations → Lightweight mixing

Architecture

┌─────────────────────────────────────────────────────────────┐
│                        CoRT Layer                            │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│   Input ──► LayerNorm ──► CoherentRouter                    │
│                              │                               │
│                    ┌─────────┴─────────┐                    │
│                    ▼                   ▼                    │
│              Low Coherence       High Coherence             │
│              (need attention)    (can be mixed)             │
│                    │                   │                    │
│                    ▼                   ▼                    │
│            MultiHeadAttention    CoherenceMixer             │
│                    │                   │                    │
│                    └─────────┬─────────┘                    │
│                              ▼                               │
│                           Merge ──► LayerNorm ──► FFN       │
│                                                              │
│                              ▼                               │
│                           Output                             │
│                                                              │
└─────────────────────────────────────────────────────────────┘

Benchmark Results

Trained on WikiText-103 for 3 epochs with 48M parameter models:

Model Val Perplexity Throughput vs Standard
Standard Transformer 1577.18 10,751 tok/s baseline
CoRT 2.38 10,239 tok/s -99.8% PPL

CoRT achieves 99.8% lower perplexity than a standard transformer with comparable throughput and only 8% more parameters.

Installation

pip install cort-transformer

For training utilities:

pip install cort-transformer[training]

For development:

pip install cort-transformer[dev]

Features

Coherence Metrics

from cort import compute_phase_coherence, compute_local_coherence

# Per-token coherence
hidden = torch.randn(2, 128, 512)
coherence = compute_phase_coherence(hidden, mode="token")  # [2, 128]

# Local windowed coherence
local_coh = compute_local_coherence(hidden, window=8)  # [2, 128]

Model Configurations

from cort import CoRTConfig

# Preset sizes
small = CoRTConfig.small()   # ~48M params
medium = CoRTConfig.medium() # ~124M params
large = CoRTConfig.large()   # ~350M params

# Custom configuration
config = CoRTConfig(
    vocab_size=50257,
    d_model=768,
    n_heads=12,
    n_layers=12,
    route_frac=0.15,        # 15% tokens to attention
    adaptive_routing=True,   # PID-tuned routing
    coherence_mode="combined",
)

Adaptive Routing

CoRT includes a PID controller that dynamically adjusts the routing fraction based on observed coherence levels:

config = CoRTConfig(
    adaptive_routing=True,
    route_frac=0.15,  # Base fraction, will be adjusted
)

The controller targets optimal coherence, routing more tokens to attention when coherence is low and fewer when high.

Training

from cort import CoRTModel, CoRTConfig
from cort.utils import Trainer, TrainingConfig

# Model
model = CoRTModel(CoRTConfig.small())

# Training config
train_config = TrainingConfig(
    learning_rate=3e-4,
    weight_decay=0.1,
    log_interval=100,
)

# Train
trainer = Trainer(model, train_config)
trainer.train(train_loader, num_epochs=3)

Save & Load

# Save
model.save_pretrained("my_model")

# Load
model = CoRTModel.from_pretrained("my_model")

How It Works

1. Phase Coherence Computation

Token embeddings are converted to phases via sigmoid mapping to [0, 2π]:

phases = 2π · sigmoid(hidden_states)

The coherence R̄ is computed as the magnitude of the mean unit vector:

 = sqrt(mean(cos(φ))² + mean(sin(φ))²)

2. Routing Decision

A learned router combines:

  • Entropy-based importance (variance of hidden states)
  • Phase coherence (R̄ score)
  • Learned weights (task-specific routing)
score = w_entropy · entropy + w_coherence ·  + w_learned · f(h)

Low scores → attention path, high scores → mixing path.

3. Parallel Processing

  • Attention path: Standard multi-head self-attention
  • Mixing path: Lightweight phase-aware token mixing

The paths are computed in parallel and merged based on routing decisions.

4. Adaptive Control

A PID controller monitors coherence levels and adjusts the routing fraction to maintain optimal efficiency:

route_frac += Kp·error + Ki·error + Kd·d(error)/dt

API Reference

Core Classes

Class Description
CoRTModel Complete transformer model
CoRTConfig Model configuration
CoRTLayer Single transformer layer
CoherentRouter Phase coherence-based router
CoherenceMixer Lightweight token mixer

Core Functions

Function Description
compute_phase_coherence Compute R̄ from hidden states
compute_local_coherence Windowed local coherence
r_bar_from_phases R̄ from phase angles

Examples

Language Modeling

import torch
from cort import CoRTModel, CoRTConfig

# Create model
config = CoRTConfig(vocab_size=50257, d_model=512, n_layers=6)
model = CoRTModel(config).cuda()

# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for batch in dataloader:
    x, y = batch[:, :-1].cuda(), batch[:, 1:].cuda()

    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Text Generation

import tiktoken

enc = tiktoken.get_encoding("gpt2")
prompt = "Once upon a time"
input_ids = torch.tensor([enc.encode(prompt)]).cuda()

generated = model.generate(
    input_ids,
    max_new_tokens=100,
    temperature=0.8,
    top_p=0.9,
)

print(enc.decode(generated[0].tolist()))

Routing Statistics

# Get per-layer routing stats
logits, stats = model(input_ids, return_stats=True)

for i, layer_stats in enumerate(stats):
    print(f"Layer {i}: {layer_stats['attn_ratio']:.1%} to attention")

Citation

@software{cort2024,
  author = {Vaca, Dylan},
  title = {CoRT: Coherence-Routed Transformer},
  year = {2024},
  url = {https://github.com/followthesapper/cort}
}

License

MIT License - see LICENSE for details.

Contributing

Contributions welcome! Please read our contributing guidelines and submit pull requests.

Acknowledgments

CoRT builds on ideas from:

  • Circular statistics and the Mean Resultant Length (R̄)
  • Adaptive computation in transformers
  • Mixture of Experts routing mechanisms

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cort_transformer-0.1.0.tar.gz (25.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cort_transformer-0.1.0-py3-none-any.whl (24.5 kB view details)

Uploaded Python 3

File details

Details for the file cort_transformer-0.1.0.tar.gz.

File metadata

  • Download URL: cort_transformer-0.1.0.tar.gz
  • Upload date:
  • Size: 25.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cort_transformer-0.1.0.tar.gz
Algorithm Hash digest
SHA256 ebd616685007818170fb4424f9af532d4ae7fcb2fecdbe29a7ba3690238d1ebf
MD5 b4f62cfb5af24b6dd24c0933806a4f24
BLAKE2b-256 d3e4eb47e0dda27f80fbd333577b41d27296f14888252a13164f1495d912bdb2

See more details on using hashes here.

File details

Details for the file cort_transformer-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for cort_transformer-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3e039f1586a9c2a94bb664640c8149a273301a9c04daf6362a0854a97f615491
MD5 92a4fa45245975110fe80bdbecf3d38a
BLAKE2b-256 3da43ca4d634c0cd6d03eb02f8f26bc37ccf79b3381bd0050b19ec82b2a7d8f6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page