Skip to main content

Windows-only GPU neural network training via DirectCompute (D3D11 Compute Shaders)

Project description

DirectCompute Neural Network Engine

A high-performance, from-scratch neural network training framework that runs entirely on the GPU using DirectCompute (D3D11 Compute Shaders). No CUDA, no cuDNN — just raw HLSL shaders dispatched through a thin C++ runtime, driven from Python.

Quick Install (Windows Only)

pip install directcompute-nn

The Windows wheel includes everything: pre-compiled engine.dll, bundled HLSL shaders, and the Python API. No C++ compiler is required to start using it.

Key Features

  • Pure DirectCompute: Uses standard D3D11 compute shaders (HLSL).
  • Fast Autograd: Python-based automatic differentiation with GPU-accelerated gradient accumulation.
  • Modern Optimizers: Support for SGD, Adam, AdamW, and the state-of-the-art Muon optimizer.
  • Advanced Layers: Convolutional layers, Linear/Dense, BatchNorm2d, MaxPool, and more.
  • ONNX Inference: Load and run ONNX models directly on the DirectCompute engine.

Usage Examples

1. Minimal Forward Pass

import numpy as np
from nn_engine import Tensor, Linear, relu

# Data automatically moves to GPU
x = Tensor(np.random.randn(32, 128).astype(np.float32))
layer = Linear(128, 64)
y = relu(layer(x))

print(y.shape)  # (32, 64)

2. Standard Training (LeNet)

from nn_engine import ConvLayer, Linear, Model, maxpool2d, flatten, relu, Adam

class LeNet(Model):
    def __init__(self):
        super().__init__()
        self.c1 = ConvLayer(1, 6, 5)
        self.c2 = ConvLayer(6, 16, 5)
        self.l1 = Linear(16*4*4, 120)
        self.l2 = Linear(120, 84)
        self.l3 = Linear(84, 10)

    def forward(self, x):
        x = maxpool2d(relu(self.c1(x)))
        x = maxpool2d(relu(self.c2(x)))
        x = flatten(x)
        x = relu(self.l1(x))
        x = relu(self.l2(x))
        return self.l3(x)

model = LeNet()
optimizer = Adam(model.parameters(), lr=0.001)

# Standard training loop...
# loss = softmax_ce(model(xb), yb)
# loss.backward()
# optimizer.step()

3. Large Model (AlexNet)

The engine is capable of running larger architectures like AlexNet. You can use global pooling to handle multiple image sizes.

from nn_engine import ConvLayer, Linear, Model, maxpool2d, flatten

class AlexNet(Model):
    def __init__(self, num_classes=1000):
        super().__init__()
        self.c1 = ConvLayer(3, 64, 11, stride=4, padding=2)
        self.c2 = ConvLayer(64, 192, 5, padding=2)
        self.c3 = ConvLayer(192, 384, 3, padding=1)
        self.c4 = ConvLayer(384, 256, 3, padding=1)
        self.c5 = ConvLayer(256, 256, 3, padding=1)
        self.fc = Linear(256 * 6 * 6, num_classes)

    def forward(self, x):
        x = maxpool2d(relu(self.c1(x)), pool_size=3, stride=2)
        x = maxpool2d(relu(self.c2(x)), pool_size=3, stride=2)
        x = relu(self.c3(x))
        x = relu(self.c4(x))
        x = maxpool2d(relu(self.c5(x)), pool_size=3, stride=2)
        return self.fc(flatten(x))

See how to train AlexNet models with this

API Support

Layers

  • Linear: Fully-connected dense layer.
  • ConvLayer: 2D Convolutional layer with im2col/matmul optimization.
  • BatchNorm2d: Standard batch normalization for training and inference.
  • MaxPool2d: Spatial max pooling.
  • Flatten: Reshapes multidimensional input to flat vector.

Optimizers

  • SGD: Classic Stochastic Gradient Descent with gradient clipping.
  • Adam: Adaptive Moment Estimation with GPU-optimized momentum.
  • AdamW: Adam with weight decay decoupling (SOTA for many vision/LLM tasks).
  • Muon: Orthogonal Gradient Newton-Schulz iteration (extreme training efficiency).

Functions & Activations

  • relu: Rectified Linear Unit.
  • softmax_ce: Softmax + Cross-Entropy loss combined (numerical stability).
  • matmul: Matrix multiplication with optional transA and transB flags.
  • scale_add: Scaled element-wise addition.
  • rms: Root Mean Square calculation on GPU.

Contributing and Source Code

For the full source code, C++ engine implementation, and detailed architecture documentation, please visit our GitHub repository:

https://github.com/raviadi12/directcompute_torch

Future Roadmap

  • Vulkan & DX12 Backends: Cross-platform GPU support and lower-level control.
  • Enhanced Memory Management: Improved DMA copies and UMA (Unified Memory Architecture) optimizations for integrated GPUs.
  • Expanded Op Library: More activation functions, Dropout, and Transposed Convolutions.
  • Deployment: Enhanced ONNX support for both export and import.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

directcompute_nn-0.1.7.tar.gz (193.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

directcompute_nn-0.1.7-cp313-cp313-win_amd64.whl (124.1 kB view details)

Uploaded CPython 3.13Windows x86-64

File details

Details for the file directcompute_nn-0.1.7.tar.gz.

File metadata

  • Download URL: directcompute_nn-0.1.7.tar.gz
  • Upload date:
  • Size: 193.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for directcompute_nn-0.1.7.tar.gz
Algorithm Hash digest
SHA256 a463e4c1791266c16ae903a07634bcf48acd8d511e05b0f2820a3288917dece6
MD5 d0fa54d2e4adb545625acb42fd25c29f
BLAKE2b-256 00282b5eaa1e6c69d50f2c7c82e36edc3d9e217cda5d5875f85d9fd163bcbe95

See more details on using hashes here.

File details

Details for the file directcompute_nn-0.1.7-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for directcompute_nn-0.1.7-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 5ff2be0b0850ecde621576175c5e4628b9cf1c5af5ce8490e3088c910097095d
MD5 c2153b5cb38ed6d8e5e3938432296c97
BLAKE2b-256 0ef646d123ed7b2391c8bcfe88722e8560981fd09127065d76db6d6aaa0cddde

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page