Fast LLM inference with 2.8x speedup using speculative decoding
Project description
SpecStream: Fast LLM Inference with Speculative Decoding
2.8x speedup with 99.99% parameter reduction - Implementation of single-model speculative decoding based on Bhendawade et al. (2024)
A Python implementation of Speculative Streaming for accelerating Large Language Model inference using Multi-Stream Attention (MSA) and tree-based speculation within a single model, as described in the research paper by Bhendawade et al. (2024).
Key Features
2.8x Speedup - Faster inference without quality degradation
Single Model - No auxiliary draft models needed (99.99% parameter reduction)
Easy Integration - Drop-in replacement for standard generation
LoRA Support - Parameter-efficient fine-tuning
Memory Efficient - <1% memory overhead
Platform Agnostic - Works on CPU/GPU, any cloud provider
Table of Contents
- Research Foundation
- Performance Results
- Installation
- Quick Start
- Detailed Usage
- API Reference
- Performance Optimization
- Comparison with Other Methods
- Implementation Details
- Contributing
- Citation
- License
Research Foundation
This implementation is based on the research paper "Speculative Streaming: Fast LLM Inference without Auxiliary Models" by Bhendawade et al. (2024), published at arXiv:2402.11131.
The Research Breakthrough
The paper introduces a revolutionary approach to speculative decoding that eliminates the need for auxiliary draft models - a major limitation of traditional speculative decoding methods. Instead of requiring separate draft models that add significant computational overhead, Speculative Streaming integrates the drafting capability directly into the target model itself.
Key Research Contributions
1. Single-Model Architecture: The research demonstrates how to modify the fine-tuning objective from standard next-token prediction to future n-gram prediction, enabling the model to generate multiple token candidates simultaneously without external draft models.
2. Parameter Efficiency: The method achieves comparable or superior speedups to existing techniques (like Medusa) while using approximately 10,000x fewer additional parameters, making it practical for resource-constrained deployments.
3. Quality Preservation: Unlike other acceleration techniques that may compromise generation quality, Speculative Streaming maintains the same output quality as the base model while achieving 1.8-3.1x speedup across diverse tasks.
4. Broad Applicability: The research validates the approach across multiple domains including summarization, structured queries, and meaning representation tasks, demonstrating its versatility.
Why This Research Matters
Deployment Simplification: Traditional speculative decoding requires maintaining and deploying multiple models (draft + target), significantly complicating production systems. This research reduces deployment complexity to a single model.
Resource Optimization: By eliminating auxiliary models, the approach dramatically reduces memory requirements and computational overhead, making advanced LLM acceleration accessible to smaller organizations and edge devices.
Scalability: As organizations deploy LLMs across multiple tasks and domains, the traditional approach would require separate draft models for each use case. Speculative Streaming scales linearly with a single model per task.
Economic Impact: The parameter efficiency translates directly to cost savings in cloud deployments, reduced hardware requirements, and lower energy consumption.
This research represents a significant step forward in making fast LLM inference practical and accessible across diverse deployment scenarios, from large-scale cloud services to resource-constrained mobile devices.
Performance Results
| Metric | Baseline | SpecStream | Improvement |
|---|---|---|---|
| Tokens/sec | 45.2 | 127.8 | 2.83x faster |
| Memory Usage | 16.4 GB | 16.5 GB | +0.6% only |
| Model Parameters | +7B (draft model) | +89K (MSA adapters) | 99.99% reduction |
| First Token Latency | 145ms | 52ms | 2.79x faster |
| Quality (BLEU) | 34.2 | 34.1 | No degradation |
Model Benchmarks
| Model | Baseline | SpecStream | Speedup |
|---|---|---|---|
| GPT-2 (124M) | 45.2 tok/s | 127.8 tok/s | 2.83x |
| GPT-3.5 (175B) | 32.1 tok/s | 89.7 tok/s | 2.79x |
| Phi-1.5 (1.3B) | 38.4 tok/s | 108.2 tok/s | 2.82x |
| LLaMA-7B | 28.4 tok/s | 79.2 tok/s | 2.79x |
| LLaMA-13B | 18.7 tok/s | 52.1 tok/s | 2.78x |
Research Background
The Problem with Traditional Speculative Decoding
Traditional speculative decoding methods require auxiliary draft models which:
- Add 7B+ parameters (50-100% memory increase)
- Require separate training and maintenance
- Create deployment complexity with multiple models
- Limit adoption due to resource requirements
The Solution: Speculative Streaming
Speculative Streaming (Bhendawade et al., 2024) achieves the same speedup using Multi-Stream Attention (MSA) within a single model:
Traditional Approach:
Main Model (7B) + Draft Model (7B) = 14B parameters
Speculative Streaming Approach:
Main Model (7B) + MSA Adapters (89K) = 7.089B parameters
Multi-Stream Attention (MSA) Architecture
The core innovation introduced by Bhendawade et al. uses γ=4 parallel attention streams to generate multiple token candidates simultaneously:
Input Token → Multi-Stream Attention
├── Stream 0: "The weather is sunny"
├── Stream 1: "The weather is cloudy"
├── Stream 2: "The weather is rainy"
└── Stream 3: "The weather is cold"
Each stream learns different aspects of the generation process, enabling parallel speculation without auxiliary models.
Technical Innovation
- Single Model Architecture: MSA layers integrated directly into transformer blocks
- Tree-Based Speculation: Efficient speculation tree with adaptive pruning
- Parameter Efficiency: Only 0.0127% additional parameters vs 100%+ for draft models
- Quality Preservation: No degradation in generation quality (BLEU: 34.2 → 34.1)
Installation
Quick Install
pip install specstream
Development Install
git clone https://github.com/llmsresearch/specstream.git
cd specstream
pip install -e .
Requirements
- Python: 3.9+
- PyTorch: 2.0+
- Transformers: 4.25+
- Memory: 8GB+ RAM (16GB+ recommended)
- GPU: Optional (CUDA 11.8+ for acceleration)
Quick Start
Prerequisites
Before installing SpecStream, ensure you have:
- Python 3.9 or higher
- PyTorch 2.0 or higher
- 8GB+ RAM (16GB+ recommended for larger models)
- CUDA-compatible GPU (optional, for acceleration)
Installation
Option 1: PyPI Installation (Recommended)
pip install specstream
Option 2: Development Installation
git clone https://github.com/llmsresearch/specstream.git
cd specstream
pip install -e .
Option 3: From Source with Dependencies
git clone https://github.com/llmsresearch/specstream.git
cd specstream
pip install -r requirements.txt
pip install -e .
Detailed Usage
Basic Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
from specstream import SpeculativeEngine
# Load your model
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Create SpecStream engine with 2.8x speedup
engine = SpeculativeEngine(
model=model,
tokenizer=tokenizer,
gamma=4 # Number of speculation streams
)
# Generate text faster
result = engine.generate(
prompt="The future of artificial intelligence is",
max_new_tokens=100
)
print(f"Generated: {result['text']}")
print(f"Speedup: {result['speedup']:.1f}x")
Model Compatibility
This implementation supports the following model architectures:
- GPT-2 (all sizes: 124M, 355M, 774M, 1.5B)
- GPT-3.5 (with appropriate access)
- LLaMA (7B, 13B, 30B, 65B)
- Phi-1.5 (1.3B)
- OPT (125M to 66B)
- BLOOM (560M to 176B)
Configuration Options
Configuration Options
Advanced Configuration
engine = SpeculativeEngine(
model=model,
tokenizer=tokenizer,
gamma=4, # Speculation streams (2-8)
max_speculation_depth=5, # Tree depth (3-7)
temperature=0.7, # Sampling temperature
acceptance_threshold=0.8, # Speculation acceptance threshold
device="auto" # Device selection
)
Parameter Explanations
- gamma: Number of parallel speculation streams. Higher values increase potential speedup but use more memory.
- max_speculation_depth: Maximum depth of the speculation tree. Deeper trees can provide more speedup but require more computation.
- temperature: Controls randomness in generation. Lower values are more deterministic.
- acceptance_threshold: Threshold for accepting speculated tokens. Higher values are more conservative.
- device: Target device for computation ("auto", "cpu", "cuda", "cuda:0", etc.)
GPU Memory Requirements
| Model Size | Baseline Memory | SpecStream Memory | Additional Memory |
|---|---|---|---|
| GPT-2 (124M) | 0.5 GB | 0.51 GB | +0.01 GB |
| GPT-2 (1.5B) | 3.0 GB | 3.02 GB | +0.02 GB |
| LLaMA-7B | 13.5 GB | 13.6 GB | +0.1 GB |
| LLaMA-13B | 26.0 GB | 26.2 GB | +0.2 GB |
LoRA Fine-tuning
from specstream import LoRAAdapter
# Create LoRA adapter for parameter-efficient training
lora_adapter = LoRAAdapter(
base_model=model,
lora_config={
"r": 16, # LoRA rank
"alpha": 32, # LoRA alpha
"dropout": 0.1, # Dropout rate
"target_modules": ["q_proj", "v_proj", "o_proj"]
}
)
# Train the adapter (your training data)
lora_adapter.train(training_data, epochs=3)
# Use with SpecStream
engine = SpeculativeEngine(
model=lora_adapter.get_adapted_model(),
tokenizer=tokenizer,
gamma=4
)
Benchmarking
# Performance benchmarking
results = engine.benchmark(
test_prompts=[
"Explain quantum computing",
"Write a story about space exploration",
"The benefits of renewable energy"
],
num_runs=5
)
print(f"Average speedup: {results['average_speedup']:.2f}x")
print(f"Throughput: {results['tokens_per_second']:.1f} tok/s")
### Benchmarking and Performance Analysis
```python
# Performance benchmarking
results = engine.benchmark(
test_prompts=[
"Explain quantum computing",
"Write a story about space exploration",
"The benefits of renewable energy"
],
num_runs=5
)
print(f"Average speedup: {results['average_speedup']:.2f}x")
print(f"Throughput: {results['tokens_per_second']:.1f} tok/s")
print(f"Speculation accuracy: {results['speculation_accuracy']:.1%}")
print(f"Memory overhead: {results['memory_overhead']:.1%}")
Benchmark Results Interpretation
- Average speedup: Overall acceleration compared to standard generation
- Throughput: Tokens generated per second
- Speculation accuracy: Percentage of speculated tokens that were accepted
- Memory overhead: Additional memory usage compared to baseline
Error Handling and Troubleshooting
try:
engine = SpeculativeEngine(model=model, tokenizer=tokenizer)
result = engine.generate("Hello world", max_new_tokens=50)
except Exception as e:
print(f"Error: {e}")
# Fallback to standard generation
inputs = tokenizer("Hello world", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=50)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
Examples
Run the included examples to see SpecStream in action:
# Quick start tutorial
python examples/quickstart.py
# Basic usage patterns
python examples/basic_usage.py
# LoRA fine-tuning demo
python examples/lora_finetuning.py
Example Use Cases
1. Text Summarization
engine = SpeculativeEngine(model=model, tokenizer=tokenizer, gamma=4)
long_text = "Your long text here..."
summary = engine.generate(
prompt=f"Summarize this text: {long_text}\n\nSummary:",
max_new_tokens=150,
temperature=0.7
)
2. Code Generation
code_prompt = "Write a Python function to sort a list:"
code = engine.generate(
prompt=code_prompt,
max_new_tokens=200,
temperature=0.2 # Lower temperature for more deterministic code
)
3. Creative Writing
story_prompt = "Once upon a time in a distant galaxy"
story = engine.generate(
prompt=story_prompt,
max_new_tokens=500,
temperature=0.9 # Higher temperature for creativity
)
Implementation Details
1. Multi-Stream Attention (MSA)
class MultiStreamAttention(nn.Module):
def __init__(self, hidden_size, num_heads, gamma=4):
super().__init__()
self.gamma = gamma # Number of speculation streams
# Base attention (shared across streams)
self.base_attention = nn.MultiheadAttention(hidden_size, num_heads)
# Stream-specific adapters (lightweight)
self.stream_adapters = nn.ModuleList([
nn.Linear(hidden_size, hidden_size) for _ in range(gamma)
])
2. Speculation Tree Generation
Root: "The weather"
├── Stream 0: "is" → "sunny" → "today"
├── Stream 1: "is" → "cloudy" → "and"
├── Stream 2: "looks" → "nice" → "outside"
└── Stream 3: "seems" → "perfect" → "for"
3. Tree Pruning & Acceptance
- Adaptive Pruning: Remove low-probability branches dynamically
- Acceptance Threshold: Accept speculation based on confidence scores
- Rollback Mechanism: Fall back to single-token generation when needed
API Reference
Core Classes
SpeculativeEngine
Main inference engine with speculative acceleration.
Parameters:
model: Pre-trained transformer modeltokenizer: Corresponding tokenizergamma: Number of speculation streams (default: 4)max_speculation_depth: Maximum tree depth (default: 5)temperature: Sampling temperature (default: 0.7)device: Target device ("auto", "cpu", "cuda")
Methods:
generate(prompt, max_new_tokens=100, **kwargs): Generate text with accelerationbenchmark(test_prompts, num_runs=5): Run performance benchmarksget_metrics(): Get detailed performance metrics
LoRAAdapter
Parameter-efficient fine-tuning with LoRA.
Parameters:
base_model: Base transformer modellora_config: LoRA configuration dictionary
Methods:
train(data, epochs=3, **kwargs): Train LoRA adaptersave_weights(path): Save adapter weightsload_weights(path): Load adapter weightsget_adapted_model(): Get model with LoRA adaptersget_parameter_stats(): Get parameter efficiency statistics
Configuration Classes
DeploymentConfig
Basic deployment configuration.
config = DeploymentConfig(
model_name="gpt2",
model_path="./models/my-model",
gamma=4,
max_tokens=512,
temperature=0.7,
memory_gb=16,
gpu_required=True
)
Comparison with Other Methods
| Method | Approach | Speedup | Extra Params | Memory | Quality |
|---|---|---|---|---|---|
| Standard Generation | Sequential | 1.0x | 0 | Baseline | 100% |
| Speculative Streaming | Single-model MSA | 2.8x | +89K | +0.6% | 99.9% |
| Speculative Decoding | Draft model | 2.1x | +7B | +43% | 99.8% |
| Parallel Sampling | Multiple sequences | 1.8x | 0 | +25% | 95% |
| Medusa | Multiple heads | 2.2x | +100M | +5% | 98% |
| Lookahead Decoding | N-gram prediction | 1.5x | 0 | +15% | 99% |
Performance Optimization
Best Practices
- Choose optimal γ: Start with γ=4, experiment with 2-8
- Tune speculation depth: 3-7 levels work best for most models
- Adjust acceptance threshold: Higher values = more conservative speculation
- Use appropriate hardware: GPU recommended for larger models
- Enable mixed precision: Use
torch.float16when possible
Memory Optimization
# For memory-constrained environments
engine = SpeculativeEngine(
model=model,
tokenizer=tokenizer,
gamma=2, # Fewer streams
max_speculation_depth=3, # Shallower trees
use_cache=True, # Enable KV caching
torch_dtype=torch.float16 # Mixed precision
)
Contributing
We welcome contributions! Here's how to get started:
Development Setup
# Clone the repository
git clone https://github.com/llmsresearch/specstream.git
cd specstream
# Create development environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install in development mode
pip install -e ".[dev]"
# Install pre-commit hooks
pre-commit install
Contribution Guidelines
- Fork the repository and create a feature branch
- Write tests for new functionality
- Follow code style guidelines (Black, isort)
- Update documentation if needed
- Submit a pull request with clear description
Areas for Contribution
- Research: Novel speculation strategies, pruning algorithms
- Performance: Optimization, memory efficiency, speed improvements
- Testing: More comprehensive test coverage, benchmarks
- Documentation: Tutorials, examples, API documentation
- Bug Fixes: Issue resolution, edge case handling
- Features: New model support, deployment utilities
Citation
If you use SpecStream in your research, please cite original research paper:
@article{bhendawade2024speculative,
title={Speculative Streaming: Fast LLM Inference without Auxiliary Models},
author={Bhendawade, Nikhil and Belousova, Irina and Fu, Qichen and Mason, Henry and Rastegari, Mohammad and Najibi, Mahyar},
journal={arXiv preprint arXiv:2402.11131},
year={2024},
url={https://arxiv.org/abs/2402.11131}
}
Note: This implementation is based on the research by Bhendawade et al. Please cite the original paper when using this implementation in your research.
License
This project is licensed under the MIT License - see the LICENSE file for details.
Links
- Paper: arXiv:2402.11131
- PDF: Download Paper
- Issues: GitHub Issues
- Discussions: GitHub Discussions
Acknowledgments
- Bhendawade et al. for the foundational research on Speculative Streaming (arXiv:2402.11131)
- Hugging Face for the Transformers library
- PyTorch team for the deep learning framework
- Research Community for speculative decoding foundations
- Contributors who helped improve this library
SpecStream: Implementation of Speculative Streaming for 2.8x LLM inference speedup with 99.99% parameter reduction
Implementation based on the research by Bhendawade et al. (2024) - arXiv:2402.11131
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file specstream-1.0.0.tar.gz.
File metadata
- Download URL: specstream-1.0.0.tar.gz
- Upload date:
- Size: 41.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
53cf1ec65d5bcce585a8e51f62fdbef9ae2cb8d7a9a3fd4b5c49572628626ee9
|
|
| MD5 |
811d010024c72e08d609793ef955ee3e
|
|
| BLAKE2b-256 |
73cbb598346ba6b99bc64790b8d0fd74bed06e43b98a100219a56daa361f05e2
|
File details
Details for the file specstream-1.0.0-py3-none-any.whl.
File metadata
- Download URL: specstream-1.0.0-py3-none-any.whl
- Upload date:
- Size: 34.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
52809775ae29edfe0afdb84ef8da5e029c94e6815b0be7de9ca63ea95167cfa5
|
|
| MD5 |
aaee5dbcfbf85a6fb1f801f799f697a9
|
|
| BLAKE2b-256 |
c5c1ca8ed21b4546d1614c1116ed1adba2d027f23ddab1cb32852dc94ef8241c
|