Windows-only GPU neural network training via DirectCompute (D3D11 Compute Shaders)
Project description
DirectCompute Neural Network Engine
A high-performance, from-scratch neural network training framework that runs entirely on the GPU using DirectCompute (D3D11 Compute Shaders). No CUDA, no cuDNN — just raw HLSL shaders dispatched through a thin C++ runtime, driven from Python.
Quick Install (Windows Only)
pip install directcompute-nn
The Windows wheel includes everything: pre-compiled engine.dll, bundled HLSL shaders, and the Python API. No C++ compiler required.
Key Features
- Pure DirectCompute: Uses standard D3D11 compute shaders (HLSL) — runs on any DirectX 11 GPU (Intel, AMD, NVIDIA) on Windows.
- Full Autograd: Python-based automatic differentiation with GPU-accelerated gradient accumulation.
- Complete Layer Library: Linear, Conv2D, DepthwiseConv2D, BatchNorm2d, MaxPool, GlobalAvgPool2D, and more.
- Differentiable Skip Connections:
add()is fully differentiable — build ResNet and MobileNet-style residual blocks. - Modern Optimizers: SGD, Adam, AdamW, and the state-of-the-art Muon optimizer.
- Transfer Learning: Load pretrained weights, freeze backbone layers, fine-tune on custom data.
- ONNX Inference: Load and run ONNX models on the DirectCompute engine.
Usage Examples
1. Minimal Forward Pass
import numpy as np
from nn_engine import Tensor, Linear, relu
x = Tensor(np.random.randn(32, 128).astype(np.float32))
layer = Linear(128, 64)
y = relu(layer(x))
print(y.shape) # (32, 64)
2. Standard Training Loop (LeNet)
import numpy as np
from nn_engine import (Tensor, ConvLayer, Linear, Model, BatchNorm2d,
maxpool2d, flatten, relu, softmax_ce,
AdamW, Metrics, end_batch)
class LeNet(Model):
def __init__(self):
super().__init__()
self.c1 = ConvLayer(1, 6, 5)
self.c2 = ConvLayer(6, 16, 5)
self.l1 = Linear(16*4*4, 120)
self.l2 = Linear(120, 10)
def forward(self, x):
x = maxpool2d(relu(self.c1(x)))
x = maxpool2d(relu(self.c2(x)))
x = flatten(x)
x = relu(self.l1(x))
return self.l2(x)
model = LeNet()
optimizer = AdamW(model.parameters(), lr=0.001)
metrics = Metrics()
for epoch in range(10):
for xb_np, yb_np in your_dataloader(): # yield (np.float32, np.int32)
optimizer.zero_grad()
xb, yb = Tensor(xb_np), Tensor(yb_np)
loss = softmax_ce(model(xb), yb)
metrics.update(loss, model(xb), yb)
loss.backward()
optimizer.step(clip=1.0)
end_batch() # flushes GPU, frees intermediates
3. ResNet-style Residual Block
The add() function is fully differentiable and supports skip connections:
from nn_engine import (Tensor, ConvLayer, BatchNorm2d, Linear,
relu, add, flatten, softmax_ce, AdamW, end_batch)
import numpy as np
class ResBlock:
def __init__(self, channels):
self.conv1 = ConvLayer(channels, channels, ks=3, padding=1)
self.bn1 = BatchNorm2d(channels)
self.conv2 = ConvLayer(channels, channels, ks=3, padding=1)
self.bn2 = BatchNorm2d(channels)
def __call__(self, x):
identity = x
out = relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
return relu(add(out, identity)) # differentiable skip connection
def parameters(self):
return [self.conv1.filters, self.conv1.bias,
self.bn1.gamma, self.bn1.beta,
self.conv2.filters, self.conv2.bias,
self.bn2.gamma, self.bn2.beta]
4. MobileNet-style Inverted Residual Block
DepthwiseConvLayer and relu6 enable full MobileNetV2-style blocks:
from nn_engine import (Tensor, ConvLayer, DepthwiseConvLayer, BatchNorm2d,
relu6, add, global_avg_pool2d, flatten,
Linear, softmax_ce, AdamW, end_batch)
import numpy as np
class InvertedResidual:
"""MobileNetV2 inverted residual block."""
def __init__(self, inp, oup, stride, expand_ratio):
hidden = int(round(inp * expand_ratio))
self.use_res = (stride == 1 and inp == oup)
self.expand_ratio = expand_ratio
if expand_ratio != 1:
self.expand_conv = ConvLayer(inp, hidden, ks=1)
self.expand_bn = BatchNorm2d(hidden)
self.dw = DepthwiseConvLayer(hidden, ks=3, stride=stride, padding=1)
self.dw_bn = BatchNorm2d(hidden)
self.project = ConvLayer(hidden, oup, ks=1)
self.project_bn = BatchNorm2d(oup)
def __call__(self, x):
identity = x
if self.expand_ratio != 1:
x = relu6(self.expand_bn(self.expand_conv(x)))
x = relu6(self.dw_bn(self.dw(x)))
x = self.project_bn(self.project(x))
if self.use_res:
x = add(x, identity)
return x
Transfer Learning with Pretrained MobileNetV2
The engine supports loading pretrained weights and performing feature extraction. Below is a complete working example — training a Cat/Dog classifier on PetImages in ~42 seconds using just a 128MB Intel iGPU, matching PyTorch CPU performance.
Step 1 — Download pretrained weights (one-time, requires torch)
pip install torch torchvision
# download_weights.py
import numpy as np
try:
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
model = mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
except ImportError:
from torchvision.models import mobilenet_v2
model = mobilenet_v2(pretrained=True)
model.eval()
arrays = {k: v.cpu().numpy() for k, v in model.state_dict().items()
if 'num_batches_tracked' not in k}
np.savez("mobilenet_v2_weights.npz", **arrays)
print(f"Saved {len(arrays)} tensors")
Step 2 — Build MobileNetV2 and load ImageNet weights
import ctypes, numpy as np
from nn_engine import (Tensor, ConvLayer, DepthwiseConvLayer, BatchNorm2d,
relu6, add, global_avg_pool2d, flatten, Linear,
softmax_ce, AdamW, Metrics, end_batch, lib)
MOBILENET_SETTINGS = [
(1, 16, 1, 1), (6, 24, 2, 2), (6, 32, 3, 2),
(6, 64, 4, 2), (6, 96, 3, 1), (6, 160, 3, 2), (6, 320, 1, 1),
]
class InvertedResidual:
def __init__(self, inp, oup, stride, t):
hidden = int(round(inp * t))
self.use_res = (stride == 1 and inp == oup)
self.t = t
if t != 1:
self.expand_conv = ConvLayer(inp, hidden, ks=1)
self.expand_bn = BatchNorm2d(hidden)
self.dw = DepthwiseConvLayer(hidden, ks=3, stride=stride, padding=1)
self.dw_bn = BatchNorm2d(hidden)
self.project = ConvLayer(hidden, oup, ks=1)
self.project_bn = BatchNorm2d(oup)
def __call__(self, x):
identity = x
if self.t != 1:
x = relu6(self.expand_bn(self.expand_conv(x)))
x = relu6(self.dw_bn(self.dw(x)))
x = self.project_bn(self.project(x))
return add(x, identity) if self.use_res else x
def parameters(self):
p = []
if self.t != 1:
p += [self.expand_conv.filters, self.expand_conv.bias,
self.expand_bn.gamma, self.expand_bn.beta]
return p + [self.dw.filters, self.dw.bias,
self.dw_bn.gamma, self.dw_bn.beta,
self.project.filters, self.project.bias,
self.project_bn.gamma, self.project_bn.beta]
def set_training(self, mode):
if self.t != 1: self.expand_bn.training = mode
self.dw_bn.training = mode
self.project_bn.training = mode
class MobileNetV2:
def __init__(self, num_classes=2):
self.conv0 = ConvLayer(3, 32, ks=3, stride=2, padding=1)
self.bn0 = BatchNorm2d(32)
inp = 32
idx = 0
for t, c, n, s in MOBILENET_SETTINGS:
for i in range(n):
setattr(self, f'b{idx}', InvertedResidual(inp, c, s if i == 0 else 1, t))
idx += 1
inp = c
self.num_blocks = idx # 17
self.conv_last = ConvLayer(320, 1280, ks=1)
self.bn_last = BatchNorm2d(1280)
self.classifier = Linear(1280, num_classes)
def forward(self, x):
x = relu6(self.bn0(self.conv0(x)))
for i in range(self.num_blocks):
x = getattr(self, f'b{i}')(x)
x = relu6(self.bn_last(self.conv_last(x)))
x = global_avg_pool2d(x)
return self.classifier(flatten(x))
def extract_features(self, x):
"""Backbone only — returns 1280-dim features without classifier."""
x = relu6(self.bn0(self.conv0(x)))
for i in range(self.num_blocks):
x = getattr(self, f'b{i}')(x)
x = relu6(self.bn_last(self.conv_last(x)))
x = global_avg_pool2d(x)
return flatten(x)
def parameters(self):
p = [self.conv0.filters, self.conv0.bias, self.bn0.gamma, self.bn0.beta]
for i in range(self.num_blocks): p.extend(getattr(self, f'b{i}').parameters())
return p + [self.conv_last.filters, self.conv_last.bias,
self.bn_last.gamma, self.bn_last.beta,
self.classifier.w, self.classifier.b]
def eval(self):
self.bn0.training = False
for i in range(self.num_blocks): getattr(self, f'b{i}').set_training(False)
self.bn_last.training = False
def load_pretrained(self, npz_path):
weights = np.load(npz_path)
def upload(tensor, key):
data = weights[key].astype(np.float32)
assert data.shape == tensor.shape, f"Shape mismatch {key}: {data.shape} vs {tensor.shape}"
tensor.data = data.copy()
lib.UpdateBuffer(tensor.gpu_buf, data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))
def load_conv_bn(ck, bk, cl, bl):
upload(cl.filters, f"{ck}.weight")
upload(bl.gamma, f"{bk}.weight"); upload(bl.beta, f"{bk}.bias")
upload(bl.running_mean, f"{bk}.running_mean")
upload(bl.running_var, f"{bk}.running_var")
load_conv_bn("features.0.0", "features.0.1", self.conv0, self.bn0)
bidx, fidx = 0, 1
for t, c, n, s in MOBILENET_SETTINGS:
for _ in range(n):
blk, p = getattr(self, f'b{bidx}'), f"features.{fidx}"
if t != 1:
load_conv_bn(f"{p}.conv.0.0", f"{p}.conv.0.1", blk.expand_conv, blk.expand_bn)
upload(blk.dw.filters, f"{p}.conv.1.0.weight")
for attr, key in [("gamma", "weight"), ("beta", "bias"),
("running_mean", "running_mean"), ("running_var", "running_var")]:
upload(getattr(blk.dw_bn, attr), f"{p}.conv.1.1.{key}")
upload(blk.project.filters, f"{p}.conv.2.weight")
for attr, key in [("gamma", "weight"), ("beta", "bias"),
("running_mean", "running_mean"), ("running_var", "running_var")]:
upload(getattr(blk.project_bn, attr), f"{p}.conv.3.{key}")
else:
upload(blk.dw.filters, f"{p}.conv.0.0.weight")
for attr, key in [("gamma", "weight"), ("beta", "bias"),
("running_mean", "running_mean"), ("running_var", "running_var")]:
upload(getattr(blk.dw_bn, attr), f"{p}.conv.0.1.{key}")
upload(blk.project.filters, f"{p}.conv.1.weight")
for attr, key in [("gamma", "weight"), ("beta", "bias"),
("running_mean", "running_mean"), ("running_var", "running_var")]:
upload(getattr(blk.project_bn, attr), f"{p}.conv.2.{key}")
bidx += 1; fidx += 1
load_conv_bn("features.18.0", "features.18.1", self.conv_last, self.bn_last)
print(f"Loaded pretrained weights from {npz_path}")
Step 3 — Feature extraction + classifier training
import os, time
import numpy as np
from PIL import Image
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)
def load_images(folder, label, limit=500):
images, labels = [], []
for f in list(os.listdir(folder))[:limit + 200]: # buffer for corrupt files
if len(images) >= limit: break
try:
img = Image.open(os.path.join(folder, f)).convert('RGB').resize((224, 224))
arr = np.array(img).transpose(2, 0, 1).astype(np.float32) / 255.0
images.append((arr - IMAGENET_MEAN) / IMAGENET_STD)
labels.append(label)
except Exception:
continue
return images, labels
# Load data
cats, cat_labels = load_images("PetImages/Cat", 0, limit=500)
dogs, dog_labels = load_images("PetImages/Dog", 1, limit=500)
X = np.array(cats + dogs, dtype=np.float32)
Y = np.array(cat_labels + dog_labels, dtype=np.int32)
idx = np.random.permutation(len(X)); X, Y = X[idx], Y[idx]
split = int(0.9 * len(X))
X_train, X_val = X[:split], X[split:]
Y_train, Y_val = Y[:split], Y[split:]
# Build model + load weights
model = MobileNetV2(num_classes=2)
model.load_pretrained("mobilenet_v2_weights.npz")
for p in model.parameters(): p.requires_grad = False
model.eval()
# Phase 1: Extract features once (runs frozen backbone on GPU)
def extract(images, label=""):
feats = []
for i in range(0, len(images), 2):
xb = Tensor(images[i:i+2])
feats.append(model.extract_features(xb).sync().copy())
end_batch()
return np.concatenate(feats)
print("Extracting features...")
t0 = time.time()
F_train = extract(X_train, "train")
F_val = extract(X_val, "val")
print(f"Done in {time.time()-t0:.1f}s ({F_train.shape[1]}-dim features)")
# Phase 2: Train linear head on cached features
clf = Linear(1280, 2)
opt = AdamW([clf.w, clf.b], lr=0.001, weight_decay=0.01)
for epoch in range(30):
perm = np.random.permutation(len(F_train))
opt.zero_grad()
for i in range(0, len(F_train), 32):
xb = Tensor(F_train[perm[i:i+32]])
yb = Tensor(Y_train[perm[i:i+32]])
loss = softmax_ce(clf(xb), yb)
loss.backward(); opt.step(); opt.zero_grad()
end_batch()
print("Training complete!")
Benchmark Results (PetImages Cat vs Dog, Intel Xe iGPU)
Changelog
v0.2.1 — The "Heavy Matmul" Update & Transformers
- Adaptive Matrix Multiplication: Added
nn_matmul_heavy.hlsl(128x128 Tile Size, 8x8 thread workload) specifically designed for very large matrices like those found in AlexNet and Transformers. The C++ engine now dynamically shifts gears between Universal (16x16), Coarsened/dGPU (64x64), and Heavy (128x128) based on matrix dimensions, nearly eliminating the performance gap with PyTorch CUDA on large workloads. - Custom Shaders: Added
CustomShaderandcustom_unaryAPIs. Users can now write pure HLSL inline in Python, dynamically compile it viaD3DCompileFromFile, and plug it directly into the autograd graph without touching C++. - Transformer Primitives: Added support for Multi-Head Attention (
attention_forward,attention_causal), LayerNorm, and GELU activations.
v0.2.0 — GPU-Side Performance Optimizations
Major GPU performance improvements reducing per-batch time on segmentation workloads:
Dice Loss — 2-pass parallel backward (was O(N²), now O(1) per thread)
nn_dice_loss.hlslrewritten with 256-thread shared-memory reduction; dispatch is(batch, 1, 1)instead of a single serial threadnn_dice_loss_grad.hlslrewritten to a two-pass approach: pre-compute per-batch{intersection, sum_pred, sum_target}sums in a new Pass 1 shader (nn_dice_loss_sums.hlsl), then each backward thread reads directly from that buffer — eliminating the ~19ms O(N²) loop entirely
GPU-side gradient clipping — eliminates 30+ CopyResource stalls
- New
sgd_step_clipped(params, lr, max_norm)function: gradient norm computation and SGD weight update are fully GPU-resident - Three new shaders:
nn_grad_sq_reduce.hlsl(per-param ² partial sums),nn_grad_norm_final.hlsl(total norm + clip scale),nn_sgd_clipped.hlsl(scaled weight update) - New C++ function
SGDBatchClippedinengine.cpp: only 1 GPU→CPU readback (the norm scalar) vs. 1 readback per parameter previously
New and updated shaders in this release:
nn_dice_loss_sums.hlsl, nn_grad_sq_reduce.hlsl, nn_grad_norm_final.hlsl, nn_sgd_clipped.hlsl, plus updated nn_dice_loss.hlsl and nn_dice_loss_grad.hlsl
v0.1.9 and earlier
See GitHub releases for full history.
| DirectCompute (iGPU) | PyTorch (CPU) | |
|---|---|---|
| Feature extraction (1600 imgs) | 41.7s | 43.1s |
| Classifier training (30 epochs) | 0.9s | 1.4s |
| Total | 42.7s | 44.5s |
| Test accuracy (400 unseen) | 98.2% | 98.0% |
The DirectCompute engine runs feature extraction faster than PyTorch CPU, on a 128MB integrated GPU with no dedicated VRAM.
Full API Reference
Layers
| Class | Description |
|---|---|
Linear(in, out) |
Fully-connected layer. Weights: (in, out), bias: (out,) |
ConvLayer(inC, outC, ks, stride, padding) |
2D convolution via im2col+matmul. Skips im2col when requires_grad=False (frozen layers save GPU memory) |
DepthwiseConvLayer(channels, ks, stride, padding) |
Depthwise separable convolution — one filter per input channel. Used in MobileNet-style blocks |
BatchNorm2d(num_features) |
Batch normalization with running stats. Set .training=False for eval/frozen mode |
maxpool2d(x, pool_size, stride) |
Max pooling with saved indices for backward |
global_avg_pool2d(x) |
Global average pool: (N, C, H, W) → (N, C, 1, 1) |
flatten(x) |
Flatten spatial dims: (N, C, H, W) → (N, C*H*W) |
Differentiable Operations
| Function | Description |
|---|---|
relu(x) |
ReLU activation |
relu6(x) |
Clamped ReLU: min(max(x, 0), 6) — used in MobileNetV2 |
add(a, b) |
Element-wise add with full gradient support — enables residual/skip connections |
softmax_ce(logits, labels) |
Fused softmax + cross-entropy loss |
matmul(A, B, transA, transB) |
Matrix multiply with optional transpose flags |
Optimizers
| Class | Description |
|---|---|
SGD(params, lr) |
Stochastic Gradient Descent with gradient clipping |
Adam(params, lr, weight_decay) |
Adaptive Moment Estimation |
AdamW(params, lr, weight_decay) |
Adam with decoupled weight decay (recommended for most tasks) |
Muon(params, lr) |
Orthogonal gradient optimizer via Newton-Schulz iteration |
Tensor
t = Tensor(np_array, requires_grad=True)
t.sync() # read data back from GPU → numpy array
t.upload(data) # push new data into existing GPU buffer (no realloc)
t.backward() # run autograd backward from this tensor
t.shape # tuple
t.size # total element count
t.grad # gradient Tensor (set after backward)
Model Base Class
class MyModel(Model):
def forward(self, x): ... # implement forward pass
def parameters(self): ... # return list of Tensors
model.parameters() # auto-discovered from Linear/ConvLayer/BatchNorm2d
model.train(); model.eval() # toggle BN training mode
model.export("model.onnx", [1,3,224,224]) # ONNX export
Training Utilities
end_batch() # flush GPU pipeline + bulk-free intermediate tensors
Metrics() # tracks loss and accuracy across batches
get_pool_stats() # (hits, misses) for GPU buffer pool
get_pool_memory() # bytes currently in pool
Contributing and Source Code
Full source code, C++ engine, and HLSL shader implementation:
https://github.com/raviadi12/directcompute_torch
Future Roadmap
- Vulkan & DX12 Backends: Cross-platform GPU support.
- UMA Optimizations: Better memory paths for integrated GPUs.
- Transposed Convolutions: For upsampling / generative models.
- Enhanced ONNX: Full graph import for pretrained model deployment.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file directcompute_nn-0.2.5.tar.gz.
File metadata
- Download URL: directcompute_nn-0.2.5.tar.gz
- Upload date:
- Size: 227.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fbbc54c75d3331cce62f907d39b13a3fc8ff02cf6cfb8f43c0b17631ff375642
|
|
| MD5 |
46dc4bb1cd08b304edf1f90206a3657c
|
|
| BLAKE2b-256 |
ca0860298f1d0d5552466d7a198d587798e4610802a47ea7ce4563a25362faea
|
File details
Details for the file directcompute_nn-0.2.5-cp313-cp313-win_amd64.whl.
File metadata
- Download URL: directcompute_nn-0.2.5-cp313-cp313-win_amd64.whl
- Upload date:
- Size: 173.0 kB
- Tags: CPython 3.13, Windows x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5524af2fc9c03c96752cda6b34ff57638a1e7c30c53b4652200d998edf0bf99e
|
|
| MD5 |
7f4525eaaeabe5479ca6fbe2507f6b81
|
|
| BLAKE2b-256 |
8871a2ab6186d03c9413dbc210b8ab3c474a7ab38647c33b3bc6ccbbb7702f85
|