Skip to main content

A high-performance tensor library with CUDA acceleration

Project description


⚡ Tensorax

A from-scratch tensor library with hand-written CUDA kernels.

No PyTorch. No NumPy. Pure C++/CUDA + Python.


PyPI Python Downloads License CUDA Tests Coverage


Usage Guide · Architecture · Contributing · Examples




🔩   Zero heavy dependencies

Only pybind11 — no PyTorch, NumPy, or cuBLAS at runtime.

⚡   Hand-written CUDA kernels

6 matmul variants, 5 attention kernels, 14 element-wise ops — all from scratch.

🧠   Full autograd engine

Reverse-mode autodiff with gradient tracking through 18+ operations.

🎯   PyTorch-like API

Familiar Tensor, nn.Module, optim.Adam, lr_scheduler interface — minimal learning curve.

🧱   Batteries included

Linear, Embedding, GELU, SiLU, LayerNorm, BatchNorm, Dropout, MultiHeadAttention, GQA, Flash Attention, LR schedulers — ready to train.

📚   Built to learn from

Clean, readable implementation of a DL framework from first principles.




Get Started

pip install tensorax
from tensorax import Tensor, nn, optim, lr_scheduler, functional as F

# Build
model = nn.Sequential(nn.Linear(4, 8), nn.GELU(), nn.LayerNorm(8), nn.Linear(8, 3))
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

# Train
for epoch in range(100):
    loss = F.mse_loss(model(x_train), y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()

Full usage guide with all APIs, code examples, and details: docs/USAGE.md




What's Inside


Core

  • Tensor with CPU ↔ CUDA
  • Broadcasting arithmetic
  • sum, mean with keepdim
  • reshape, transpose
  • exp, log, sqrt, pow
  • 13 dtype constants

Neural Networks

  • Linear, Embedding, Sequential
  • ReLU, Sigmoid, Tanh, Softmax, GELU, SiLU
  • LayerNorm, RMSNorm, BatchNorm
  • Dropout
  • Module base class

Training

  • SGD with momentum
  • Adam with bias correction
  • MSE, CrossEntropy, CE from logits
  • 5 LR schedulers (Step, Cosine, Exponential, Linear, MultiStep)
  • Autograd through 18+ ops

Attention

  • Scaled dot-product attention
  • Multi-Head Attention with projections
  • 5 CUDA kernels (naive → MMA)
  • Grouped Query Attention
  • Causal & padding masks

CUDA Kernels

  • 6 matmul implementations
  • 14 element-wise ops
  • Parallel reductions
  • Tiled + coalesced access

Infra

  • 433 tests, 95% coverage
  • CI/CD with GitHub Actions
  • pybind11 bindings
  • Automatic CUDA fallback



Performance

Matrix Multiplication — fp32, 3×1024×1024, 100 runs:

PyTorch CUDA (ref)         ████████████████████████████████████████████  0.41s  (4.51×)
Tensorax 1D Block Tiling   ██████████████████████████████████████████    0.95s  (2.31×)  ← best
Tensorax Tiled             ████████████████████████████████              1.22s  (1.80×)
NumPy CPU (baseline)       █████████████████████████                    1.85s  (1.00×)

2.31× faster than NumPy · 43% of PyTorch's cuBLAS kernels · all hand-written, zero library calls

Attention Kernels — fp32/fp16, B=4, H=8, S=256, Dk=512, Dv=512, 30 runs:

PyTorch SDPA (ref)         ████████████████████████████████████████████  0.04s  (2340×)
Tensorax MMA Tensor Core   ██████████████████████████████████████        0.33s   (297×)  ← best
Tensorax Optim. Flash      ██████████████████████████████████            0.52s   (187×)
Tensorax Flash SDPA        ██████████████████████████                    3.10s    (31×)
NumPy CPU (baseline)       ████████████████████                          7.06s    (14×)
Tensorax Tiled SDPA        ██████                                       32.91s     (3×)
Tensorax Naive SDPA        █                                            98.26s     (1×)

9.3× faster than Flash SDPA via raw PTX inline assembly using mma.sync Ampere Tensor Cores and SFU intrinsics.




Project Structure

csrc/                           C++ / CUDA backend
  cuda/kernels/                   elementwise · matmul (×6) · reduction · attention (×5)
  cpu/                            CPU fallback for all ops
  tensor_ops.{cpp,h}             pybind11 bindings

tensorax/                       Python package
  tensor.py                       Tensor class + autograd
  functional.py                   F.relu, F.gelu, F.silu, F.softmax, F.sdpa, ...
  nn/                             Linear, Embedding, norms, dropout, attention (SDPA, MHA, GQA)
  optim.py                        SGD, Adam
  lr_scheduler.py                 StepLR, CosineAnnealingLR, ExponentialLR, LinearLR, MultiStepLR



Roadmap

Status
Core ops · autograd · NN layers · norms · optimizers · losses · attention (4 CUDA kernels) · GQA · MHA · matmul (6 variants) · GELU/SiLU · Embedding · LR schedulers
🚧 Expanded benchmarking · higher test coverage
🔮 Conv2D · MaxPool2D · AdamW · indexing/slicing · serialization · DataLoader · multi-GPU · mixed precision · DDP · ONNX export



Documentation

Usage Guide API reference, code examples, training patterns
Architecture System design, kernel strategy, autograd internals
Development Build, test, contribute
Examples Runnable scripts for common tasks



Contributing

Fork → Branch → Commit → PR

See DEVELOPMENT.md for build instructions and guidelines.




Citation

@software{tensorax2025,
  title  = {Tensorax: Pure C++/CUDA Tensor Library},
  author = {Shrirang Mahajan},
  year   = {2025},
  url    = {https://github.com/NotShrirang/tensorax}
}



GitHub  ·  Issues  ·  Discussions

Built with ❤️ by @NotShrirang

⭐ Star if you find this useful

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tensorax-0.2.1.tar.gz (49.4 kB view details)

Uploaded Source

File details

Details for the file tensorax-0.2.1.tar.gz.

File metadata

  • Download URL: tensorax-0.2.1.tar.gz
  • Upload date:
  • Size: 49.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for tensorax-0.2.1.tar.gz
Algorithm Hash digest
SHA256 94dd1111c58071194b7a62eb86bd8af6628ce3a8a0d504ba14e8c28732b16a2c
MD5 9e33903607ff9e07d9664d6cc2d0a8ef
BLAKE2b-256 f1fe01bff5697e0485568c729b9671da82581fe35c56f60487461f6d3158ba91

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page