Skip to main content

FlashHead: vLLM Plugin for Fast Language Model Head Inference

Project description

FlashHead

vLLM Plugin for Fast Language Model Head Inference

The dense classification head accounts for up to 60% of parameters in small LLMs and roughly half of decode-step compute. FlashHead replaces it with a two-stage retrieval pipeline — up to 2.0x model-level inference speedup while maintaining accuracy — training-free and hardware-friendly. FlashHead integrates via vLLM's official vllm.general_plugins entry point: no source patches, no custom Docker image.

Python vLLM License
Paper Collection Benchmarks

FlashHead: Efficient Drop-In Replacement for the Classification Head in Language Model Inference

The standard LM head computes a dense matrix multiplication $h_t × W_{vocab}$ at every decode step, scoring all Vocabulary tokens regardless of relevance. FlashHead reframes this as a two-stage retrieval problem over clustered token embeddings: first identify which regions of vocabulary space are relevant, then score only those candidates.

⚡ Key Tradeoff A dense head scores 128,256 tokens per step (for a 128K vocabulary). With c = 8,016 clusters and p = 256 probes, FlashHead scores only 8,016 + 256 × 16 = 12,112 tokens, a 10× reduction in scored tokens, while multi-probe retrieval maintains near-perfect recall of the correct next token.

Note. The offline clustering step runs once per model and adds zero overhead at inference time. Both stages use contiguous memory access patterns for GPU and edge accelerator efficiency.

Four key ideas (see paper)

  • Equal-sized Clustering Token embeddings grouped into balanced clusters for predictable memory access and stable latency. Unlike hierarchical softmax, cluster sizes stay uniform; critical for GPU and edge accelerators.

  • Multi-Probe Retrieval Instead of committing to a single cluster, FlashHead probes multiple centroids - beam search over vocabulary space. Near-perfect recall with far fewer evaluations.

  • Full Decoding Support Supports both greedy and sampling decoding. For sampling, clusters are selected proportionally to centroid probabilities, preserving the output distribution.

  • Selective Quantization Stage 1 (coarse centroid scoring) runs in low precision; Stage 2 preserves accuracy. The head's quantization weakness becomes a structural advantage.

📦 Installation

Prerequisites: Python 3.10+ and vLLM >= 0.14.0

pip install flash-head

That's it. The plugin is discovered automatically by vLLM at startup. ✨

Install from source
git clone https://github.com/embedl/flash-head.git
cd flash-head
pip install .

🚀 Usage

⌨️ CLI

# FlashHead activates automatically for compatible models
vllm serve embedl/Cosmos-Reason2-2B-W4A16-Edge2-FlashHead \
    --host 0.0.0.0 --port 8000 \
    --gpu-memory-utilization 0.75 \
    --max-model-len 8192

# Disable without uninstalling
FLASHHEAD_ENABLED=0 vllm serve ...

🐍 Python

from vllm import LLM, SamplingParams

llm = LLM(
    model="embedl/Cosmos-Reason2-2B-W4A16-Edge2-FlashHead",
    trust_remote_code=True,
)
outputs = llm.generate(
    ["Explain quantum computing."],
    SamplingParams(max_tokens=50),
)
print(outputs[0].outputs[0].text)

The model's config.json contains "flash_head_cache_dir": "flash_head_assets" which signals FlashHead to activate. Standard models without this field are completely unaffected.

🔧 vLLM Plugin Integration

  1. Discovery vLLM discovers the flash-head plugin via the vllm.general_plugins entry point at startup
  2. Patching register() is called in every process, intercepting logits computation, sampling, and speculative decoding
  3. Inference The worker lazily constructs the FlashHead module on GPU from the model's clustering cache

🛡️ Safety

FlashHead models use a custom architecture name (e.g., FlashHeadQwen3VLForConditionalGeneration). Without the plugin installed, vLLM does not recognize the architecture and refuses to load the model. Users cannot accidentally run at reduced speed.

Scenario Behavior
Plugin not installed ❌ vLLM errors: architecture not supported
Plugin installed, FLASHHEAD_ENABLED=0 ⏸️ Clean disable, model loads without FlashHead
Plugin installed, enabled ✅ FlashHead loads on GPU, full speedup

🏗️ Supported Architectures

See most recent architectures in _FLASHHEAD_ARCHITECTURES:

_FLASHHEAD_ARCHITECTURES = {
    "FlashHeadLlamaForCausalLM": "vllm.model_executor.models.llama:LlamaForCausalLM",
    "FlashHeadQwen3ForCausalLM": "vllm.model_executor.models.qwen3:Qwen3ForCausalLM",
    "FlashHeadQwen3VLForConditionalGeneration": "vllm.model_executor.models.qwen3_vl:Qwen3VLForConditionalGeneration",
    "FlashHeadGemma3ForCausalLM": "vllm.model_executor.models.gemma2:Gemma2ForCausalLM",
}

📤 Publishing FlashHead Models

For the safety check to work, FlashHead models should use this config.json structure:

{
  "architectures": ["FlashHeadQwen3VLForConditionalGeneration"],
  "model_type": "qwen3_vl",
  "flash_head_cache_dir": "flash_head_assets"
}
  • architectures uses the FlashHead* prefix so vLLM rejects the model without the plugin
  • model_type stays standard so vLLM can resolve the base model class
  • flash_head_cache_dir points to the clustering cache directory
  • Do NOT include auto_map -- the plugin handles registration

🗺️ Roadmap

  • ✅ Core FlashHead plugin for vLLM (greedy decoding)
  • ✅ Balanced clustering with multiprobe retrieval
  • ✅ Inference-time sampling across full vocabulary
  • ✅ Quantized model support
  • 🔄 Speculative decoding: Full FlashHead integration with vLLM's speculative decoding pipeline (in progress)
  • 🔄 EAGLE draft proposals: FlashHead-accelerated draft generation for EAGLE speculative decoding (in progress)
  • ⬜ Additional model architectures
  • ⬜ Benchmarks on additional edge platforms (Qualcomm, AMD, Intel, ...)

💡 Want a feature? Open an issue!

🤝 Contributing

We welcome contributions, feedback, and collaboration. Whether you're interested in adding support for new architectures, improving performance, or integrating FlashHead into your own inference stack -- we'd love to hear from you.

  • Report issues Bug reports and feature requests help us improve. Open an issue.
  • Submit PRs Code contributions for new architectures, optimizations, or bug fixes.
  • Research collaboration Working on efficient inference, vocabulary approximation, or edge deployment? Reach out.
  • Model contributions Publish FlashHead-optimized models to the HuggingFace collection.
  • Benchmarks Run FlashHead on your hardware and submit results to the Edge Inference Benchmarks space.

📂 Project Structure

flash-head/
├── src/flash_head/
│   ├── __init__.py              # Plugin entry point (register)
│   ├── flash_head.py            # Core clustering-based head
│   ├── loading.py               # Model/asset loading from HF Hub
│   └── patches/                 # vLLM runtime patches
│
├── pyproject.toml
└── LICENSE

📖 Citation

If you use FlashHead in your research, please cite:

@article{tranheden2026flashhead,
  title={FlashHead: Efficient Drop-In Replacement for the Classification Head in Language Model Inference},
  author={Tranheden, Wilhelm and Ahmed, Shahnawaz and Dubhashi, Devdatt and Matthiesen, Jonna and von Essen, Hannes},
  journal={arXiv preprint arXiv:2603.14591},
  year={2026}
}

License

Free for non-commercial use within the Embedl Community License (v.1.0).

Interested in FlashHead?

Enterprise licensing, custom model optimization, and engineering support available.

models@embedl.com  •  embedl.com


© 2026 Embedl AB. All rights reserved.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_head-0.1.2.tar.gz (38.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flash_head-0.1.2-py3-none-any.whl (35.8 kB view details)

Uploaded Python 3

File details

Details for the file flash_head-0.1.2.tar.gz.

File metadata

  • Download URL: flash_head-0.1.2.tar.gz
  • Upload date:
  • Size: 38.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for flash_head-0.1.2.tar.gz
Algorithm Hash digest
SHA256 e24b5ac7ebee516672a6bfeac4781fa651d7e20ebe062773a0b2fd300fa59191
MD5 86baf4f84281e48073552b84a9954640
BLAKE2b-256 578629ef2c9917b07158f791b39617fd4610b64e314588f685dd26811d9819bb

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_head-0.1.2.tar.gz:

Publisher: publish.yml on embedl/flash-head

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_head-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: flash_head-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 35.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for flash_head-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 fe77984fa1787a0054fe470dd3263e994c8fe6213f7ed9616a4fdc351714bf54
MD5 e2d1b9af53bb0d2c31b6d88256d1950e
BLAKE2b-256 01089df01cde7a8e081f0003f8c682eb63a55aa06085172dbd62bab2ee4dcf91

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_head-0.1.2-py3-none-any.whl:

Publisher: publish.yml on embedl/flash-head

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page