A high-performance CNN framework: SciPy for CPU optimization, JAX for GPU/TPU
Project description
๐ง ConvNet
A high-performance, educational CNN framework: SciPy optimization for CPU, JAX for GPU/TPU
This project was created as a school assignment with the goal of understanding deep learning from the ground up. It's designed to be easy to understand and learn from, implementing a complete CNN framework with SciPy optimization for efficient CPU training and optional JAX acceleration for GPU/TPU. The framework uses simple, readable code while delivering excellent performance.
๐ Features
Core Functionality
- โ Pure Python Core - Clean, educational code
- ๐ฅ Complete CNN Support - Conv2D, MaxPool2D, Flatten, Dense layers
- ๐ Modern Training - Batch normalization, dropout, early stopping
- ๐ฏ Smart Optimizers - SGD with momentum and Adam optimizer
- ๐ Learning Rate Scheduling - Plateau-based LR reduction
- ๐พ Model Persistence - Save/load models in HDF5 or NPZ format
- ๐ Data Augmentation Ready - Thread-pooled data loading
Performance Options
- ๐ SciPy CPU Optimization - BLAS-based linear algebra and optimized operations
- โก JAX GPU/TPU Support - XLA compilation for maximum throughput
- ๐ฆ NumPy Fallback - Pure NumPy when SciPy unavailable
Developer Experience
- ๐ Clean Code - Well-documented and easy to follow
- ๐ Educational - Built for learning deep learning fundamentals
- ๐ง Modular Design - Easy to extend and customize
- ๐ป Examples Included - MNIST training example and GUI demo
๐ Quick Start
Installation
Install from PyPI (Recommended):
# Install base package (includes SciPy for CPU optimization)
pip install convnet
# For GPU/TPU support, add JAX:
pip install convnet[jax] # CPU JAX
pip install convnet[gpu] # NVIDIA GPU (CUDA 12)
pip install convnet[tpu] # Google TPU
Install from Source:
# Clone the repository
git clone https://github.com/codinggamer-dev/ConvNet.git
cd ConvNet
# Install in development mode
pip install -e .
# Add JAX for GPU acceleration
pip install jax jaxlib # For CPU JAX
pip install jax[cuda12] # For NVIDIA GPU
Performance Expectations
| Backend | MNIST Training Speed | Best For |
|---|---|---|
| SciPy+NumPy (CPU) | ~12-15 it/s | CPU training, educational |
| Pure NumPy | ~8-10 it/s | Fallback, compatibility |
| JAX (CPU) | ~10-12 it/s | Development |
| JAX (GPU) | ~50+ it/s | Production training |
Note: SciPy uses optimized BLAS routines for linear algebra. Actual training speed depends on CPU and BLAS library. For maximum performance, ensure you have a modern CPU with AVX2 support.
Your First Neural Network in 10 Lines
from convnet import Model, Conv2D, Activation, MaxPool2D, Flatten, Dense
# Build a simple CNN
model = Model([
Conv2D(8, (3, 3)), Activation('relu'),
MaxPool2D((2, 2)),
Flatten(),
Dense(10), Activation('softmax')
])
# Configure training
model.compile(loss='categorical_crossentropy', optimizer='adam', lr=0.001)
# Train on your data
history = model.fit(train_dataset, epochs=10, batch_size=32, num_classes=10)
๐ Complete MNIST Example
import numpy as np
from convnet import Model, Conv2D, Activation, MaxPool2D, Flatten, Dense, Dropout, Dataset
from convnet.data import load_mnist_gz # or load_dataset_gz for other datasets
# Load MNIST data
train_data, test_data = load_mnist_gz('mnist_dataset')
# For EMNIST or other IDX format datasets, use load_dataset_gz:
# from convnet.data import load_dataset_gz
# train_data, test_data = load_dataset_gz(
# 'emnist_dataset',
# train_images_file='emnist-letters-train-images-idx3-ubyte.gz',
# train_labels_file='emnist-letters-train-labels-idx1-ubyte.gz',
# test_images_file='emnist-letters-test-images-idx3-ubyte.gz',
# test_labels_file='emnist-letters-test-labels-idx1-ubyte.gz'
# )
# Build the model
model = Model([
Conv2D(8, (3, 3)), Activation('relu'),
MaxPool2D((2, 2)),
Conv2D(16, (3, 3)), Activation('relu'),
MaxPool2D((2, 2)),
Flatten(),
Dense(64), Activation('relu'), Dropout(0.2),
Dense(10) # or 26 for EMNIST Letters
])
# Compile and train
model.compile(loss='categorical_crossentropy', optimizer='adam', lr=0.001)
history = model.fit(train_data, epochs=50, batch_size=128, num_classes=10)
# Save the model
model.save('my_model.hdf5')
๐งฉ Architecture Components
Available Layers
| Layer | Description | Parameters |
|---|---|---|
Conv2D(filters, kernel_size) |
2D Convolutional layer | filters, kernel_size, stride, padding |
Dense(units) |
Fully connected layer | units, use_bias |
MaxPool2D(pool_size) |
Max pooling layer | pool_size, stride |
Activation(type) |
Activation function | 'relu', 'tanh', 'sigmoid', 'softmax' |
Flatten() |
Reshape to 1D | None |
Dropout(rate) |
Dropout regularization | rate (0.0 to 1.0) |
BatchNorm2D() |
Batch normalization | momentum, epsilon |
Optimizers
-
SGD - Stochastic Gradient Descent with momentum
model.compile(optimizer='sgd', lr=0.01, momentum=0.9)
-
Adam - Adaptive Moment Estimation (recommended)
model.compile(optimizer='adam', lr=0.001, beta1=0.9, beta2=0.999)
Loss Functions
'categorical_crossentropy'- For multi-class classification'mse'- Mean Squared Error for regression
๐ฎ Examples & Demos
The examples/ directory contains several demonstrations:
1. MNIST Training (mnist_train-example.py)
Complete training pipeline with early stopping, LR scheduling, and model persistence.
python examples/mnist_train-example.py
2. Interactive GUI Demo (mnist_gui.py)
Draw digits and see real-time predictions! Requires tkinter.
python examples/mnist_gui.py
3. GPU Training Test (test_gpu_training.py)
Benchmark GPU vs CPU performance.
python examples/test_gpu_training.py
4. JAX Benchmark (benchmark_jax.py)
Compare JAX JIT performance vs NumPy baseline.
python examples/benchmark_jax.py
โ๏ธ Advanced Features
GPU/TPU Acceleration
ConvNet automatically detects and uses available hardware accelerators via JAX:
# Install with GPU support
pip install convnet[gpu]
# Or for TPU
pip install convnet[tpu]
# Or install JAX with CUDA manually
pip install jax[cuda12] # For CUDA 12.x
The framework will automatically:
- Detect available GPUs/TPUs
- JIT-compile operations with XLA
- Move tensors to accelerator devices
- Handle device transfers transparently
Regularization
model.compile(
optimizer='adam',
lr=0.001,
weight_decay=1e-4, # L2 regularization
clip_norm=5.0 # Gradient clipping
)
Learning Rate Scheduling
history = model.fit(
dataset,
lr_schedule='plateau', # Reduce LR when validation plateaus
lr_factor=0.5, # Multiply LR by 0.5
lr_patience=5, # Wait 5 epochs before reducing
lr_min=1e-6 # Minimum learning rate
)
Early Stopping
history = model.fit(
dataset,
val_data=(X_val, y_val),
early_stopping=True,
patience=10, # Stop after 10 epochs without improvement
min_delta=0.001 # Minimum change to qualify as improvement
)
๐ Model Inspection
# Print model architecture and parameter counts
model.summary()
# Output:
# Model summary:
# Conv2D: params=80
# Activation: params=0
# MaxPool2D: params=0
# Conv2D: params=1168
# Activation: params=0
# MaxPool2D: params=0
# Flatten: params=0
# Dense: params=40064
# Activation: params=0
# Dropout: params=0
# Dense: params=650
# Total params: 41962
๐ง Configuration
Thread Configuration
The framework automatically configures BLAS threads for optimal CPU performance:
import os
os.environ['NN_DISABLE_AUTO_THREADS'] = '1' # Disable auto-configuration
import convnet
Custom RNG Seeds
For reproducibility:
import numpy as np
rng = np.random.default_rng(seed=42)
model = Model([
Conv2D(8, (3, 3), rng=rng),
Dense(10, rng=rng)
])
๐ Understanding the Code
This project is designed for learning. Here's how to explore:
Start Here
convnet/layers.py- See how Conv2D, Dense, and other layers workconvnet/model.py- Understand forward/backward propagationconvnet/optim.py- Learn how optimizers update weightsexamples/mnist_train-example.py- Complete training example
Key Concepts Implemented
- ๐ Backpropagation - Full gradient computation chain
- ๐ Gradient Descent - SGD and Adam optimization
- ๐ฒ Weight Initialization - Glorot/Xavier uniform
- ๐งฎ Convolution Math - JIT-compiled implementation
- ๐ Batch Normalization - Running mean/variance tracking
- ๐ฏ Softmax & Cross-Entropy - Numerically stable implementation
๐ฏ Project Goals
This framework was built to:
- Understand deep learning by implementing it from scratch
- Learn how CNNs actually work under the hood
- Teach others the fundamentals of neural networks
- Provide a clean, readable codebase for education
Not for production use - Use PyTorch, TensorFlow, or JAX for real applications!
๐ฆ Project Structure
ConvNet/
โโโ convnet/ # Core framework
โ โโโ __init__.py # Package initialization & auto-config
โ โโโ layers.py # Layer implementations
โ โโโ model.py # Model class with training loop
โ โโโ optim.py # Optimizers (SGD, Adam)
โ โโโ losses.py # Loss functions
โ โโโ data.py # Data loading utilities
โ โโโ utils.py # Helper functions
โ โโโ jax_backend.py # JAX acceleration backend
โ โโโ io.py # Model save/load
โโโ examples/ # Example scripts
โ โโโ mnist_train-example.py
โ โโโ mnist_gui.py
โ โโโ mnist_gui.py
โโโ requirements.txt # Dependencies
โโโ setup.py # Package setup
โโโ LICENSE.md # MIT License
โโโ README.md # This file
๐ค Contributing
This is an educational project, but contributions are welcome! Feel free to:
- ๐ Report bugs
- ๐ก Suggest improvements
- ๐ Improve documentation
- โจ Add new features
๐ Requirements
Core Dependencies
- Python 3.8 or higher
- JAX โฅ 0.4.0 (high-performance computing! ๐)
- jaxlib โฅ 0.4.0 (JAX runtime)
- NumPy โฅ 1.20.0 (compatibility layer)
- tqdm โฅ 4.60.0 (progress bars)
- h5py โฅ 3.0.0 (model serialization)
Optional Dependencies
- jax[cuda12] (GPU acceleration)
- jax[tpu] (TPU acceleration)
- tkinter (for GUI demo, usually included with Python)
๐ License
This project is licensed under the MIT License - see the LICENSE.md file for details.
Copyright (c) 2025 Tim Bauer
๐ Acknowledgments
- Built as a school project to learn deep learning fundamentals
- Inspired by PyTorch and TensorFlow's clean APIs
- Powered by JAX and XLA for high-performance computation
- MNIST dataset by Yann LeCun and Corinna Cortes - the perfect dataset for learning CNNs
๐ฌ Questions?
Feel free to open an issue on GitHub if you have questions or run into problems!
Made with โค๏ธ for learning and education
โญ If this helped you understand CNNs better, consider giving it a star! โญ
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file convnet-2.5.3.tar.gz.
File metadata
- Download URL: convnet-2.5.3.tar.gz
- Upload date:
- Size: 38.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7de25999160f5eff5f9397a80c2426f59a99a84179988c6ef5ac172dafe5b8c7
|
|
| MD5 |
40793448df3dcb4bdb874102457db317
|
|
| BLAKE2b-256 |
68ddefca7d103d7dba6da7d85b559c3c78ffcb304d7cb1bf236c91630f789ff7
|
File details
Details for the file convnet-2.5.3-py3-none-any.whl.
File metadata
- Download URL: convnet-2.5.3-py3-none-any.whl
- Upload date:
- Size: 30.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d436d401e6c00c32c51fdcaed7e5bcfa9d8c680359ea7e2fc91cf80899a18b5c
|
|
| MD5 |
523f5533170f89042a48cd59c9633bcd
|
|
| BLAKE2b-256 |
11ef7db6a3abad6a8253e6770b5d4a297b454687c210bdf661821b5798af301a
|