Monkey-patch PyTorch with optimized CuTile kernels for training (forward + backward)
Project description
Bastile
Drop-in monkey-patch that replaces HuggingFace Qwen3 ops with optimized CuTile and cuteDSL kernels for training on NVIDIA Blackwell GPUs.
Requires: NVIDIA Blackwell (B200 / B100) + CUDA Toolkit 13.1+
Benchmarks
Qwen3-8B (36 layers, 4096 hidden, 32 heads) — single B200, batch_size=1, bf16, AdamW:
Throughput (tokens/sec)
Peak GPU Memory (GB)
Latency (ms/iter)
Bastile's fused linear cross-entropy avoids materializing the full [batch * seq_len, vocab_size] logits tensor, which is the dominant memory cost at longer sequences. This is where the memory savings and throughput gains compound.
Installation
pip install bastile
Prerequisites:
- NVIDIA Blackwell GPU (B200, B100, GB200)
- CUDA Toolkit 13.1+
- PyTorch 2.4+ with CUDA support
# Inside a CUDA 13.1 container (e.g. baseten/gpu-dev:v8-cu13_1):
pip install bastile
Quick Start
import bastile
# Apply all patches BEFORE loading / creating the model
bastile.apply()
from transformers import Qwen3ForCausalLM
model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
model.train()
# Train as usual — Bastile automatically uses optimized kernels
Selective patching:
bastile.apply(
rms_norm=True, # cuteDSL RMSNorm (via quack)
swiglu=True, # CuTile SwiGLU MLP
rope=True, # CuTile RoPE with autotuning
fused_linear_cross_entropy=True, # Fused linear + CE (via quack)
)
Reset to original HuggingFace implementations:
bastile.reset()
Kernel Implementations
Bastile patches 4 operations in transformers.models.qwen3.modeling_qwen3. Each has a full forward and backward pass:
| Operation | Backend | Source | What it replaces |
|---|---|---|---|
| RMSNorm | cuteDSL (via quack) | ops/rms_norm.py |
Qwen3RMSNorm |
| SwiGLU MLP | CuTile (cuda.tile) |
ops/swiglu.py |
Qwen3MLP |
| RoPE | CuTile (cuda.tile) |
ops/rope.py |
apply_rotary_pos_emb |
| Fused Linear Cross-Entropy | cuteDSL (via quack) | ops/fused_linear_cross_entropy.py |
Qwen3ForCausalLM.forward |
RMSNorm — cuteDSL
Wraps quack's compiled cuteDSL kernels with reduced CPU dispatch overhead. Bypasses torch.library.custom_op dispatch by directly invoking the compiled kernel from a lookup cache, and caches SM counts to avoid repeated queries.
src/bastile/ops/rms_norm.py → patches Qwen3RMSNorm
SwiGLU MLP — CuTile
Native CuTile kernels using cuda.tile with gather/scatter memory access. Uses flush_to_zero and approximate reciprocal for fast sigmoid on Blackwell. Full backward with recomputation (no extra activation memory).
src/bastile/ops/swiglu.py → patches Qwen3MLP
RoPE — CuTile
CuTile rotary position embedding with occupancy-based autotuning (tests occupancy 1, 2, 4, 8 and caches the best). In-place rotation on reshaped tensors to minimize memory traffic.
src/bastile/ops/rope.py → patches apply_rotary_pos_emb
Fused Linear Cross-Entropy — cuteDSL
Replaces the standard lm_head(hidden_states) → logits → cross_entropy(logits, labels) pipeline with quack's chunked_linear_cross_entropy. This never materializes the full logits tensor ([batch * seq, 151936] for Qwen3), instead computing cross-entropy in chunks of 4096. This is the single biggest memory saver at long sequence lengths.
src/bastile/ops/fused_linear_cross_entropy.py → patches Qwen3ForCausalLM.forward
API Reference
import bastile
bastile.apply() # Patch all ops
bastile.apply(rope=False) # Patch everything except RoPE
bastile.reset() # Restore original implementations
bastile.get_patched_ops() # List currently active patches
bastile.warmup_all_kernels() # Pre-compile kernels (avoids JIT lag)
bastile.clear_autotune_cache() # Re-run autotuning on next call
Running Benchmarks
# Small model comparison (HuggingFace vs Liger vs Bastile)
make bench-small
# Qwen3-8B sequence length sweep (parallel on 3 GPUs)
make bench-8b
# Qwen3-8B sweep (sequential, single GPU)
make bench-8b-seq
# Kernel profiling with torch.profiler
make bench-profile
Why CuTile instead of Triton?
Bastile uses NVIDIA's CuTile (cuda.tile) and cuteDSL instead of Triton. On Blackwell (sm_100), CuTile generates native PTX through NVIDIA's own compiler toolchain, while Triton's code generation for sm_100 is still maturing. In our benchmarks, Triton-based kernels (Liger) often underperform raw PyTorch on B200, whereas CuTile kernels consistently match or beat it.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file bastile-0.1.0.tar.gz.
File metadata
- Download URL: bastile-0.1.0.tar.gz
- Upload date:
- Size: 16.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b9c5b4e28dc43616c0aa6c2533c28431aca5b7d3bf042898a2ca4230edd8685c
|
|
| MD5 |
43f00d94d5da154a3d585567a9e8e357
|
|
| BLAKE2b-256 |
e1f9fc2f32221142d2f50461ccdee701d44329eede53161474883831525f4960
|
Provenance
The following attestation bundles were made for bastile-0.1.0.tar.gz:
Publisher:
publish.yml on aghilann/bastile
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
bastile-0.1.0.tar.gz -
Subject digest:
b9c5b4e28dc43616c0aa6c2533c28431aca5b7d3bf042898a2ca4230edd8685c - Sigstore transparency entry: 942266781
- Sigstore integration time:
-
Permalink:
aghilann/bastile@1dc5ca7856b7bbbe2d1fb79aab9a57a07e641265 -
Branch / Tag:
refs/tags/v0.0.1 - Owner: https://github.com/aghilann
-
Access:
private
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@1dc5ca7856b7bbbe2d1fb79aab9a57a07e641265 -
Trigger Event:
release
-
Statement type:
File details
Details for the file bastile-0.1.0-py3-none-any.whl.
File metadata
- Download URL: bastile-0.1.0-py3-none-any.whl
- Upload date:
- Size: 20.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5603a26cd20c0df78b0d20017a3182b73a15253ff22338f45afa62a3d8a5aec4
|
|
| MD5 |
e666aa89a9a91145dc5c42f7e473d738
|
|
| BLAKE2b-256 |
4c80ab943f50f3c3b54b51cf5141d918fab5355a172abd414107adc922e33ee5
|
Provenance
The following attestation bundles were made for bastile-0.1.0-py3-none-any.whl:
Publisher:
publish.yml on aghilann/bastile
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
bastile-0.1.0-py3-none-any.whl -
Subject digest:
5603a26cd20c0df78b0d20017a3182b73a15253ff22338f45afa62a3d8a5aec4 - Sigstore transparency entry: 942266791
- Sigstore integration time:
-
Permalink:
aghilann/bastile@1dc5ca7856b7bbbe2d1fb79aab9a57a07e641265 -
Branch / Tag:
refs/tags/v0.0.1 - Owner: https://github.com/aghilann
-
Access:
private
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@1dc5ca7856b7bbbe2d1fb79aab9a57a07e641265 -
Trigger Event:
release
-
Statement type: