Skip to main content

Build Pytorch Based NN Projects Faster

Project description

LayerZero

A modular PyTorch training framework with automatic performance optimizations.

Features

Trainer

  • Model compilation via torch.compile() (PyTorch 2.0+)
  • Mixed precision training (AMP)
  • Automatic GPU augmentation integration
  • Asynchronous CUDA data transfers
  • Metric tracking and logging
  • Model checkpointing
  • Custom callbacks

ImageDataLoader

  • GPU-accelerated augmentation using Kornia
  • Configurable augmentation modes
  • Automatic worker detection
  • Torchvision dataset support
  • Pinned memory for GPU training

Helper

  • Training/validation metric tracking
  • Loss curve visualization
  • Experiment logging

Performance Optimizations

Applied automatically:

  • torch.compile() for model compilation
  • Mixed precision (FP16) training
  • Non-blocking CUDA transfers
  • GPU-based augmentation (Kornia)
  • Optimized DataLoader configuration

Installation

pip install torch torchvision matplotlib

# Optional: GPU augmentation
pip install kornia kornia-rs

Usage

Basic Example

import torch
from torch import nn
from LayerZero import ImageDataLoader, Trainer, TrainerConfig
from torchvision.datasets import CIFAR10

# Model
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(3*32*32, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

# Data
loader = ImageDataLoader(
    CIFAR10,
    root='./data',
    image_size=32,
    batch_size=128,
    download=True,
    use_gpu_augmentation='auto'  # Automatic GPU acceleration
)

train_loader, test_loader = loader.get_loaders()

# Training configuration
config = TrainerConfig(
    epochs=10,
    amp=True,
    compile_model='auto'
)

# Train
trainer = Trainer(
    model=model,
    loss_fn=nn.CrossEntropyLoss(),
    optimizer=torch.optim.Adam(model.parameters()),
    config=config
)

results = trainer.fit(
    train_loader, 
    test_loader,
    data_loader=loader  # Auto-detects GPU augmentation!
)

Configuration

Augmentation Modes

from LayerZero import ImageDataLoader, AugmentationMode

loader = ImageDataLoader(
    CIFAR10,
    augmentation_mode=AugmentationMode.MINIMAL,  # Flip + Crop
    # AugmentationMode.BASIC,   # + ColorJitter (default)
    # AugmentationMode.STRONG,  # + Rotation + Blur + Erasing
    # AugmentationMode.OFF,     # No augmentation
)

GPU Augmentation

# Automatic integration with Trainer (Recommended)
loader = ImageDataLoader(
    CIFAR10,
    use_gpu_augmentation='auto',  # Auto-detect GPU and Kornia
    auto_install_kornia=True       # Install if missing
)

train_loader, test_loader = loader.get_loaders()

# GPU augmentation auto-detected when fit() is called!
trainer = Trainer(
    model=model,
    loss_fn=nn.CrossEntropyLoss(),
    optimizer=torch.optim.Adam(model.parameters()),
    config=config
)

trainer.fit(
    train_loader,
    test_loader,
    data_loader=loader  # ← Pass loader here, Trainer auto-detects GPU aug!
)

# Manual usage in custom training loops
gpu_aug = loader.get_gpu_augmentation(device='cuda')

for X, y in train_loader:
    X = X.to(device)
    X = gpu_aug(X)
    # ... training code ...

Mixed Precision

config = TrainerConfig(
    amp=True,   # Enable (default)
    # amp=False  # Disable for debugging
)

Model Compilation

config = TrainerConfig(
    compile_model='auto',        # Auto-detect PyTorch 2.0+
    compile_mode='default',      # Compilation mode
    # compile_mode='reduce-overhead'
    # compile_mode='max-autotune'
)

Custom Metrics

def accuracy_fn(y_pred, y_true):
    return (y_pred.argmax(1) == y_true).float().mean().item() * 100

config = TrainerConfig(
    metrics={'accuracy': accuracy_fn}
)

Callbacks

def save_checkpoint(model, epoch, metrics):
    torch.save(model.state_dict(), f'model_epoch_{epoch}.pt')

config = TrainerConfig(
    callbacks={'on_epoch_end': save_checkpoint}
)

API Reference

ImageDataLoader

ImageDataLoader(
    dataset_cls,                          # Torchvision dataset class
    root='./data',                        # Data directory
    image_size=224,                       # Image size
    batch_size=64,                        # Batch size
    augmentation_mode=AugmentationMode.BASIC,
    use_gpu_augmentation='auto',
    auto_install_kornia=True,
    num_workers=None,                     # Auto-detect
    download=False,
)

TrainerConfig

TrainerConfig(
    epochs=10,
    amp=True,                     # Mixed precision
    compile_model='auto',         # torch.compile()
    compile_mode='default',
    metrics={},
    callbacks={},
    device='auto',
    log_interval=100,
    save_dir='./checkpoints',
)

Trainer

Trainer(
    model,
    loss_fn,
    optimizer,
    config,
    metrics=None,
    callbacks=None,
)

# Run training with optional GPU augmentation auto-detection
trainer.fit(
    train_loader, 
    val_loader,
    epochs=None,        # Optional: Override config.epochs
    data_loader=None    # Optional: ImageDataLoader for GPU aug auto-detection
)

trainer.evaluate(dataloader)  # Evaluate on data
trainer.predict(dataloader)   # Get predictions

KorniaHelper

from LayerZero import (
    is_kornia_available,
    install_kornia,
    ensure_kornia,
    get_kornia_version,
)

if ensure_kornia(auto_install=True):
    # Kornia available
    pass

Architecture

LayerZero/
├── Trainer.py              # Training loop
├── ImageDataLoader.py      # Data loading
├── GPUAugmentation.py      # Kornia augmentation
├── AugmentationMode.py     # Augmentation enums
├── KorniaHelper.py         # Kornia management
└── Helper.py               # Metrics tracking

Troubleshooting

Kornia installation fails

pip install kornia kornia-rs

torch.compile not available

Requires PyTorch 2.0+:

pip install --upgrade torch torchvision

Out of memory

Reduce batch size or enable mixed precision:

config = TrainerConfig(amp=True)

Slow on CPU

Use minimal augmentation:

loader = ImageDataLoader(
    ...,
    augmentation_mode=AugmentationMode.MINIMAL
)

Releasing New Versions

# Bump version (bug fixes: 0.1.3 → 0.1.4)
make bump-patch

# Push to trigger PyPI release
make release

See RELEASE_WORKFLOW.md for complete guide.


License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

layerzero-0.2.2.tar.gz (19.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

layerzero-0.2.2-py3-none-any.whl (22.9 kB view details)

Uploaded Python 3

File details

Details for the file layerzero-0.2.2.tar.gz.

File metadata

  • Download URL: layerzero-0.2.2.tar.gz
  • Upload date:
  • Size: 19.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for layerzero-0.2.2.tar.gz
Algorithm Hash digest
SHA256 c83d8a27f9ca0f7f724ee46d5f18968d40aec9b4717c09ea20e68a1c582c0300
MD5 858804b275ca752b51b1644623e8b8f4
BLAKE2b-256 ffef8bc3ce498ada689c408b6d3dd3f9dd4c5ddc82fd79cdfd3d5cbe0cac35c6

See more details on using hashes here.

File details

Details for the file layerzero-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: layerzero-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 22.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for layerzero-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a8eb5a09f703bc83eb9960de884268fc33635bbebc7e6293f86b99e7ada830a1
MD5 d3f3b13c98c5962ac475906700f837df
BLAKE2b-256 cf6ae03407d0fbae1eb1825a11ecd893d7c977fd0fce7c3883b8b2f43be76347

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page