Skip to main content

A high-performance CNN framework: SciPy for CPU optimization, JAX for GPU/TPU

Project description

๐Ÿง  ConvNet

PyPI Python SciPy JAX License: MIT Status

A high-performance, educational CNN framework: SciPy optimization for CPU, JAX for GPU/TPU

This project was created as a school assignment with the goal of understanding deep learning from the ground up. It's designed to be easy to understand and learn from, implementing a complete CNN framework with SciPy optimization for efficient CPU training and optional JAX acceleration for GPU/TPU. The framework uses simple, readable code while delivering excellent performance.


๐ŸŒŸ Features

Core Functionality

  • โœ… Pure Python Core - Clean, educational code
  • ๐Ÿ”ฅ Complete CNN Support - Conv2D, MaxPool2D, Flatten, Dense layers
  • ๐Ÿ“Š Modern Training - Batch normalization, dropout, early stopping
  • ๐ŸŽฏ Smart Optimizers - SGD with momentum and Adam optimizer
  • ๐Ÿ“ˆ Learning Rate Scheduling - Plateau-based LR reduction
  • ๐Ÿ’พ Model Persistence - Save/load models in HDF5 or NPZ format
  • ๐Ÿ”„ Data Augmentation Ready - Thread-pooled data loading

Performance Options

  • ๐Ÿš€ SciPy CPU Optimization - BLAS-based linear algebra and optimized operations
  • โšก JAX GPU/TPU Support - XLA compilation for maximum throughput
  • ๐Ÿ“ฆ NumPy Fallback - Pure NumPy when SciPy unavailable

Developer Experience

  • ๐Ÿ“š Clean Code - Well-documented and easy to follow
  • ๐ŸŽ“ Educational - Built for learning deep learning fundamentals
  • ๐Ÿ”ง Modular Design - Easy to extend and customize
  • ๐Ÿ’ป Examples Included - MNIST training example and GUI demo

๐Ÿš€ Quick Start

Installation

Install from PyPI (Recommended):

# Install base package (includes SciPy for CPU optimization)
pip install convnet

# For GPU/TPU support, add JAX:
pip install convnet[jax]      # CPU JAX
pip install convnet[gpu]      # NVIDIA GPU (CUDA 12)
pip install convnet[tpu]      # Google TPU

Install from Source:

# Clone the repository
git clone https://github.com/codinggamer-dev/ConvNet.git
cd ConvNet

# Install in development mode
pip install -e .

# Add JAX for GPU acceleration
pip install jax jaxlib  # For CPU JAX
pip install jax[cuda12]  # For NVIDIA GPU

Performance Expectations

Backend MNIST Training Speed Best For
SciPy+NumPy (CPU) ~12-15 it/s CPU training, educational
Pure NumPy ~8-10 it/s Fallback, compatibility
JAX (CPU) ~10-12 it/s Development
JAX (GPU) ~50+ it/s Production training

Note: SciPy uses optimized BLAS routines for linear algebra. Actual training speed depends on CPU and BLAS library. For maximum performance, ensure you have a modern CPU with AVX2 support.

Your First Neural Network in 10 Lines

from convnet import Model, Conv2D, Activation, MaxPool2D, Flatten, Dense

# Build a simple CNN
model = Model([
    Conv2D(8, (3, 3)), Activation('relu'),
    MaxPool2D((2, 2)),
    Flatten(),
    Dense(10), Activation('softmax')
])

# Configure training
model.compile(loss='categorical_crossentropy', optimizer='adam', lr=0.001)

# Train on your data
history = model.fit(train_dataset, epochs=10, batch_size=32, num_classes=10)

๐Ÿ“– Complete MNIST Example

import numpy as np
from convnet import Model, Conv2D, Activation, MaxPool2D, Flatten, Dense, Dropout, Dataset
from convnet.data import load_mnist_gz  # or load_dataset_gz for other datasets

# Load MNIST data
train_data, test_data = load_mnist_gz('mnist_dataset')

# For EMNIST or other IDX format datasets, use load_dataset_gz:
# from convnet.data import load_dataset_gz
# train_data, test_data = load_dataset_gz(
#     'emnist_dataset',
#     train_images_file='emnist-letters-train-images-idx3-ubyte.gz',
#     train_labels_file='emnist-letters-train-labels-idx1-ubyte.gz',
#     test_images_file='emnist-letters-test-images-idx3-ubyte.gz',
#     test_labels_file='emnist-letters-test-labels-idx1-ubyte.gz'
# )

# Build the model
model = Model([
    Conv2D(8, (3, 3)), Activation('relu'),
    MaxPool2D((2, 2)),
    Conv2D(16, (3, 3)), Activation('relu'),
    MaxPool2D((2, 2)),
    Flatten(),
    Dense(64), Activation('relu'), Dropout(0.2),
    Dense(10)  # or 26 for EMNIST Letters
])

# Compile and train
model.compile(loss='categorical_crossentropy', optimizer='adam', lr=0.001)
history = model.fit(train_data, epochs=50, batch_size=128, num_classes=10)

# Save the model
model.save('my_model.hdf5')

๐Ÿงฉ Architecture Components

Available Layers

Layer Description Parameters
Conv2D(filters, kernel_size) 2D Convolutional layer filters, kernel_size, stride, padding
Dense(units) Fully connected layer units, use_bias
MaxPool2D(pool_size) Max pooling layer pool_size, stride
Activation(type) Activation function 'relu', 'tanh', 'sigmoid', 'softmax'
Flatten() Reshape to 1D None
Dropout(rate) Dropout regularization rate (0.0 to 1.0)
BatchNorm2D() Batch normalization momentum, epsilon

Optimizers

  • SGD - Stochastic Gradient Descent with momentum

    model.compile(optimizer='sgd', lr=0.01, momentum=0.9)
    
  • Adam - Adaptive Moment Estimation (recommended)

    model.compile(optimizer='adam', lr=0.001, beta1=0.9, beta2=0.999)
    

Loss Functions

  • 'categorical_crossentropy' - For multi-class classification
  • 'mse' - Mean Squared Error for regression

๐ŸŽฎ Examples & Demos

The examples/ directory contains several demonstrations:

1. MNIST Training (mnist_train-example.py)

Complete training pipeline with early stopping, LR scheduling, and model persistence.

python examples/mnist_train-example.py

2. Interactive GUI Demo (mnist_gui.py)

Draw digits and see real-time predictions! Requires tkinter.

python examples/mnist_gui.py

3. GPU Training Test (test_gpu_training.py)

Benchmark GPU vs CPU performance.

python examples/test_gpu_training.py

4. JAX Benchmark (benchmark_jax.py)

Compare JAX JIT performance vs NumPy baseline.

python examples/benchmark_jax.py

โš™๏ธ Advanced Features

GPU/TPU Acceleration

ConvNet automatically detects and uses available hardware accelerators via JAX:

# Install with GPU support
pip install convnet[gpu]

# Or for TPU
pip install convnet[tpu]

# Or install JAX with CUDA manually
pip install jax[cuda12]  # For CUDA 12.x

The framework will automatically:

  • Detect available GPUs/TPUs
  • JIT-compile operations with XLA
  • Move tensors to accelerator devices
  • Handle device transfers transparently

Regularization

model.compile(
    optimizer='adam',
    lr=0.001,
    weight_decay=1e-4,  # L2 regularization
    clip_norm=5.0        # Gradient clipping
)

Learning Rate Scheduling

history = model.fit(
    dataset,
    lr_schedule='plateau',  # Reduce LR when validation plateaus
    lr_factor=0.5,         # Multiply LR by 0.5
    lr_patience=5,         # Wait 5 epochs before reducing
    lr_min=1e-6           # Minimum learning rate
)

Early Stopping

history = model.fit(
    dataset,
    val_data=(X_val, y_val),
    early_stopping=True,
    patience=10,      # Stop after 10 epochs without improvement
    min_delta=0.001   # Minimum change to qualify as improvement
)

๐Ÿ“Š Model Inspection

# Print model architecture and parameter counts
model.summary()

# Output:
# Model summary:
# Conv2D: params=80
# Activation: params=0
# MaxPool2D: params=0
# Conv2D: params=1168
# Activation: params=0
# MaxPool2D: params=0
# Flatten: params=0
# Dense: params=40064
# Activation: params=0
# Dropout: params=0
# Dense: params=650
# Total params: 41962

๐Ÿ”ง Configuration

Thread Configuration

The framework automatically configures BLAS threads for optimal CPU performance:

import os
os.environ['NN_DISABLE_AUTO_THREADS'] = '1'  # Disable auto-configuration
import convnet

Custom RNG Seeds

For reproducibility:

import numpy as np
rng = np.random.default_rng(seed=42)

model = Model([
    Conv2D(8, (3, 3), rng=rng),
    Dense(10, rng=rng)
])

๐Ÿ“š Understanding the Code

This project is designed for learning. Here's how to explore:

Start Here

  1. convnet/layers.py - See how Conv2D, Dense, and other layers work
  2. convnet/model.py - Understand forward/backward propagation
  3. convnet/optim.py - Learn how optimizers update weights
  4. examples/mnist_train-example.py - Complete training example

Key Concepts Implemented

  • ๐Ÿ”„ Backpropagation - Full gradient computation chain
  • ๐Ÿ“‰ Gradient Descent - SGD and Adam optimization
  • ๐ŸŽฒ Weight Initialization - Glorot/Xavier uniform
  • ๐Ÿงฎ Convolution Math - JIT-compiled implementation
  • ๐Ÿ“Š Batch Normalization - Running mean/variance tracking
  • ๐ŸŽฏ Softmax & Cross-Entropy - Numerically stable implementation

๐ŸŽฏ Project Goals

This framework was built to:

  1. Understand deep learning by implementing it from scratch
  2. Learn how CNNs actually work under the hood
  3. Teach others the fundamentals of neural networks
  4. Provide a clean, readable codebase for education

Not for production use - Use PyTorch, TensorFlow, or JAX for real applications!


๐Ÿ“ฆ Project Structure

ConvNet/
โ”œโ”€โ”€ convnet/              # Core framework
โ”‚   โ”œโ”€โ”€ __init__.py       # Package initialization & auto-config
โ”‚   โ”œโ”€โ”€ layers.py         # Layer implementations
โ”‚   โ”œโ”€โ”€ model.py          # Model class with training loop
โ”‚   โ”œโ”€โ”€ optim.py          # Optimizers (SGD, Adam)
โ”‚   โ”œโ”€โ”€ losses.py         # Loss functions
โ”‚   โ”œโ”€โ”€ data.py           # Data loading utilities
โ”‚   โ”œโ”€โ”€ utils.py          # Helper functions
โ”‚   โ”œโ”€โ”€ jax_backend.py    # JAX acceleration backend
โ”‚   โ””โ”€โ”€ io.py             # Model save/load
โ”œโ”€โ”€ examples/             # Example scripts
โ”‚   โ”œโ”€โ”€ mnist_train-example.py
โ”‚   โ”œโ”€โ”€ mnist_gui.py
โ”‚   โ””โ”€โ”€ mnist_gui.py
โ”œโ”€โ”€ requirements.txt      # Dependencies
โ”œโ”€โ”€ setup.py              # Package setup
โ”œโ”€โ”€ LICENSE.md            # MIT License
โ””โ”€โ”€ README.md             # This file

๐Ÿค Contributing

This is an educational project, but contributions are welcome! Feel free to:

  • ๐Ÿ› Report bugs
  • ๐Ÿ’ก Suggest improvements
  • ๐Ÿ“– Improve documentation
  • โœจ Add new features

๐Ÿ“ Requirements

Core Dependencies

  • Python 3.8 or higher
  • JAX โ‰ฅ 0.4.0 (high-performance computing! ๐Ÿš€)
  • jaxlib โ‰ฅ 0.4.0 (JAX runtime)
  • NumPy โ‰ฅ 1.20.0 (compatibility layer)
  • tqdm โ‰ฅ 4.60.0 (progress bars)
  • h5py โ‰ฅ 3.0.0 (model serialization)

Optional Dependencies

  • jax[cuda12] (GPU acceleration)
  • jax[tpu] (TPU acceleration)
  • tkinter (for GUI demo, usually included with Python)

๐Ÿ“„ License

This project is licensed under the MIT License - see the LICENSE.md file for details.

Copyright (c) 2025 Tim Bauer


๐Ÿ™ Acknowledgments

  • Built as a school project to learn deep learning fundamentals
  • Inspired by PyTorch and TensorFlow's clean APIs
  • Powered by JAX and XLA for high-performance computation
  • MNIST dataset by Yann LeCun and Corinna Cortes - the perfect dataset for learning CNNs

๐Ÿ’ฌ Questions?

Feel free to open an issue on GitHub if you have questions or run into problems!


Made with โค๏ธ for learning and education

โญ If this helped you understand CNNs better, consider giving it a star! โญ

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

convnet-2.4.2.tar.gz (36.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

convnet-2.4.2-py3-none-any.whl (29.3 kB view details)

Uploaded Python 3

File details

Details for the file convnet-2.4.2.tar.gz.

File metadata

  • Download URL: convnet-2.4.2.tar.gz
  • Upload date:
  • Size: 36.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for convnet-2.4.2.tar.gz
Algorithm Hash digest
SHA256 7713567765cada11b15a364d7641be8f4c3c6ded0c8d8110d68c188f4ebce200
MD5 ad1bbfc2805566854ff4d19962a2f308
BLAKE2b-256 ec1552ca7f7ccbd020dacbfb9022be1113a4f1515d695895f4392b25e3f1dae0

See more details on using hashes here.

File details

Details for the file convnet-2.4.2-py3-none-any.whl.

File metadata

  • Download URL: convnet-2.4.2-py3-none-any.whl
  • Upload date:
  • Size: 29.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for convnet-2.4.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a3b7b98178e95926cbe360cd8bc7991e8cfdb29693859439bdbdbe93e94842d3
MD5 08c4db9dcc558d76b5b3828ea6c39325
BLAKE2b-256 fa1c5ad69b11811b454f5958bd6d1c105487b55a3d7159a6be81b3d88cda11dc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page