Deterministic CUDA operations for reproducible deep learning
Project description
BitExact
Deterministic CUDA Kernels for Reproducible Deep Learning
BitExact is a research-driven CUDA library providing bit-exact deterministic GPU tensor operations. It ensures identical floating-point results across runs, batches, and devices, removing nondeterminism from key deep-learning computations.
The library is designed to be plug and play with PyTorch. This means it can serve as a drop-in replacement for selected PyTorch tensor operations while guaranteeing bit-level reproducibility.
BitExact is particularly suited for:
- Model reproducibility research - verifying training consistency across runs
- Numerical analysis and benchmarking - comparing model outputs with precision guarantees
- Deployment pipelines where deterministic inference is required for compliance or scientific validation
Quick Links
- Quick Start ๐
- API Reference ๐
- Design Reference โ๏ธ
- Performance Reference ๐จ
- Testing ๐งช
- Project Structure ๐๏ธ
- Contributing ๐ก
- Project Status โ
- Acknowledgements ๐
Quick Start Example
import torch, bitexact
x = torch.randn(4, 4, device="cuda")
w = torch.ones(4, device="cuda")
y = bitexact.rms_norm(x, w)
print(y)
Current Features
| Category | Kernel Operation | Reference |
|---|---|---|
| Linear Algebra | Matrix Multiplication | MatMul |
| Normalization | RMS Normalization | RmsNorm |
| Normalization | Layer Normalization | RmsNorm |
| Reductions | Sum | Sum |
| Reductions | Mean | Mean |
| Reductions | Max | Max |
| Reductions | Min | Min |
| Activations | Sigmoid | Sigmoid |
More Determinsitic Kernels May Be Coming Soon
Installation
Prerequisites
- Python $\geq 3.9$
- CUDA $\geq 12.0$
- PyTorch $\geq2.1$
- A C++ Compiler (MSVC 2022 / gcc $\geq 9$)
From Source
git clone https://github.com/aaravkohli1/BitExact.git
cd BitExact
pip install . --no-build-isolation
PyPI
pip install bitexact
Performance at a Glance
| Operation | Throughput (vs PyTorch) | Notes |
|---|---|---|
| Matrix Multiplication | 0.47x | Slower than cuBLAS; PyTorchโs highly tuned GEMM outperforms deterministic reduction. |
| RMS Normalization | 5.09x | Fused mean, sqrt, and scaling operations reduce kernel launches and memory access. |
| Layer Normalization | 1.66x | Fused single-kernel variance reduces global memory passes and improves speed on small tensors. |
| Sum | 1.98x | Optimized shared-memory reduction with fixed traversal order for determinism. |
| Mean | 1.69x | Builds on the Sum kernel with deterministic normalization by element count. |
| Max | 1.75x | Deterministic warp-level reduction; avoids divergent branching used in PyTorch. |
| Min | 1.98x | Similar to Max; uses unified deterministic traversal for all elements. |
| Variance | 1.35x | Uses fused E[xยฒ] - (E[x])ยฒ formulation with deterministic accumulation. |
| Sigmoid | 0.92x | Identical arithmetic to PyTorch; near-equal performance and perfect bit equivalence. |
| Average | 1.88x | Tests performed on small-scale tensors; PyTorch is optimized for large batch sizes. |
(Benchmarked on NVIDIA GeForce RTX 4060 Ti, PyTorch 2.6.0, CUDA 12.5)
Interpretation of Results
BitExactโs performance advantage comes primarily from kernel fusion and deterministic reduction order, which minimize synchronization and memory traffic. However, PyTorchโs fused kernels outperform in large-batch GEMM and high-throughput workloads. These results emphasize that BitExact prioritizes determinism and reproducibility over raw FLOPS.
Local Benchmarks
To see how BitExact benchmarks on your machine, run:
python benchmarks/benchmark.py
Example output
BitExact vs PyTorch - Benchmark Suite
Operation Torch (ms) BitExact (ms) Speed Max Diff Match
-------------------------------------------------------------------------
MatMul 0.0336 0.0692 0.48x 1.07e-04 True
Sum 0.0086 0.0117 0.73x 1.14e-05 True
Mean 0.0083 0.0079 1.05x 1.12e-08 True
Max 0.0087 0.0117 0.74x 0.00e+00 True
Min 0.0097 0.0080 1.21x 0.00e+00 True
Sigmoid 0.0074 0.0073 1.01x 0.00e+00 True
RMSNorm 0.0430 0.0084 5.12x 1.91e-06 True
LayerNorm 0.0881 0.0547 1.61x 1.91e-06 True
Variance 0.0311 0.0266 1.17x 2.38e-07 True
Note: Matches use atol=1e-4, rtol=1e-6 tolerance (within FP32 rounding).
-------------------------------------------------------------------------
Summary
-------------------------------------------------------------------------
Operations faster than PyTorch: 6/9
All operations deterministic: True
Average speedup: 1.46x
=========================================================================
All measurements use CUDA events for precise GPU timing with 10 warmup and 100 timed iterations. Run-to-run variance of 5-15% is typical due to GPU boost clocks, thermal state, and driver scheduling. Focus on relative speedup trends rather than absolute millisecond values.
Testing
BitExact includes deterministic equality tests for all kernels.
To run the test suite, ensure you have PyTest installed. To install PyTest, run:
pip install -U pytest
Then you can run the test suite with:
pytest tests/
Recommended Flags
-v- Verbose flag (shows results of each individual test)-s- Donโt capture output (allows setup logs from conftest.py)
Example:
pytest tests/ -v -s
Because many tests utilize randomized tensors, running the suite multiple times can help verify reproducibility and numerical stability. You can run the tests any number of times, the examples below simply use 3 as a placeholder.
Linux
for i in {1..3}; do pytest -v; done
Windows
for ($i = 1; $i -le 3; $i++) { pytest -v }
Troubleshooting
- CUDA OOM: close other GPU workloads, then re-run. Cache is auto-cleared; if needed, re-run with
-sto confirm setup logs. - No GPU: tests require a CUDA-capable device; CPU fallbacks are not provided.
All tests verify bit-exact equivalence to PyTorchโs reference implementations and ensure reproducibility across multiple runs and devices.
Deterministic Inference
The examples/deterministic_inference.py script demonstrates a small neural network using BitExact kernels (matmul, rms_norm, and sigmoid). Running the example verifies that the networkโs outputs are bit-for-bit identical across multiple runs, confirming complete GPU determinism.
Run the file with:
python examples/deterministic_inference.py
Project Structure
bitexact/
โโโ bitexact/ # Python bindings and high-level API
โ โโโ __init__.py
โ
โโโ benchmarks/ # Benchmarking suite for performance comparison
โ โโโ benchmark.py
โ โโโ utils.py
โ
โโโ docs/ # Technical documentation
โ โโโ api.md
โ โโโ design.md
โ
โโโ examples/ # Minimal runnable examples
โ โโโ basic_usage.py # Simple demonstration of deterministic ops
โ โโโ deterministic_inference.py # Reproducible model inference pipeline
โ
โโโ src/ # Core CUDA/C++ source
โ โโโ bindings.cpp # PyTorch extension bindings (exposes kernels to Python)
โ โ
โ โโโ ops/ # Kernel implementations
โ โโโ matmul/ # Matrix multiplication kernels
โ โ โโโ matmul.cu
โ โ โโโ matmul.cuh
โ โ
โ โโโ reductions/ # Deterministic reduction kernels
โ โ โโโ sum.cu
โ โ โโโ sum.cuh
โ โ โโโ mean.cu
โ โ โโโ mean.cuh
โ โ โโโ max.cu
โ โ โโโ max.cuh
โ โ โโโ min.cu
โ โ โโโ min.cuh
โ โ โโโ var.cu
โ โ โโโ var.cuh
โ โ
โ โโโ normalization/ # Normalization kernels
โ โ โโโ rms_norm.cu
โ โ โโโ rms_norm.cuh
โ โ โโโ layer_norm.cu
โ โ โโโ layer_norm.cuh
โ โ
โ โโโ activations/ # Activation kernels
โ โ โโโ sigmoid.cu
โ โ โโโ sigmoid.cuh
โ โ
โ โโโ utils/ # Shared CUDA utilities
โ โโโ cuda_utils.cuh # Common device helpers (grid-stride loops, etc.)
โ โโโ dtype_utils.cuh # Type casting and precision utilities
โ โโโ reduction.cuh # Shared reduction patterns for deterministic ops
โ
โโโ tests/ # Pytest suite
โ โโโ conftest.py
โ โโโ test_determinism.py
โ
โโโ LICENSE # License file
โโโ README.md # Project overview and documentation
โโโ setup.py # Build and installation script
Contributions
Contributions are welcome! If you have an idea for a Kernel, feel free to implement it (the largest missing one is attention).
Please ensure new kernels:
- Pass Deterministic equality tests (see testing suite).
- Use Warp-synchronous, non-atomic reduction patterns.
- Includes both .cu and .cuh files and a corresponding test.
Project Status
This project was an experiment that followed a research article. I found it to be an interesting problem, so I spent a portion of my reading week making this library. I do find the problem of determinism to be really interesting so I will keep developing this library, but on no fixed schedule.
There are many ways the library could be expanded, outlined in the design document. If you are interested, feel free to make a contribution.
Acknowledgements
This project draws inspiration from research by Thinking Machines Lab on deterministic GPU computation and reproducible deep learning. Their exploration of bit-exact kernels and floating-point determinism informed the design philosophy of BitExact.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file bitexact-0.1.0.tar.gz.
File metadata
- Download URL: bitexact-0.1.0.tar.gz
- Upload date:
- Size: 13.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
74af357076a1f034a2fd52e2c11aad7bf540c419ab70e960a5c48186ad91b199
|
|
| MD5 |
d5a36c917799109f5b60540c5f969247
|
|
| BLAKE2b-256 |
5fac127406a9bf0935a2533625f6d7b0935f7838648cc4b5b22836e79378e8e2
|