Skip to main content

Monkey-patch PyTorch with optimized CuTile kernels for training (forward + backward)

Project description

Bastile

Drop-in monkey-patch that replaces HuggingFace Qwen3 ops with optimized CuTile and cuteDSL kernels for training on NVIDIA Blackwell GPUs.

Requires: NVIDIA Blackwell (B200 / B100) + CUDA Toolkit 13.1+

Benchmarks

Qwen3-8B (36 layers, 4096 hidden, 32 heads) — single B200, batch_size=1, bf16, AdamW:

Throughput (tokens/sec)

Throughput

Peak GPU Memory (GB)

Memory

Latency (ms/iter)

Latency

Bastile's fused linear cross-entropy avoids materializing the full [batch * seq_len, vocab_size] logits tensor, which is the dominant memory cost at longer sequences. This is where the memory savings and throughput gains compound.

Installation

pip install bastile

Prerequisites:

  • NVIDIA Blackwell GPU (B200, B100, GB200)
  • CUDA Toolkit 13.1+
  • PyTorch 2.4+ with CUDA support
# Inside a CUDA 13.1 container (e.g. baseten/gpu-dev:v8-cu13_1):
pip install bastile

Quick Start

import bastile

# Apply all patches BEFORE loading / creating the model
bastile.apply()

from transformers import Qwen3ForCausalLM

model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
model.train()

# Train as usual — Bastile automatically uses optimized kernels

Selective patching:

bastile.apply(
    rms_norm=True,                   # cuteDSL RMSNorm (via quack)
    swiglu=True,                     # CuTile SwiGLU MLP
    rope=True,                       # CuTile RoPE with autotuning
    fused_linear_cross_entropy=True,  # Fused linear + CE (via quack)
)

Reset to original HuggingFace implementations:

bastile.reset()

Kernel Implementations

Bastile patches 4 operations in transformers.models.qwen3.modeling_qwen3. Each has a full forward and backward pass:

Operation Backend Source What it replaces
RMSNorm cuteDSL (via quack) ops/rms_norm.py Qwen3RMSNorm
SwiGLU MLP CuTile (cuda.tile) ops/swiglu.py Qwen3MLP
RoPE CuTile (cuda.tile) ops/rope.py apply_rotary_pos_emb
Fused Linear Cross-Entropy cuteDSL (via quack) ops/fused_linear_cross_entropy.py Qwen3ForCausalLM.forward

RMSNorm — cuteDSL

Wraps quack's compiled cuteDSL kernels with reduced CPU dispatch overhead. Bypasses torch.library.custom_op dispatch by directly invoking the compiled kernel from a lookup cache, and caches SM counts to avoid repeated queries.

src/bastile/ops/rms_norm.py → patches Qwen3RMSNorm

SwiGLU MLP — CuTile

Native CuTile kernels using cuda.tile with gather/scatter memory access. Uses flush_to_zero and approximate reciprocal for fast sigmoid on Blackwell. Full backward with recomputation (no extra activation memory).

src/bastile/ops/swiglu.py → patches Qwen3MLP

RoPE — CuTile

CuTile rotary position embedding with occupancy-based autotuning (tests occupancy 1, 2, 4, 8 and caches the best). In-place rotation on reshaped tensors to minimize memory traffic.

src/bastile/ops/rope.py → patches apply_rotary_pos_emb

Fused Linear Cross-Entropy — cuteDSL

Replaces the standard lm_head(hidden_states) → logits → cross_entropy(logits, labels) pipeline with quack's chunked_linear_cross_entropy. This never materializes the full logits tensor ([batch * seq, 151936] for Qwen3), instead computing cross-entropy in chunks of 4096. This is the single biggest memory saver at long sequence lengths.

src/bastile/ops/fused_linear_cross_entropy.py → patches Qwen3ForCausalLM.forward

API Reference

import bastile

bastile.apply()                  # Patch all ops
bastile.apply(rope=False)        # Patch everything except RoPE
bastile.reset()                  # Restore original implementations
bastile.get_patched_ops()        # List currently active patches
bastile.warmup_all_kernels()     # Pre-compile kernels (avoids JIT lag)
bastile.clear_autotune_cache()   # Re-run autotuning on next call

Running Benchmarks

# Small model comparison (HuggingFace vs Liger vs Bastile)
make bench-small

# Qwen3-8B sequence length sweep (parallel on 3 GPUs)
make bench-8b

# Qwen3-8B sweep (sequential, single GPU)
make bench-8b-seq

# Kernel profiling with torch.profiler
make bench-profile

Why CuTile instead of Triton?

Bastile uses NVIDIA's CuTile (cuda.tile) and cuteDSL instead of Triton. On Blackwell (sm_100), CuTile generates native PTX through NVIDIA's own compiler toolchain, while Triton's code generation for sm_100 is still maturing. In our benchmarks, Triton-based kernels (Liger) often underperform raw PyTorch on B200, whereas CuTile kernels consistently match or beat it.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bastile-0.1.0.tar.gz (16.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bastile-0.1.0-py3-none-any.whl (20.5 kB view details)

Uploaded Python 3

File details

Details for the file bastile-0.1.0.tar.gz.

File metadata

  • Download URL: bastile-0.1.0.tar.gz
  • Upload date:
  • Size: 16.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for bastile-0.1.0.tar.gz
Algorithm Hash digest
SHA256 b9c5b4e28dc43616c0aa6c2533c28431aca5b7d3bf042898a2ca4230edd8685c
MD5 43f00d94d5da154a3d585567a9e8e357
BLAKE2b-256 e1f9fc2f32221142d2f50461ccdee701d44329eede53161474883831525f4960

See more details on using hashes here.

Provenance

The following attestation bundles were made for bastile-0.1.0.tar.gz:

Publisher: publish.yml on aghilann/bastile

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file bastile-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: bastile-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 20.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for bastile-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5603a26cd20c0df78b0d20017a3182b73a15253ff22338f45afa62a3d8a5aec4
MD5 e666aa89a9a91145dc5c42f7e473d738
BLAKE2b-256 4c80ab943f50f3c3b54b51cf5141d918fab5355a172abd414107adc922e33ee5

See more details on using hashes here.

Provenance

The following attestation bundles were made for bastile-0.1.0-py3-none-any.whl:

Publisher: publish.yml on aghilann/bastile

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page