Skip to main content

Ultra-fast parallel training and inference for language models

Project description

Parallel-LLM: Ultra-Fast Parallel Training & Inference

PyPI version License Python 3.9+

Parallel-LLM is a production-ready library for training and inference of language models with revolutionary parallel token generation. Generate all tokens at once instead of one-by-one using our hybrid diffusion-energy architecture.

๐Ÿš€ Key Features

Training

  • Full Parallelism: Data + Tensor + Pipeline + Expert parallelism
  • FSDP2: PyTorch's latest fully sharded data parallel with DTensor
  • DeepSpeed ZeRO: Stages 1, 2, 3 with CPU offloading
  • Flash Attention 3: Up to 75% GPU utilization on H100
  • torch.compile: Automatic kernel fusion and optimization
  • Mixed Precision: FP16, BF16, FP8 support
  • Gradient Checkpointing: Selective activation checkpointing

Inference

  • Parallel Generation: Generate 64+ tokens simultaneously
  • 1.5-3ร— Faster: Compared to autoregressive decoding
  • Paged KV Cache: Memory-efficient attention like vLLM
  • CUDA Graphs: Zero CPU overhead
  • Continuous Batching: Dynamic request handling
  • Speculative Decoding: Draft model verification

Multimodal

  • Vision-Language Models: CLIP-style contrastive learning
  • Cross-Modal Fusion: Attention-based alignment
  • Unified Architecture: Single model for text + vision

๐Ÿ“ฆ Installation

pip install parallel-llm

From Source

git clone https://github.com/furqan-y-khan/parallel-llm
cd parallel-llm
pip install -e .

Requirements

  • Python >= 3.9
  • PyTorch >= 2.2.0
  • CUDA >= 11.8 (for GPU)
  • 16GB+ GPU memory recommended

๐Ÿ”ฅ Quick Start

Training a Unimodal LLM

import torch
from parallel_llm import DiffusionTransformer, ModelConfig, TrainingConfig, DistributedTrainer

# Configure model
model_config = ModelConfig(
    vocab_size=50257,
    hidden_size=2048,
    num_hidden_layers=24,
    num_attention_heads=16,
    use_flash_attention=True,
)

# Create model
model = DiffusionTransformer(model_config)

# Configure training
train_config = TrainingConfig(
    batch_size=8,
    learning_rate=3e-4,
    use_fsdp=True,
    fsdp_sharding_strategy="full",
    mixed_precision="bf16",
    use_torch_compile=True,
    torch_compile_mode="max-autotune",
)

# Create trainer
trainer = DistributedTrainer(
    model=model,
    train_config=train_config,
    model_config=model_config,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
)

# Train!
trainer.train()

Parallel Generation (Inference)

from parallel_llm import ParallelGenerator, GenerationConfig

# Configure generation
gen_config = GenerationConfig(
    max_new_tokens=512,
    temperature=1.0,
    num_refinement_steps=5,
    confidence_threshold=0.9,
)

# Create generator
generator = ParallelGenerator(model, gen_config, use_cuda_graphs=True)

# Generate (all 512 tokens in ~5 forward passes!)
prompt = torch.tensor([[1, 2, 3, 4, 5]])  # Your prompt tokens
generated = generator.generate(prompt)

print(f"Generated {generated.shape[1]} tokens")

Multimodal Training

from parallel_llm import MultimodalModel, MultimodalConfig

# Configure multimodal model
config = MultimodalConfig(
    # Text config
    vocab_size=50257,
    hidden_size=2048,
    num_hidden_layers=24,

    # Vision config
    vision_encoder="clip",
    image_size=224,
    patch_size=16,
    vision_hidden_size=1024,

    # Fusion
    fusion_type="cross_attention",
    use_contrastive=True,
)

# Create model
model = MultimodalModel(config)

# Train with image-text pairs
# ... (similar to unimodal training)

๐Ÿ—๏ธ Architecture

Hybrid Diffusion-Energy Framework

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚  Input: [MASK] [MASK] [MASK] ... [MASK] โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                โ†“
    โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
    โ”‚  Diffusion Transformer     โ”‚
    โ”‚  (Bidirectional Attention) โ”‚
    โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                โ†“
    โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
    โ”‚  Multi-Token Predictions   โ”‚
    โ”‚  With Confidence Scores    โ”‚
    โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                โ†“
    โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
    โ”‚  Energy-Based Refinement   โ”‚
    โ”‚  (Sequence-Level Scoring)  โ”‚
    โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                โ†“
    โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
    โ”‚  Adaptive Masking          โ”‚
    โ”‚  (Keep high-confidence)    โ”‚
    โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                โ†“
    Output: All tokens generated

Key Innovations

  1. Masked Diffusion: Start with all [MASK] tokens, iteratively refine
  2. Bidirectional Attention: Each token sees entire context
  3. Confidence-Based Masking: Adaptively accept high-confidence predictions
  4. Energy Model: Global sequence coherence checking
  5. Parallel Decoding: 64+ tokens per forward pass

๐Ÿ“Š Performance

Speed Comparison (Llama-7B equivalent)

Method Tokens/sec Speedup
Autoregressive (HF) 25 1.0ร—
vLLM 45 1.8ร—
Parallel-LLM 75 3.0ร—

Memory Efficiency

Batch Size Standard Parallel-LLM
1 16 GB 12 GB
8 128 GB 48 GB
32 OOM 96 GB

๐Ÿ› ๏ธ Advanced Features

Distributed Training

# Launch with torchrun
torchrun --nproc_per_node=8 train.py \
    --use-fsdp \
    --fsdp-sharding-strategy full \
    --tensor-parallel-size 1 \
    --pipeline-parallel-size 1

Custom Kernels

from parallel_llm.kernels import fused_attention, parallel_decode

# Use optimized Triton kernels
output = fused_attention(query, key, value, use_flash=True)

# Parallel token decoding
tokens = parallel_decode(logits, num_parallel=64)

Quantization

from parallel_llm.quantization import quantize_model

# Quantize to INT8 or FP8
model = quantize_model(model, precision="fp8")

๐Ÿ“š Documentation

๐Ÿค Contributing

We welcome contributions! See CONTRIBUTING.md for guidelines.

๐Ÿ“„ License

Apache 2.0 License. See LICENSE for details.

๐Ÿ™ Acknowledgments

Built on research from:

  • FlashAttention (Dao et al.)
  • Diffusion Language Models (various)
  • DeepSpeed ZeRO (Microsoft)
  • vLLM (UC Berkeley)
  • PyTorch FSDP (Meta)

๐Ÿ“ž Contact

๐ŸŒŸ Star History

If you find this project useful, please give it a star! โญ

Citation

@software{parallel_llm,
  title = {Parallel-LLM: Ultra-Fast Parallel Training and Inference},
  author = {Last App Standing Team},
  year = {2025},
  url = {https://github.com/furqan-y-khan/parallel-llm}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

parallel_llm-0.1.0.tar.gz (30.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

parallel_llm-0.1.0-py3-none-any.whl (21.5 kB view details)

Uploaded Python 3

File details

Details for the file parallel_llm-0.1.0.tar.gz.

File metadata

  • Download URL: parallel_llm-0.1.0.tar.gz
  • Upload date:
  • Size: 30.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.0b4

File hashes

Hashes for parallel_llm-0.1.0.tar.gz
Algorithm Hash digest
SHA256 11ed976ec83447d6728adfd90ea9af9b2632a7ca569c1ad8d8c10de413bcb21a
MD5 4fc097011474ad5e205df10764d71ec1
BLAKE2b-256 73555c70b42363a0551066cc7b4b8029bb0754704712010b09a658ffbc2c4339

See more details on using hashes here.

File details

Details for the file parallel_llm-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: parallel_llm-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 21.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.0b4

File hashes

Hashes for parallel_llm-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 66363d35e90747a92d261073f50ffb5b1e914e6aa32e6bb175bcd96a764ab58e
MD5 5356f9be1205e1f29957bd03d884260e
BLAKE2b-256 396975e6feaf376fbe87b1fafee442990021a81399c4eef24ea323143bb85991

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page