Skip to main content

GPU neural network training via DirectCompute (D3D11 Compute Shaders) — no CUDA required

Project description

DirectCompute Neural Network Engine

A from-scratch neural network training framework that runs entirely on the GPU using DirectCompute (D3D11 Compute Shaders). No CUDA, no cuDNN — just raw HLSL shaders dispatched through a thin C++ runtime, driven from Python via ctypes.

Architecture Overview

Python (nn_engine.py)          ←  Autograd, layers, optimizer
    │
    ▼  ctypes FFI
C++ DLL (engine.dll)           ←  D3D11 device, shader dispatch, buffer management
    │
    ▼  ID3D11DeviceContext::Dispatch()
HLSL Compute Shaders (nn_*.hlsl)  ←  GPU kernels for every operation

Layer Stack

Component File Role
Runtime engine.cppengine.dll D3D11 device init, buffer create/read/release, shader compile & dispatch
Framework nn_engine.py Tensor class, autograd (topological backward), layers (Linear, ConvLayer, MaxPool2D, Flatten), SGD optimizer
Training scripts train_lenet.py, train_alexnet.py End-to-end training loops with validation
Shaders nn_*.hlsl One HLSL file per GPU kernel (see full list below)

Prerequisites

  • Windows 10/11 with a DirectX 11-capable GPU
  • Visual Studio 2022+ (need cl.exe and vcvarsall.bat for compiling the DLL)
  • Python 3.10+ with numpy and Pillow
pip install numpy pillow

Building

Compile the engine DLL

compile_engine.bat

This runs:

call "C:\Program Files\Microsoft Visual Studio\18\Community\VC\Auxiliary\Build\vcvarsall.bat" x64
cl.exe /EHsc /O2 /LD engine.cpp /link /OUT:engine.dll

Note: Adjust the Visual Studio path in compile_engine.bat to match your installation. The key requirement is vcvarsall.bat to set up the MSVC x64 toolchain.

Output: engine.dll in the project root (loaded by nn_engine.py at import time).

HLSL shaders

Shaders are compiled at runtime by D3DCompileFromFile() when nn_engine.py is imported — no separate shader compilation step needed. Each shader must have entry point CSMain and target cs_5_0.

Running

LeNet on MNIST

python train_lenet.py

Expects mnist/0/, mnist/1/, ..., mnist/9/ folders containing grayscale digit images.

AlexNet on PetImages

python train_alexnet.py

Expects PetImages/Cat/ and PetImages/Dog/ folders containing JPEG images (resized to 224×224 at load time).

How the Engine Works

engine.cpp — The D3D11 Runtime

The C++ DLL exposes 8 functions via extern "C" __declspec(dllexport):

Function Purpose
InitEngine() Creates D3D11 device + immediate context
CompileShader(name, path) Compiles HLSL from file, stores ID3D11ComputeShader* by name
CreateBuffer(data, count) Allocates a StructuredBuffer<float> on GPU with SRV + UAV views
AddRefBuffer(handle) Increments ref count (for shared buffers like Flatten)
ReleaseBuffer(handle) Decrements ref count; frees GPU resources at zero
ReadBuffer(handle, dst) Copies GPU → CPU via staging buffer (triggers Flush())
ClearBuffer(handle) Zeros a UAV buffer (used before scatter ops like MaxPool backward)
RunShader(name, srvs, srvCount, uavs, uavCount, threads, cb, cbSize) The main dispatch call — binds resources, sets constant buffer, calls Dispatch()

Buffer Layout

Every GPU buffer is a StructuredBuffer<float> wrapped in a GPUBuffer struct:

struct GPUBuffer {
    ID3D11Buffer* buffer;              // The raw D3D11 buffer
    ID3D11ShaderResourceView* srv;     // For reading in shaders (register t0, t1, ...)
    ID3D11UnorderedAccessView* uav;    // For writing in shaders (register u0, u1, ...)
    uint32_t size;                     // Element count (floats)
    int refCount;                      // Manual ref counting for shared ownership
};

Constant Buffer Convention

Shader parameters are passed through a constant buffer at register(b0). Two layouts are used:

ParamsCB — 9 × uint32 (36 bytes max):

cbuffer Params : register(b0) {
    uint u1, u2, u3, u4, u5, u6, u7, u8, u9;
};

Each shader interprets these fields differently (e.g., M/K/N/flags for matmul, batch/inC/inH/inW/kH/stride/padding/outH/outW for im2col).

SGDParamsCB — 16 bytes:

cbuffer Params : register(b0) {
    uint count; float lr; float clip; uint pad;
};

nn_engine.py — The Python Autograd Framework

Tensor

  • Wraps a NumPy array + a GPU buffer handle
  • requires_grad=True enables gradient tracking
  • track=True registers for bulk cleanup via release_all_buffers()
  • .sync() reads GPU data back to CPU (triggers D3D11 Flush())
  • .backward() builds topological order and runs backward pass

Autograd

Each differentiable operation is a Function subclass with forward() and backward() methods. The forward pass stores self.inputs and attaches res._ctx = self to the output tensor. The backward pass receives grad_output and calls input._accumulate_grad(grad).

Gradient accumulation uses the grad_accum shader to add in-place on GPU when a tensor receives gradients from multiple paths.

Convolution: The im2col Approach

Convolutions are implemented via the im2col transformation, which converts convolution into matrix multiplication:

Forward:

input → [im2col] → col_matrix → [matmul] filters × col_matrix → [conv_reshape] → output
  1. im2col shader extracts patches into a (inC*kH*kW) × (batch*outH*outW) matrix
  2. matmul computes filters(outC × inC*kH*kW) × col_matrix(outC × batch*outH*outW)
  3. conv_reshape adds bias and reshapes to (batch, outC, outH, outW)

Backward:

grad_output → [conv_grad_reshape] → grad_reshaped (outC × batch*outH*outW)

Filter gradients:   grad_reshaped × im2col_matrix^T  (reuses saved im2col from forward)
Input gradients:    filters^T × grad_reshaped → [col2im] → grad_input
Bias gradients:     sum over spatial dims of grad_reshaped

The transpose operations use the matmul flags parameter instead of explicit transpose shaders.

HLSL Shader Reference

Matmul Shaders

Shader Tile Size Description
nn_matmul_universal.hlsl 16×16 Tiled matmul with transpose flags. Used for small matrices (M or N < 128)
nn_matmul_coarsened.hlsl 64×64, 4×4 WPT Coarsened tiled matmul. 16×16 thread groups, each thread computes a 4×4 output block. Used for large matrices

Both shaders share the same constant buffer: {M, K, N, flags} where:

  • flags & 1 → transpose A (read A as column-major)
  • flags & 2 → transpose B (read B as column-major)

The Python helper _run_mm() automatically selects the shader based on matrix dimensions (threshold: 128).

Convolution Shaders

Shader Purpose
nn_im2col.hlsl Extracts image patches into column matrix for convolution
nn_col2im.hlsl Scatters column matrix back to image format (backward pass)
nn_conv_reshape.hlsl Reshapes matmul output to (batch, C, H, W) + adds bias
nn_conv_grad_reshape.hlsl Reshapes (batch, C, H, W) gradient to matrix form for backward matmul

Activation / Loss Shaders

Shader Purpose
nn_relu.hlsl ReLU forward: max(0, x)
nn_relu_grad.hlsl ReLU backward: grad * (x > 0)
nn_softmax.hlsl Numerically stable softmax (per-row max subtraction)
nn_softmax_ce_grad.hlsl Combined softmax + cross-entropy gradient: softmax - one_hot
nn_loss.hlsl Cross-entropy loss computation

Pooling Shaders

Shader Purpose
nn_maxpool_forward.hlsl Max pooling with index tracking (stores argmax for backward)
nn_maxpool_backward.hlsl Scatter gradients to max positions using saved indices

Utility Shaders

Shader Purpose
nn_add_bias.hlsl Adds bias vector to each row: out[i] = A[i] + B[i % cols]
nn_bias_grad.hlsl Sum-reduces rows to compute bias gradient
nn_sgd.hlsl SGD update: param -= lr * clamp(grad, -clip, clip)
nn_grad_accum.hlsl In-place gradient accumulation: accum += grad

Legacy / Unused Shaders

These exist in the repo but are not used by the current engine — the im2col approach replaced them:

Shader Note
nn_conv_forward.hlsl Direct convolution (replaced by im2col + matmul)
nn_conv_forward_tiled.hlsl Tiled direct convolution (replaced)
nn_conv_backprop_filters.hlsl Direct filter gradient (replaced by matmul transpose)
nn_conv_backprop_filters_tiled.hlsl Tiled version (still compiled but unused by im2col path)
nn_conv_backprop_input.hlsl Direct input gradient (replaced by matmul transpose + col2im)
nn_conv_backprop_input_fused.hlsl Fused version (still compiled but unused)
nn_conv_reduce_filters.hlsl Filter gradient reduction (unused)
nn_matmul.hlsl Original matmul without transpose (superseded by universal)
nn_matmul_transpose_a.hlsl Dedicated transpose-A matmul (superseded by flags)
nn_matmul_transpose_b.hlsl Dedicated transpose-B matmul (superseded by flags)

Optimizations

Constant Buffer Caching

The engine reuses a single D3D11_USAGE_DYNAMIC constant buffer across all dispatches via MAP_WRITE_DISCARD. This avoids creating and destroying a CB per shader call. The cache grows if a larger CB is needed but never shrinks.

Adaptive Matmul Selection

_COARSEN_THRESHOLD = 128
if M >= 128 and N >= 128:
    use nn_matmul_coarsened   # 64×64 tiles, 4×4 work-per-thread
else:
    use nn_matmul_universal   # 16×16 tiles, 1×1 work-per-thread

The coarsened shader is ~2× faster for large matrices but has tile overhead that hurts small ones (like LeNet's 84×10 FC layer).

Transpose via Flags (No Extra Buffers)

Instead of explicit transpose operations or separate transpose shaders, both matmul shaders accept a flags field in the constant buffer. The shader adjusts its indexing at load time:

uint idxA = (flags & 1) ? (col * M + row) : (row * K + col);  // transpose A
uint idxB = (flags & 2) ? (col * K + row) : (row * N + col);  // transpose B

This saves GPU memory and dispatch overhead for backward-pass transposes.

SRV/UAV Unbinding

After every Dispatch(), the engine unbinds all SRVs and UAVs:

g_context->CSSetShaderResources(0, srvCount, nullSRVs);
g_context->CSSetUnorderedAccessViews(0, uavCount, nullUAVs, nullptr);

This is critical in D3D11. Without it, a buffer written as UAV in one dispatch and read as SRV in the next will silently return stale/zero data, because D3D11 detects the binding conflict and unbinds the SRV automatically.

Per-Dispatch Flush

Each RunShader() call ends with g_context->Flush(). While counter-intuitive (batching dispatches seems better), removing Flush causes a 2× slowdown on D3D11 because:

  • Without Flush, the CPU queues many dispatches but the GPU sits idle until ReadBuffer() triggers execution
  • With Flush, the GPU starts executing immediately while the CPU prepares the next dispatch
  • This enables CPU-GPU pipelining, which is essential for the many small dispatches in a neural network training step

Memory Management

  • Tracked tensors: Intermediate tensors created during forward/backward are registered via track=True and bulk-freed by release_all_buffers() at the end of each batch
  • Persistent tensors: Weights and biases use track=False and survive across batches
  • Ref counting: AddRefBuffer/ReleaseBuffer handle shared ownership (e.g., Flatten shares its input's GPU buffer)
  • im2col buffer: Saved during forward for reuse in backward, then explicitly released

File Structure

engine.cpp              # D3D11 compute shader runtime (→ engine.dll)
compile_engine.bat      # Build script for engine.dll
nn_engine.py            # Python autograd framework + ctypes bindings
train_lenet.py          # LeNet-5 on MNIST (1→6→16 conv, 256→120→84→10 FC)
train_alexnet.py        # AlexNet on PetImages (5 conv layers, 3 FC layers)
nn_*.hlsl               # HLSL compute shaders (one per operation)
mnist/                  # MNIST digit images (0-9 subfolders)
PetImages/              # Cat/Dog image dataset

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

directcompute_nn-0.1.0.tar.gz (124.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

directcompute_nn-0.1.0-py3-none-any.whl (22.7 kB view details)

Uploaded Python 3

File details

Details for the file directcompute_nn-0.1.0.tar.gz.

File metadata

  • Download URL: directcompute_nn-0.1.0.tar.gz
  • Upload date:
  • Size: 124.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for directcompute_nn-0.1.0.tar.gz
Algorithm Hash digest
SHA256 20f727494a5b575fc69d3c5ddec7d43359220ed12f661dbe334fc075625ffaab
MD5 4b5d30af2cdee86d29e2e4096813f831
BLAKE2b-256 b09c4b5785a967ae20dd87133e65005156fb901b978a41af2d72a1749b994723

See more details on using hashes here.

File details

Details for the file directcompute_nn-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for directcompute_nn-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8fc01c4d802a93f520d8c6e5dd0144c088728c0a05827ec8bcd7e905bffa8400
MD5 b86e7c1b2494186d9002916f4093ac6e
BLAKE2b-256 8b36f0411642e720a79339d15ed3bdb1d60b15a38c6430793003398c5e73e6ea

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page