Skip to main content

A high-performance CNN framework: Numba for CPU, JAX for GPU/TPU

Project description

๐Ÿง  ConvNet

PyPI Python Numba JAX License: MIT Status

A high-performance, educational CNN framework: Numba for CPU, JAX for GPU/TPU

This project was created as a school assignment with the goal of understanding deep learning from the ground up. It's designed to be easy to understand and learn from, implementing a complete CNN framework with Numba acceleration for fast CPU training and optional JAX acceleration for GPU/TPU. The framework uses simple, readable code while delivering excellent performance.


๐ŸŒŸ Features

Core Functionality

  • โœ… Pure Python Core - Clean, educational code
  • ๐Ÿ”ฅ Complete CNN Support - Conv2D, MaxPool2D, Flatten, Dense layers
  • ๐Ÿ“Š Modern Training - Batch normalization, dropout, early stopping
  • ๐ŸŽฏ Smart Optimizers - SGD with momentum and Adam optimizer
  • ๐Ÿ“ˆ Learning Rate Scheduling - Plateau-based LR reduction
  • ๐Ÿ’พ Model Persistence - Save/load models in HDF5 or NPZ format
  • ๐Ÿ”„ Data Augmentation Ready - Thread-pooled data loading

Performance Options

  • ๐Ÿš€ Numba CPU Acceleration - Parallel JIT-compiled operations (15+ it/s)
  • โšก JAX GPU/TPU Support - XLA compilation for maximum throughput
  • ๐Ÿ“ฆ OpenBLAS Integration - SIMD-optimized matrix operations

Developer Experience

  • ๐Ÿ“š Clean Code - Well-documented and easy to follow
  • ๐ŸŽ“ Educational - Built for learning deep learning fundamentals
  • ๐Ÿ”ง Modular Design - Easy to extend and customize
  • ๐Ÿ’ป Examples Included - MNIST training example and GUI demo

๐Ÿš€ Quick Start

Installation

Install from PyPI (Recommended):

# Install base package (includes Numba for fast CPU)
pip install convnet

# For GPU/TPU support, add JAX:
pip install convnet[jax]      # CPU JAX
pip install convnet[gpu]      # NVIDIA GPU (CUDA 12)
pip install convnet[tpu]      # Google TPU

Install from Source:

# Clone the repository
git clone https://github.com/codinggamer-dev/ConvNet.git
cd ConvNet

# Install in development mode
pip install -e .

# Add JAX for GPU acceleration
pip install jax jaxlib  # For CPU JAX
pip install jax[cuda12]  # For NVIDIA GPU

Performance Expectations

Backend MNIST Training Speed Best For
Numba (CPU) ~16 it/s (model only) Fast CPU training
Pure NumPy ~6-8 it/s Fallback, compatibility
JAX (CPU) ~10-12 it/s Development
JAX (GPU) ~50+ it/s Production training

Note: Numba uses all available CPU cores with parallel JIT compilation. Actual training speed is ~10-12 it/s due to loss/optimizer/data loading overhead. For maximum performance, ensure you have a modern CPU with AVX2 support.

Your First Neural Network in 10 Lines

from convnet import Model, Conv2D, Activation, MaxPool2D, Flatten, Dense

# Build a simple CNN
model = Model([
    Conv2D(8, (3, 3)), Activation('relu'),
    MaxPool2D((2, 2)),
    Flatten(),
    Dense(10), Activation('softmax')
])

# Configure training
model.compile(loss='categorical_crossentropy', optimizer='adam', lr=0.001)

# Train on your data
history = model.fit(train_dataset, epochs=10, batch_size=32, num_classes=10)

๐Ÿ“– Complete MNIST Example

import numpy as np
from convnet import Model, Conv2D, Activation, MaxPool2D, Flatten, Dense, Dropout, Dataset
from convnet.data import load_mnist_gz

# Load MNIST data
train_data, test_data = load_mnist_gz('mnist_dataset')

# Build the model
model = Model([
    Conv2D(8, (3, 3)), Activation('relu'),
    MaxPool2D((2, 2)),
    Conv2D(16, (3, 3)), Activation('relu'),
    MaxPool2D((2, 2)),
    Flatten(),
    Dense(64), Activation('relu'), Dropout(0.2),
    Dense(10)
])

# Compile and train
model.compile(loss='categorical_crossentropy', optimizer='adam', lr=0.001)
history = model.fit(train_data, epochs=50, batch_size=128, num_classes=10)

# Save the model
model.save('my_model.hdf5')

๐Ÿงฉ Architecture Components

Available Layers

Layer Description Parameters
Conv2D(filters, kernel_size) 2D Convolutional layer filters, kernel_size, stride, padding
Dense(units) Fully connected layer units, use_bias
MaxPool2D(pool_size) Max pooling layer pool_size, stride
Activation(type) Activation function 'relu', 'tanh', 'sigmoid', 'softmax'
Flatten() Reshape to 1D None
Dropout(rate) Dropout regularization rate (0.0 to 1.0)
BatchNorm2D() Batch normalization momentum, epsilon

Optimizers

  • SGD - Stochastic Gradient Descent with momentum

    model.compile(optimizer='sgd', lr=0.01, momentum=0.9)
    
  • Adam - Adaptive Moment Estimation (recommended)

    model.compile(optimizer='adam', lr=0.001, beta1=0.9, beta2=0.999)
    

Loss Functions

  • 'categorical_crossentropy' - For multi-class classification
  • 'mse' - Mean Squared Error for regression

๐ŸŽฎ Examples & Demos

The examples/ directory contains several demonstrations:

1. MNIST Training (mnist_train-example.py)

Complete training pipeline with early stopping, LR scheduling, and model persistence.

python examples/mnist_train-example.py

2. Interactive GUI Demo (mnist_gui.py)

Draw digits and see real-time predictions! Requires tkinter.

python examples/mnist_gui.py

3. GPU Training Test (test_gpu_training.py)

Benchmark GPU vs CPU performance.

python examples/test_gpu_training.py

4. JAX Benchmark (benchmark_jax.py)

Compare JAX JIT performance vs NumPy baseline.

python examples/benchmark_jax.py

โš™๏ธ Advanced Features

GPU/TPU Acceleration

ConvNet automatically detects and uses available hardware accelerators via JAX:

# Install with GPU support
pip install convnet[gpu]

# Or for TPU
pip install convnet[tpu]

# Or install JAX with CUDA manually
pip install jax[cuda12]  # For CUDA 12.x

The framework will automatically:

  • Detect available GPUs/TPUs
  • JIT-compile operations with XLA
  • Move tensors to accelerator devices
  • Handle device transfers transparently

Regularization

model.compile(
    optimizer='adam',
    lr=0.001,
    weight_decay=1e-4,  # L2 regularization
    clip_norm=5.0        # Gradient clipping
)

Learning Rate Scheduling

history = model.fit(
    dataset,
    lr_schedule='plateau',  # Reduce LR when validation plateaus
    lr_factor=0.5,         # Multiply LR by 0.5
    lr_patience=5,         # Wait 5 epochs before reducing
    lr_min=1e-6           # Minimum learning rate
)

Early Stopping

history = model.fit(
    dataset,
    val_data=(X_val, y_val),
    early_stopping=True,
    patience=10,      # Stop after 10 epochs without improvement
    min_delta=0.001   # Minimum change to qualify as improvement
)

๐Ÿ“Š Model Inspection

# Print model architecture and parameter counts
model.summary()

# Output:
# Model summary:
# Conv2D: params=80
# Activation: params=0
# MaxPool2D: params=0
# Conv2D: params=1168
# Activation: params=0
# MaxPool2D: params=0
# Flatten: params=0
# Dense: params=40064
# Activation: params=0
# Dropout: params=0
# Dense: params=650
# Total params: 41962

๐Ÿ”ง Configuration

Thread Configuration

The framework automatically configures BLAS threads for optimal CPU performance:

import os
os.environ['NN_DISABLE_AUTO_THREADS'] = '1'  # Disable auto-configuration
import convnet

Custom RNG Seeds

For reproducibility:

import numpy as np
rng = np.random.default_rng(seed=42)

model = Model([
    Conv2D(8, (3, 3), rng=rng),
    Dense(10, rng=rng)
])

๐Ÿ“š Understanding the Code

This project is designed for learning. Here's how to explore:

Start Here

  1. convnet/layers.py - See how Conv2D, Dense, and other layers work
  2. convnet/model.py - Understand forward/backward propagation
  3. convnet/optim.py - Learn how optimizers update weights
  4. examples/mnist_train-example.py - Complete training example

Key Concepts Implemented

  • ๐Ÿ”„ Backpropagation - Full gradient computation chain
  • ๐Ÿ“‰ Gradient Descent - SGD and Adam optimization
  • ๐ŸŽฒ Weight Initialization - Glorot/Xavier uniform
  • ๐Ÿงฎ Convolution Math - JIT-compiled implementation
  • ๐Ÿ“Š Batch Normalization - Running mean/variance tracking
  • ๐ŸŽฏ Softmax & Cross-Entropy - Numerically stable implementation

๐ŸŽฏ Project Goals

This framework was built to:

  1. Understand deep learning by implementing it from scratch
  2. Learn how CNNs actually work under the hood
  3. Teach others the fundamentals of neural networks
  4. Provide a clean, readable codebase for education

Not for production use - Use PyTorch, TensorFlow, or JAX for real applications!


๐Ÿ“ฆ Project Structure

ConvNet/
โ”œโ”€โ”€ convnet/              # Core framework
โ”‚   โ”œโ”€โ”€ __init__.py       # Package initialization & auto-config
โ”‚   โ”œโ”€โ”€ layers.py         # Layer implementations
โ”‚   โ”œโ”€โ”€ model.py          # Model class with training loop
โ”‚   โ”œโ”€โ”€ optim.py          # Optimizers (SGD, Adam)
โ”‚   โ”œโ”€โ”€ losses.py         # Loss functions
โ”‚   โ”œโ”€โ”€ data.py           # Data loading utilities
โ”‚   โ”œโ”€โ”€ utils.py          # Helper functions
โ”‚   โ”œโ”€โ”€ jax_backend.py    # JAX acceleration backend
โ”‚   โ””โ”€โ”€ io.py             # Model save/load
โ”œโ”€โ”€ examples/             # Example scripts
โ”‚   โ”œโ”€โ”€ mnist_train-example.py
โ”‚   โ”œโ”€โ”€ mnist_gui.py
โ”‚   โ””โ”€โ”€ mnist_gui.py
โ”œโ”€โ”€ requirements.txt      # Dependencies
โ”œโ”€โ”€ setup.py              # Package setup
โ”œโ”€โ”€ LICENSE.md            # MIT License
โ””โ”€โ”€ README.md             # This file

๐Ÿค Contributing

This is an educational project, but contributions are welcome! Feel free to:

  • ๐Ÿ› Report bugs
  • ๐Ÿ’ก Suggest improvements
  • ๐Ÿ“– Improve documentation
  • โœจ Add new features

๐Ÿ“ Requirements

Core Dependencies

  • Python 3.8 or higher
  • JAX โ‰ฅ 0.4.0 (high-performance computing! ๐Ÿš€)
  • jaxlib โ‰ฅ 0.4.0 (JAX runtime)
  • NumPy โ‰ฅ 1.20.0 (compatibility layer)
  • tqdm โ‰ฅ 4.60.0 (progress bars)
  • h5py โ‰ฅ 3.0.0 (model serialization)

Optional Dependencies

  • jax[cuda12] (GPU acceleration)
  • jax[tpu] (TPU acceleration)
  • tkinter (for GUI demo, usually included with Python)

๐Ÿ“„ License

This project is licensed under the MIT License - see the LICENSE.md file for details.

Copyright (c) 2025 Tim Bauer


๐Ÿ™ Acknowledgments

  • Built as a school project to learn deep learning fundamentals
  • Inspired by PyTorch and TensorFlow's clean APIs
  • Powered by JAX and XLA for high-performance computation
  • MNIST dataset by Yann LeCun and Corinna Cortes - the perfect dataset for learning CNNs

๐Ÿ’ฌ Questions?

Feel free to open an issue on GitHub if you have questions or run into problems!


Made with โค๏ธ for learning and education

โญ If this helped you understand CNNs better, consider giving it a star! โญ

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

convnet-2.2.0.tar.gz (37.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

convnet-2.2.0-py3-none-any.whl (32.2 kB view details)

Uploaded Python 3

File details

Details for the file convnet-2.2.0.tar.gz.

File metadata

  • Download URL: convnet-2.2.0.tar.gz
  • Upload date:
  • Size: 37.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for convnet-2.2.0.tar.gz
Algorithm Hash digest
SHA256 07a923adc045920cf5feba21f6ba0f9d4ed300be38a140289bad11c73b42dd2c
MD5 e6aaf9e8cbbae02e1e2d1fab6efada8c
BLAKE2b-256 30d2fc39c9b34a6392a83aca3b7aaa8794c527c2a241d8461f5a1ac7ff4e7137

See more details on using hashes here.

File details

Details for the file convnet-2.2.0-py3-none-any.whl.

File metadata

  • Download URL: convnet-2.2.0-py3-none-any.whl
  • Upload date:
  • Size: 32.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for convnet-2.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 dc52ede5cd09b3812a33e3119ff8e55d9038f0aa71f7826b9c8027ce86b88848
MD5 4dfdff2813eff32b5e2bedebcbaba4a9
BLAKE2b-256 20f7f90f1d5bdb2d046b89983c7bd8959becd209113c34278faba01a1a975703

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page