Skip to main content

Monkey-patch PyTorch with optimized CuTile kernels for training (forward + backward)

Project description

Bastile

Drop-in monkey-patch that replaces HuggingFace Qwen3 ops with optimized CuTile kernels for training on NVIDIA Blackwell GPUs.

Requires: NVIDIA Blackwell (B200 / B100) + CUDA Toolkit 13.1+

Benchmarks

Qwen3-8B (36 layers, 4096 hidden, 32 heads) — single B200, batch_size=1, bf16, AdamW:

Throughput (tokens/sec)

Throughput

Peak GPU Memory (GB)

Memory

Latency (ms/iter)

Latency

Bastile's fused linear cross-entropy avoids materializing the full [batch * seq_len, vocab_size] logits tensor, which is the dominant memory cost at longer sequences. This is where the memory savings and throughput gains compound.

Installation

pip install bastile

Prerequisites:

  • NVIDIA Blackwell GPU (B200, B100, GB200 or RTX 50 Series Chip)
  • CUDA Toolkit 13.1+
  • PyTorch 2.9+ with CUDA support
# Inside a CUDA 13.1 container:
pip install bastile

Quick Start

import bastile

# Apply all patches BEFORE loading / creating the model
bastile.apply()

from transformers import Qwen3ForCausalLM

model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
model.train()

# Train as usual — Bastile automatically uses optimized kernels

Selective patching:

bastile.apply(
    rms_norm=True,                   # CuTile RMSNorm
    swiglu=True,                     # CuTile SwiGLU MLP
    rope=True,                       # CuTile RoPE
    fused_linear_cross_entropy=True,  # CuTile fused linear + CE
)

Reset to original HuggingFace implementations:

bastile.reset()

Kernel Implementations

Bastile patches 4 operations in transformers.models.qwen3.modeling_qwen3. Each has a full forward and backward pass:

Operation Backend Source What it replaces
RMSNorm CuTile ops/rms_norm.py Qwen3RMSNorm
SwiGLU MLP CuTile ops/swiglu.py Qwen3MLP
RoPE CuTile ops/rope.py apply_rotary_pos_emb
Fused Linear Cross-Entropy CuTile ops/fused_linear_cross_entropy.py Qwen3ForCausalLM.forward

RMSNorm — CuTile

Native CuTile RMSNorm with persistent forward and backward kernels. Uses gather/scatter memory access with SM-aware tile sizing.

src/bastile/ops/rms_norm.py → patches Qwen3RMSNorm

SwiGLU MLP — CuTile

Native CuTile kernels using cuda.tile with gather/scatter memory access. Uses flush_to_zero and approximate reciprocal for fast sigmoid on Blackwell. Full backward with recomputation (no extra activation memory).

src/bastile/ops/swiglu.py → patches Qwen3MLP

RoPE — CuTile

CuTile rotary position embedding with compiler-optimized occupancy. In-place rotation on reshaped tensors to minimize memory traffic. Backward reuses the forward kernel by negating sin (inverse rotation identity).

src/bastile/ops/rope.py → patches apply_rotary_pos_emb

Fused Linear Cross-Entropy — CuTile

Replaces the standard lm_head(hidden_states) → logits → cross_entropy(logits, labels) pipeline with a chunked approach. This never materializes the full logits tensor ([batch * seq, 151936] for Qwen3), instead computing cross-entropy in chunks. This is the single biggest memory saver at long sequence lengths.

src/bastile/ops/fused_linear_cross_entropy.py → patches Qwen3ForCausalLM.forward

API Reference

import bastile

bastile.apply()                  # Patch all ops
bastile.apply(rope=False)        # Patch everything except RoPE
bastile.reset()                  # Restore original implementations
bastile.get_patched_ops()        # List currently active patches
bastile.warmup_all_kernels()     # Pre-compile kernels (avoids JIT lag)
bastile.clear_autotune_cache()   # Clear kernel caches

Running Benchmarks

# Qwen3-8B sequence length sweep (parallel on 3 GPUs)
make bench-8b

# Qwen3-8B sweep (sequential, single GPU)
make bench-8b-seq

# Qwen3-8B FSDP multi-GPU benchmark
make bench-fsdp

# Kernel micro-benchmarks
make bench-rmsnorm
make bench-lce

Why CuTile instead of Triton?

Bastile uses NVIDIA's CuTile (cuda.tile) instead of Triton. On Blackwell (sm_100), CuTile generates native PTX through NVIDIA's own compiler toolchain, while Triton's code generation for sm_100 is still maturing. In our benchmarks, Triton-based kernels (Liger) often underperform raw PyTorch on B200, whereas CuTile kernels consistently match or beat it.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bastile-0.1.3.tar.gz (17.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bastile-0.1.3-py3-none-any.whl (21.2 kB view details)

Uploaded Python 3

File details

Details for the file bastile-0.1.3.tar.gz.

File metadata

  • Download URL: bastile-0.1.3.tar.gz
  • Upload date:
  • Size: 17.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for bastile-0.1.3.tar.gz
Algorithm Hash digest
SHA256 d2cf40bb79a7f348f2bd9f693e4f0a916f88089ac3dbe5caf1ee3d875acb2ec4
MD5 b5e31573a444fc567fe66816d4b1f46c
BLAKE2b-256 3dce9f8cc9da67dd889be46bfb72dc66a1ebddf0b21910cc81a6b3d51d278b4f

See more details on using hashes here.

Provenance

The following attestation bundles were made for bastile-0.1.3.tar.gz:

Publisher: publish.yml on aghilann/bastile

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file bastile-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: bastile-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 21.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for bastile-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 f4e5d183d50c6c3839c508aa0f864ef45049bddbfe0d7f83c5de11e8d17ce5cf
MD5 87f1ef2448fbe485e217b49211bf8061
BLAKE2b-256 a011c6ec0b926faee0c593c0d21c2af5e5bc8a21eb963233e099f37afb1acc12

See more details on using hashes here.

Provenance

The following attestation bundles were made for bastile-0.1.3-py3-none-any.whl:

Publisher: publish.yml on aghilann/bastile

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page