Ultra-fast parallel training and inference for language models
Project description
Parallel-LLM: Ultra-Fast Parallel Training & Inference
Parallel-LLM is a production-ready library for training and inference of language models with revolutionary parallel token generation. Generate all tokens at once instead of one-by-one using our hybrid diffusion-energy architecture.
๐ Key Features
Training
- Full Parallelism: Data + Tensor + Pipeline + Expert parallelism
- FSDP2: PyTorch's latest fully sharded data parallel with DTensor
- DeepSpeed ZeRO: Stages 1, 2, 3 with CPU offloading
- Flash Attention 3: Up to 75% GPU utilization on H100
- torch.compile: Automatic kernel fusion and optimization
- Mixed Precision: FP16, BF16, FP8 support
- Gradient Checkpointing: Selective activation checkpointing
Inference
- Parallel Generation: Generate 64+ tokens simultaneously
- 1.5-3ร Faster: Compared to autoregressive decoding
- Paged KV Cache: Memory-efficient attention like vLLM
- CUDA Graphs: Zero CPU overhead
- Continuous Batching: Dynamic request handling
- Speculative Decoding: Draft model verification
Multimodal
- Vision-Language Models: CLIP-style contrastive learning
- Cross-Modal Fusion: Attention-based alignment
- Unified Architecture: Single model for text + vision
๐ฆ Installation
pip install parallel-llm
From Source
git clone https://github.com/furqan-y-khan/parallel-llm
cd parallel-llm
pip install -e .
Requirements
- Python >= 3.9
- PyTorch >= 2.2.0
- CUDA >= 11.8 (for GPU)
- 16GB+ GPU memory recommended
๐ฅ Quick Start
Training a Unimodal LLM
import torch
from parallel_llm import DiffusionTransformer, ModelConfig, TrainingConfig, DistributedTrainer
# Configure model
model_config = ModelConfig(
vocab_size=50257,
hidden_size=2048,
num_hidden_layers=24,
num_attention_heads=16,
use_flash_attention=True,
)
# Create model
model = DiffusionTransformer(model_config)
# Configure training
train_config = TrainingConfig(
batch_size=8,
learning_rate=3e-4,
use_fsdp=True,
fsdp_sharding_strategy="full",
mixed_precision="bf16",
use_torch_compile=True,
torch_compile_mode="max-autotune",
)
# Create trainer
trainer = DistributedTrainer(
model=model,
train_config=train_config,
model_config=model_config,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
)
# Train!
trainer.train()
Parallel Generation (Inference)
from parallel_llm import ParallelGenerator, GenerationConfig
# Configure generation
gen_config = GenerationConfig(
max_new_tokens=512,
temperature=1.0,
num_refinement_steps=5,
confidence_threshold=0.9,
)
# Create generator
generator = ParallelGenerator(model, gen_config, use_cuda_graphs=True)
# Generate (all 512 tokens in ~5 forward passes!)
prompt = torch.tensor([[1, 2, 3, 4, 5]]) # Your prompt tokens
generated = generator.generate(prompt)
print(f"Generated {generated.shape[1]} tokens")
Multimodal Training
from parallel_llm import MultimodalModel, MultimodalConfig
# Configure multimodal model
config = MultimodalConfig(
# Text config
vocab_size=50257,
hidden_size=2048,
num_hidden_layers=24,
# Vision config
vision_encoder="clip",
image_size=224,
patch_size=16,
vision_hidden_size=1024,
# Fusion
fusion_type="cross_attention",
use_contrastive=True,
)
# Create model
model = MultimodalModel(config)
# Train with image-text pairs
# ... (similar to unimodal training)
๐๏ธ Architecture
Hybrid Diffusion-Energy Framework
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Input: [MASK] [MASK] [MASK] ... [MASK] โ
โโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Diffusion Transformer โ
โ (Bidirectional Attention) โ
โโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Multi-Token Predictions โ
โ With Confidence Scores โ
โโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Energy-Based Refinement โ
โ (Sequence-Level Scoring) โ
โโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Adaptive Masking โ
โ (Keep high-confidence) โ
โโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โ
Output: All tokens generated
Key Innovations
- Masked Diffusion: Start with all [MASK] tokens, iteratively refine
- Bidirectional Attention: Each token sees entire context
- Confidence-Based Masking: Adaptively accept high-confidence predictions
- Energy Model: Global sequence coherence checking
- Parallel Decoding: 64+ tokens per forward pass
๐ Performance
Speed Comparison (Llama-7B equivalent)
| Method | Tokens/sec | Speedup |
|---|---|---|
| Autoregressive (HF) | 25 | 1.0ร |
| vLLM | 45 | 1.8ร |
| Parallel-LLM | 75 | 3.0ร |
Memory Efficiency
| Batch Size | Standard | Parallel-LLM |
|---|---|---|
| 1 | 16 GB | 12 GB |
| 8 | 128 GB | 48 GB |
| 32 | OOM | 96 GB |
๐ ๏ธ Advanced Features
Distributed Training
# Launch with torchrun
torchrun --nproc_per_node=8 train.py \
--use-fsdp \
--fsdp-sharding-strategy full \
--tensor-parallel-size 1 \
--pipeline-parallel-size 1
Custom Kernels
from parallel_llm.kernels import fused_attention, parallel_decode
# Use optimized Triton kernels
output = fused_attention(query, key, value, use_flash=True)
# Parallel token decoding
tokens = parallel_decode(logits, num_parallel=64)
Quantization
from parallel_llm.quantization import quantize_model
# Quantize to INT8 or FP8
model = quantize_model(model, precision="fp8")
๐ Documentation
๐ค Contributing
We welcome contributions! See CONTRIBUTING.md for guidelines.
๐ License
Apache 2.0 License. See LICENSE for details.
๐ Acknowledgments
Built on research from:
- FlashAttention (Dao et al.)
- Diffusion Language Models (various)
- DeepSpeed ZeRO (Microsoft)
- vLLM (UC Berkeley)
- PyTorch FSDP (Meta)
๐ Contact
- Email: furqan@lastappstanding.com
๐ Star History
If you find this project useful, please give it a star! โญ
Citation
@software{parallel_llm,
title = {Parallel-LLM: Ultra-Fast Parallel Training and Inference},
author = {Last App Standing Team},
year = {2025},
url = {https://github.com/furqan-y-khan/parallel-llm}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file parallel_llm-0.1.0.tar.gz.
File metadata
- Download URL: parallel_llm-0.1.0.tar.gz
- Upload date:
- Size: 30.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.0b4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
11ed976ec83447d6728adfd90ea9af9b2632a7ca569c1ad8d8c10de413bcb21a
|
|
| MD5 |
4fc097011474ad5e205df10764d71ec1
|
|
| BLAKE2b-256 |
73555c70b42363a0551066cc7b4b8029bb0754704712010b09a658ffbc2c4339
|
File details
Details for the file parallel_llm-0.1.0-py3-none-any.whl.
File metadata
- Download URL: parallel_llm-0.1.0-py3-none-any.whl
- Upload date:
- Size: 21.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.0b4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
66363d35e90747a92d261073f50ffb5b1e914e6aa32e6bb175bcd96a764ab58e
|
|
| MD5 |
5356f9be1205e1f29957bd03d884260e
|
|
| BLAKE2b-256 |
396975e6feaf376fbe87b1fafee442990021a81399c4eef24ea323143bb85991
|