Build Pytorch Based NN Projects Faster
Project description
LayerZero
A modular PyTorch training framework with automatic performance optimizations.
Features
Trainer
- Model compilation via
torch.compile()(PyTorch 2.0+) - Mixed precision training (AMP)
- Asynchronous CUDA data transfers
- Metric tracking and logging
- Model checkpointing
- Custom callbacks
ImageDataLoader
- GPU-accelerated augmentation using Kornia
- Configurable augmentation modes
- Automatic worker detection
- Torchvision dataset support
- Pinned memory for GPU training
Helper
- Training/validation metric tracking
- Loss curve visualization
- Experiment logging
Performance Optimizations
Applied automatically:
torch.compile()for model compilation- Mixed precision (FP16) training
- Non-blocking CUDA transfers
- GPU-based augmentation (Kornia)
- Optimized DataLoader configuration
Installation
pip install torch torchvision matplotlib
# Optional: GPU augmentation
pip install kornia kornia-rs
Usage
Basic Example
import torch
from torch import nn
from LayerZero import ImageDataLoader, Trainer, TrainerConfig
from torchvision.datasets import CIFAR10
# Model
model = nn.Sequential(
nn.Flatten(),
nn.Linear(3*32*32, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
# Data
loader = ImageDataLoader(
CIFAR10,
root='./data',
image_size=32,
batch_size=128,
download=True
)
train_loader, test_loader = loader.get_dataloaders()
# Training configuration
config = TrainerConfig(
epochs=10,
amp=True,
compile_model='auto'
)
# Train
trainer = Trainer(
model=model,
train_loader=train_loader,
val_loader=test_loader,
loss_fn=nn.CrossEntropyLoss(),
optimizer=torch.optim.Adam(model.parameters()),
config=config,
)
results = trainer.train()
Configuration
Augmentation Modes
from LayerZero import ImageDataLoader, AugmentationMode
loader = ImageDataLoader(
CIFAR10,
augmentation_mode=AugmentationMode.MINIMAL, # Flip + Crop
# AugmentationMode.BASIC, # + ColorJitter (default)
# AugmentationMode.STRONG, # + Rotation + Blur + Erasing
# AugmentationMode.OFF, # No augmentation
)
GPU Augmentation
loader = ImageDataLoader(
CIFAR10,
use_gpu_augmentation='auto', # Auto-detect
auto_install_kornia=True # Install if missing
)
# Manual usage
gpu_aug = loader.get_gpu_augmentation(device='cuda')
for X, y in train_loader:
X = X.to(device)
X = gpu_aug(X)
# ... training code ...
Mixed Precision
config = TrainerConfig(
amp=True, # Enable (default)
# amp=False # Disable for debugging
)
Model Compilation
config = TrainerConfig(
compile_model='auto', # Auto-detect PyTorch 2.0+
compile_mode='default', # Compilation mode
# compile_mode='reduce-overhead'
# compile_mode='max-autotune'
)
Custom Metrics
def accuracy_fn(y_pred, y_true):
return (y_pred.argmax(1) == y_true).float().mean().item() * 100
config = TrainerConfig(
metrics={'accuracy': accuracy_fn}
)
Callbacks
def save_checkpoint(model, epoch, metrics):
torch.save(model.state_dict(), f'model_epoch_{epoch}.pt')
config = TrainerConfig(
callbacks={'on_epoch_end': save_checkpoint}
)
API Reference
ImageDataLoader
ImageDataLoader(
dataset_cls, # Torchvision dataset class
root='./data', # Data directory
image_size=224, # Image size
batch_size=64, # Batch size
augmentation_mode=AugmentationMode.BASIC,
use_gpu_augmentation='auto',
auto_install_kornia=True,
num_workers=None, # Auto-detect
download=False,
)
TrainerConfig
TrainerConfig(
epochs=10,
amp=True, # Mixed precision
compile_model='auto', # torch.compile()
compile_mode='default',
metrics={},
callbacks={},
device='auto',
log_interval=100,
save_dir='./checkpoints',
)
Trainer
Trainer(
model,
train_loader,
val_loader,
loss_fn,
optimizer,
config,
)
trainer.train() # Run training
trainer.predict(dataloader) # Get predictions
trainer.save_checkpoint(path) # Save model
KorniaHelper
from LayerZero import (
is_kornia_available,
install_kornia,
ensure_kornia,
get_kornia_version,
)
if ensure_kornia(auto_install=True):
# Kornia available
pass
Architecture
LayerZero/
├── Trainer.py # Training loop
├── ImageDataLoader.py # Data loading
├── GPUAugmentation.py # Kornia augmentation
├── AugmentationMode.py # Augmentation enums
├── KorniaHelper.py # Kornia management
└── Helper.py # Metrics tracking
Troubleshooting
Kornia installation fails
pip install kornia kornia-rs
torch.compile not available
Requires PyTorch 2.0+:
pip install --upgrade torch torchvision
Out of memory
Reduce batch size or enable mixed precision:
config = TrainerConfig(amp=True)
Slow on CPU
Use minimal augmentation:
loader = ImageDataLoader(
...,
augmentation_mode=AugmentationMode.MINIMAL
)
Releasing New Versions
# Bump version (bug fixes: 0.1.3 → 0.1.4)
make bump-patch
# Push to trigger PyPI release
make release
See RELEASE_WORKFLOW.md for complete guide.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
layerzero-0.1.11.tar.gz
(19.0 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file layerzero-0.1.11.tar.gz.
File metadata
- Download URL: layerzero-0.1.11.tar.gz
- Upload date:
- Size: 19.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b855998af4e7d62830ff5a3dc95ffa49e3f457e78070411c15a53c1a996454b1
|
|
| MD5 |
c81167d6dd7379d3c59e288a2ef524db
|
|
| BLAKE2b-256 |
b7aae6f207e5211d5ac513f8c7cacf80911114b307168fbe0a322ae4144336c6
|
File details
Details for the file layerzero-0.1.11-py3-none-any.whl.
File metadata
- Download URL: layerzero-0.1.11-py3-none-any.whl
- Upload date:
- Size: 21.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
89d1b8fda4ac8691c2bbecf4da276e502f06d0ee60560018819f5ad6892df7c0
|
|
| MD5 |
6afd1519791f43f327d2baa690a9de27
|
|
| BLAKE2b-256 |
7d0f9b10c65608b13ff77129cf595f5ba0e5fe8f2f5925a364757e3ee24331e5
|