Skip to main content

A minimal, educational convolutional neural network framework with JAX acceleration

Project description

๐Ÿง  ConvNet

PyPI Python JAX GPU/TPU License: MIT Status

A clean, educational Convolutional Neural Network framework built from scratch with JAX acceleration

This project was created as a school assignment with the goal of understanding deep learning from the ground up. It's designed to be easy to understand and learn from, implementing a complete CNN framework with JAX for high-performance JIT compilation and automatic GPU/TPU acceleration. The framework uses simple, readable code while leveraging JAX's XLA compiler for production-grade performance.


๐ŸŒŸ Features

Core Functionality

  • โœ… JAX-Powered Core - All neural network operations with JIT compilation
  • ๐Ÿ”ฅ Complete CNN Support - Conv2D, MaxPool2D, Flatten, Dense layers
  • ๐Ÿ“Š Modern Training - Batch normalization, dropout, early stopping
  • ๐ŸŽฏ Smart Optimizers - SGD with momentum and Adam optimizer
  • ๐Ÿ“ˆ Learning Rate Scheduling - Plateau-based LR reduction
  • ๐Ÿ’พ Model Persistence - Save/load models in HDF5 or NPZ format
  • ๐Ÿ”„ Data Augmentation Ready - Thread-pooled data loading

Performance Enhancements

  • โšก JAX JIT Compilation - XLA-compiled operations for maximum speed
  • ๐Ÿš€ GPU/TPU Support - Automatic hardware acceleration via JAX
  • ๐Ÿงต XLA Optimization - Automatic kernel fusion and optimization
  • ๐Ÿ“ฆ Batch Processing - Efficient mini-batch training

Developer Experience

  • ๐Ÿ“š Clean Code - Well-documented and easy to follow
  • ๐ŸŽ“ Educational - Built for learning deep learning fundamentals
  • ๐Ÿ”ง Modular Design - Easy to extend and customize
  • ๐Ÿ’ป Examples Included - MNIST training example and GUI demo

๐Ÿš€ Quick Start

Installation

Install from PyPI (Recommended):

# Install the latest version from PyPI
pip install convnet

# Or install with GPU support
pip install convnet[gpu]   # For NVIDIA GPU (CUDA)
pip install convnet[tpu]   # For Google TPU

Install from Source:

# Clone the repository
git clone https://github.com/codinggamer-dev/ConvNet.git
cd ConvNet

# Install in development mode
pip install -e .

Your First Neural Network in 10 Lines

from convnet import Model, Conv2D, Activation, MaxPool2D, Flatten, Dense

# Build a simple CNN
model = Model([
    Conv2D(8, (3, 3)), Activation('relu'),
    MaxPool2D((2, 2)),
    Flatten(),
    Dense(10), Activation('softmax')
])

# Configure training
model.compile(loss='categorical_crossentropy', optimizer='adam', lr=0.001)

# Train on your data
history = model.fit(train_dataset, epochs=10, batch_size=32, num_classes=10)

๐Ÿ“– Complete MNIST Example

import numpy as np
from convnet import Model, Conv2D, Activation, MaxPool2D, Flatten, Dense, Dropout, Dataset
from convnet.data import load_mnist_gz

# Load MNIST data
train_data, test_data = load_mnist_gz('mnist_dataset')

# Build the model
model = Model([
    Conv2D(8, (3, 3)), Activation('relu'),
    MaxPool2D((2, 2)),
    Conv2D(16, (3, 3)), Activation('relu'),
    MaxPool2D((2, 2)),
    Flatten(),
    Dense(64), Activation('relu'), Dropout(0.2),
    Dense(10)
])

# Compile and train
model.compile(loss='categorical_crossentropy', optimizer='adam', lr=0.001)
history = model.fit(train_data, epochs=50, batch_size=128, num_classes=10)

# Save the model
model.save('my_model.hdf5')

๐Ÿงฉ Architecture Components

Available Layers

Layer Description Parameters
Conv2D(filters, kernel_size) 2D Convolutional layer filters, kernel_size, stride, padding
Dense(units) Fully connected layer units, use_bias
MaxPool2D(pool_size) Max pooling layer pool_size, stride
Activation(type) Activation function 'relu', 'tanh', 'sigmoid', 'softmax'
Flatten() Reshape to 1D None
Dropout(rate) Dropout regularization rate (0.0 to 1.0)
BatchNorm2D() Batch normalization momentum, epsilon

Optimizers

  • SGD - Stochastic Gradient Descent with momentum

    model.compile(optimizer='sgd', lr=0.01, momentum=0.9)
    
  • Adam - Adaptive Moment Estimation (recommended)

    model.compile(optimizer='adam', lr=0.001, beta1=0.9, beta2=0.999)
    

Loss Functions

  • 'categorical_crossentropy' - For multi-class classification
  • 'mse' - Mean Squared Error for regression

๐ŸŽฎ Examples & Demos

The examples/ directory contains several demonstrations:

1. MNIST Training (mnist_train-example.py)

Complete training pipeline with early stopping, LR scheduling, and model persistence.

python examples/mnist_train-example.py

2. Interactive GUI Demo (mnist_gui.py)

Draw digits and see real-time predictions! Requires tkinter.

python examples/mnist_gui.py

3. GPU Training Test (test_gpu_training.py)

Benchmark GPU vs CPU performance.

python examples/test_gpu_training.py

4. JAX Benchmark (benchmark_jax.py)

Compare JAX JIT performance vs NumPy baseline.

python examples/benchmark_jax.py

โš™๏ธ Advanced Features

GPU/TPU Acceleration

ConvNet automatically detects and uses available hardware accelerators via JAX:

# Install with GPU support
pip install convnet[gpu]

# Or for TPU
pip install convnet[tpu]

# Or install JAX with CUDA manually
pip install jax[cuda12]  # For CUDA 12.x

The framework will automatically:

  • Detect available GPUs/TPUs
  • JIT-compile operations with XLA
  • Move tensors to accelerator devices
  • Handle device transfers transparently

Regularization

model.compile(
    optimizer='adam',
    lr=0.001,
    weight_decay=1e-4,  # L2 regularization
    clip_norm=5.0        # Gradient clipping
)

Learning Rate Scheduling

history = model.fit(
    dataset,
    lr_schedule='plateau',  # Reduce LR when validation plateaus
    lr_factor=0.5,         # Multiply LR by 0.5
    lr_patience=5,         # Wait 5 epochs before reducing
    lr_min=1e-6           # Minimum learning rate
)

Early Stopping

history = model.fit(
    dataset,
    val_data=(X_val, y_val),
    early_stopping=True,
    patience=10,      # Stop after 10 epochs without improvement
    min_delta=0.001   # Minimum change to qualify as improvement
)

๐Ÿ“Š Model Inspection

# Print model architecture and parameter counts
model.summary()

# Output:
# Model summary:
# Conv2D: params=80
# Activation: params=0
# MaxPool2D: params=0
# Conv2D: params=1168
# Activation: params=0
# MaxPool2D: params=0
# Flatten: params=0
# Dense: params=40064
# Activation: params=0
# Dropout: params=0
# Dense: params=650
# Total params: 41962

๐Ÿ”ง Configuration

Thread Configuration

The framework automatically configures BLAS threads for optimal CPU performance:

import os
os.environ['NN_DISABLE_AUTO_THREADS'] = '1'  # Disable auto-configuration
import convnet

Custom RNG Seeds

For reproducibility:

import numpy as np
rng = np.random.default_rng(seed=42)

model = Model([
    Conv2D(8, (3, 3), rng=rng),
    Dense(10, rng=rng)
])

๐Ÿ“š Understanding the Code

This project is designed for learning. Here's how to explore:

Start Here

  1. convnet/layers.py - See how Conv2D, Dense, and other layers work
  2. convnet/model.py - Understand forward/backward propagation
  3. convnet/optim.py - Learn how optimizers update weights
  4. examples/mnist_train-example.py - Complete training example

Key Concepts Implemented

  • ๐Ÿ”„ Backpropagation - Full gradient computation chain
  • ๐Ÿ“‰ Gradient Descent - SGD and Adam optimization
  • ๐ŸŽฒ Weight Initialization - Glorot/Xavier uniform
  • ๐Ÿงฎ Convolution Math - JIT-compiled implementation
  • ๐Ÿ“Š Batch Normalization - Running mean/variance tracking
  • ๐ŸŽฏ Softmax & Cross-Entropy - Numerically stable implementation

๐ŸŽฏ Project Goals

This framework was built to:

  1. Understand deep learning by implementing it from scratch
  2. Learn how CNNs actually work under the hood
  3. Teach others the fundamentals of neural networks
  4. Provide a clean, readable codebase for education

Not for production use - Use PyTorch, TensorFlow, or JAX for real applications!


๐Ÿ“ฆ Project Structure

ConvNet/
โ”œโ”€โ”€ convnet/              # Core framework
โ”‚   โ”œโ”€โ”€ __init__.py       # Package initialization & auto-config
โ”‚   โ”œโ”€โ”€ layers.py         # Layer implementations
โ”‚   โ”œโ”€โ”€ model.py          # Model class with training loop
โ”‚   โ”œโ”€โ”€ optim.py          # Optimizers (SGD, Adam)
โ”‚   โ”œโ”€โ”€ losses.py         # Loss functions
โ”‚   โ”œโ”€โ”€ data.py           # Data loading utilities
โ”‚   โ”œโ”€โ”€ utils.py          # Helper functions
โ”‚   โ”œโ”€โ”€ jax_backend.py    # JAX acceleration backend
โ”‚   โ””โ”€โ”€ io.py             # Model save/load
โ”œโ”€โ”€ examples/             # Example scripts
โ”‚   โ”œโ”€โ”€ mnist_train-example.py
โ”‚   โ”œโ”€โ”€ mnist_gui.py
โ”‚   โ””โ”€โ”€ mnist_gui.py
โ”œโ”€โ”€ requirements.txt      # Dependencies
โ”œโ”€โ”€ setup.py              # Package setup
โ”œโ”€โ”€ LICENSE.md            # MIT License
โ””โ”€โ”€ README.md             # This file

๐Ÿค Contributing

This is an educational project, but contributions are welcome! Feel free to:

  • ๐Ÿ› Report bugs
  • ๐Ÿ’ก Suggest improvements
  • ๐Ÿ“– Improve documentation
  • โœจ Add new features

๐Ÿ“ Requirements

Core Dependencies

  • Python 3.8 or higher
  • JAX โ‰ฅ 0.4.0 (high-performance computing! ๐Ÿš€)
  • jaxlib โ‰ฅ 0.4.0 (JAX runtime)
  • NumPy โ‰ฅ 1.20.0 (compatibility layer)
  • tqdm โ‰ฅ 4.60.0 (progress bars)
  • h5py โ‰ฅ 3.0.0 (model serialization)

Optional Dependencies

  • jax[cuda12] (GPU acceleration)
  • jax[tpu] (TPU acceleration)
  • tkinter (for GUI demo, usually included with Python)

๐Ÿ“„ License

This project is licensed under the MIT License - see the LICENSE.md file for details.

Copyright (c) 2025 Tim Bauer


๐Ÿ™ Acknowledgments

  • Built as a school project to learn deep learning fundamentals
  • Inspired by PyTorch and TensorFlow's clean APIs
  • Powered by JAX and XLA for high-performance computation
  • MNIST dataset by Yann LeCun and Corinna Cortes - the perfect dataset for learning CNNs

๐Ÿ’ฌ Questions?

Feel free to open an issue on GitHub if you have questions or run into problems!


Made with โค๏ธ for learning and education

โญ If this helped you understand CNNs better, consider giving it a star! โญ

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

convnet-2.0.0.tar.gz (31.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

convnet-2.0.0-py3-none-any.whl (26.2 kB view details)

Uploaded Python 3

File details

Details for the file convnet-2.0.0.tar.gz.

File metadata

  • Download URL: convnet-2.0.0.tar.gz
  • Upload date:
  • Size: 31.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for convnet-2.0.0.tar.gz
Algorithm Hash digest
SHA256 359db0f5c22252bc828b0345f52f6168c2cb0c88fab6cd26aec2c9de60628fe3
MD5 31cdcdc59b05be16f1d23af2f6d924f1
BLAKE2b-256 cd66ff4af973d20b337dd9a972183b260fdcb3568936ee0db8e0928e60530ff8

See more details on using hashes here.

File details

Details for the file convnet-2.0.0-py3-none-any.whl.

File metadata

  • Download URL: convnet-2.0.0-py3-none-any.whl
  • Upload date:
  • Size: 26.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for convnet-2.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a98ac17a5bb38a98b1b594443aee99b178a660db27e211b1fdebea54ba7050cb
MD5 00187a37f20d3449cc3f2e590a9c4baf
BLAKE2b-256 2eacc5d83bb5aa8335a2e41cf1b42f49ad9c1d5704597eb64369e96eca3c419b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page