Skip to main content

Windows-only GPU neural network training via DirectCompute (D3D11 Compute Shaders)

Project description

DirectCompute Neural Network Engine

A high-performance, from-scratch neural network training framework that runs entirely on the GPU using DirectCompute (D3D11 Compute Shaders). No CUDA, no cuDNN — just raw HLSL shaders dispatched through a thin C++ runtime, driven from Python.

Quick Install (Windows Only)

pip install directcompute-nn

The Windows wheel includes everything: pre-compiled engine.dll, bundled HLSL shaders, and the Python API. No C++ compiler is required to start using it.

Key Features

  • Pure DirectCompute: Uses standard D3D11 compute shaders (HLSL).
  • Fast Autograd: Python-based automatic differentiation with GPU-accelerated gradient accumulation.
  • Modern Optimizers: Support for SGD, Adam, AdamW, and the state-of-the-art Muon optimizer.
  • Advanced Layers: Convolutional layers, Linear/Dense, BatchNorm2d, MaxPool, and more.
  • ONNX Inference: Load and run ONNX models directly on the DirectCompute engine.

Usage Examples

1. Minimal Forward Pass

import numpy as np
from nn_engine import Tensor, Linear, relu

# Data automatically moves to GPU
x = Tensor(np.random.randn(32, 128).astype(np.float32))
layer = Linear(128, 64)
y = relu(layer(x))

print(y.shape)  # (32, 64)

2. Standard Training (LeNet)

from nn_engine import ConvLayer, Linear, Model, maxpool2d, flatten, relu, Adam

class LeNet(Model):
    def __init__(self):
        super().__init__()
        self.c1 = ConvLayer(1, 6, 5)
        self.c2 = ConvLayer(6, 16, 5)
        self.l1 = Linear(16*4*4, 120)
        self.l2 = Linear(120, 84)
        self.l3 = Linear(84, 10)

    def forward(self, x):
        x = maxpool2d(relu(self.c1(x)))
        x = maxpool2d(relu(self.c2(x)))
        x = flatten(x)
        x = relu(self.l1(x))
        x = relu(self.l2(x))
        return self.l3(x)

model = LeNet()
optimizer = Adam(model.parameters(), lr=0.001)

# Standard training loop...
# loss = softmax_ce(model(xb), yb)
# loss.backward()
# optimizer.step()

3. Large Model (AlexNet)

The engine is capable of running larger architectures like AlexNet. You can use global pooling to handle multiple image sizes.

from nn_engine import ConvLayer, Linear, Model, maxpool2d, flatten

class AlexNet(Model):
    def __init__(self, num_classes=1000):
        super().__init__()
        self.c1 = ConvLayer(3, 64, 11, stride=4, padding=2)
        self.c2 = ConvLayer(64, 192, 5, padding=2)
        self.c3 = ConvLayer(192, 384, 3, padding=1)
        self.c4 = ConvLayer(384, 256, 3, padding=1)
        self.c5 = ConvLayer(256, 256, 3, padding=1)
        self.fc = Linear(256 * 6 * 6, num_classes)

    def forward(self, x):
        x = maxpool2d(relu(self.c1(x)), pool_size=3, stride=2)
        x = maxpool2d(relu(self.c2(x)), pool_size=3, stride=2)
        x = relu(self.c3(x))
        x = relu(self.c4(x))
        x = maxpool2d(relu(self.c5(x)), pool_size=3, stride=2)
        return self.fc(flatten(x))

Contributing and Source Code

For the full source code, C++ engine implementation, and detailed architecture documentation, please visit our GitHub repository:

https://github.com/raviadi12/directcompute_torch

Future Roadmap

  • Vulkan & DX12 Backends: Cross-platform GPU support and lower-level control.
  • Enhanced Memory Management: Improved DMA copies and UMA (Unified Memory Architecture) optimizations for integrated GPUs.
  • Expanded Op Library: More activation functions, Dropout, and Transposed Convolutions.
  • Deployment: Enhanced ONNX support for both export and import.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

directcompute_nn-0.1.6.tar.gz (192.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

directcompute_nn-0.1.6-cp313-cp313-win_amd64.whl (123.6 kB view details)

Uploaded CPython 3.13Windows x86-64

File details

Details for the file directcompute_nn-0.1.6.tar.gz.

File metadata

  • Download URL: directcompute_nn-0.1.6.tar.gz
  • Upload date:
  • Size: 192.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for directcompute_nn-0.1.6.tar.gz
Algorithm Hash digest
SHA256 27616174c091d2ff50a60a093ee2e2ec8c463bba00908c2305238c2adb8b6910
MD5 502bc04f697b4c5a9522972cca3b64fc
BLAKE2b-256 45e5ce6fd63da1086bfc4ce4ba9fba074a8a7bfc6a456ad77b566fe26af9fa69

See more details on using hashes here.

File details

Details for the file directcompute_nn-0.1.6-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for directcompute_nn-0.1.6-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 aece1898e7a1b2de222f1c88cd8ebac75091aa5289dc5939cbbbea4dcc7dfc4a
MD5 ce0eea5c50d5c776317cacbf9ed04dea
BLAKE2b-256 090dd2ada2bd6a9e7befe7d6387c52cc64179c78c4daba7e28d00534032b3841

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page