Skip to main content

Build Pytorch Based NN Projects Faster

Project description

LayerZero

A modular PyTorch training framework with automatic performance optimizations.

Features

Trainer

  • Model compilation via torch.compile() (PyTorch 2.0+)
  • Mixed precision training (AMP)
  • Asynchronous CUDA data transfers
  • Metric tracking and logging
  • Model checkpointing
  • Custom callbacks

ImageDataLoader

  • GPU-accelerated augmentation using Kornia
  • Configurable augmentation modes
  • Automatic worker detection
  • Torchvision dataset support
  • Pinned memory for GPU training

Helper

  • Training/validation metric tracking
  • Loss curve visualization
  • Experiment logging

Performance Optimizations

Applied automatically:

  • torch.compile() for model compilation
  • Mixed precision (FP16) training
  • Non-blocking CUDA transfers
  • GPU-based augmentation (Kornia)
  • Optimized DataLoader configuration

Installation

pip install torch torchvision matplotlib

# Optional: GPU augmentation
pip install kornia kornia-rs

Usage

Basic Example

import torch
from torch import nn
from LayerZero import ImageDataLoader, Trainer, TrainerConfig
from torchvision.datasets import CIFAR10

# Model
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(3*32*32, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

# Data
loader = ImageDataLoader(
    CIFAR10,
    root='./data',
    image_size=32,
    batch_size=128,
    download=True
)

train_loader, test_loader = loader.get_dataloaders()

# Training configuration
config = TrainerConfig(
    epochs=10,
    amp=True,
    compile_model='auto'
)

# Train
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=test_loader,
    loss_fn=nn.CrossEntropyLoss(),
    optimizer=torch.optim.Adam(model.parameters()),
    config=config,
)

results = trainer.train()

Configuration

Augmentation Modes

from LayerZero import ImageDataLoader, AugmentationMode

loader = ImageDataLoader(
    CIFAR10,
    augmentation_mode=AugmentationMode.MINIMAL,  # Flip + Crop
    # AugmentationMode.BASIC,   # + ColorJitter (default)
    # AugmentationMode.STRONG,  # + Rotation + Blur + Erasing
    # AugmentationMode.OFF,     # No augmentation
)

GPU Augmentation

loader = ImageDataLoader(
    CIFAR10,
    use_gpu_augmentation='auto',  # Auto-detect
    auto_install_kornia=True       # Install if missing
)

# Manual usage
gpu_aug = loader.get_gpu_augmentation(device='cuda')

for X, y in train_loader:
    X = X.to(device)
    X = gpu_aug(X)
    # ... training code ...

Mixed Precision

config = TrainerConfig(
    amp=True,   # Enable (default)
    # amp=False  # Disable for debugging
)

Model Compilation

config = TrainerConfig(
    compile_model='auto',        # Auto-detect PyTorch 2.0+
    compile_mode='default',      # Compilation mode
    # compile_mode='reduce-overhead'
    # compile_mode='max-autotune'
)

Custom Metrics

def accuracy_fn(y_pred, y_true):
    return (y_pred.argmax(1) == y_true).float().mean().item() * 100

config = TrainerConfig(
    metrics={'accuracy': accuracy_fn}
)

Callbacks

def save_checkpoint(model, epoch, metrics):
    torch.save(model.state_dict(), f'model_epoch_{epoch}.pt')

config = TrainerConfig(
    callbacks={'on_epoch_end': save_checkpoint}
)

API Reference

ImageDataLoader

ImageDataLoader(
    dataset_cls,                          # Torchvision dataset class
    root='./data',                        # Data directory
    image_size=224,                       # Image size
    batch_size=64,                        # Batch size
    augmentation_mode=AugmentationMode.BASIC,
    use_gpu_augmentation='auto',
    auto_install_kornia=True,
    num_workers=None,                     # Auto-detect
    download=False,
)

TrainerConfig

TrainerConfig(
    epochs=10,
    amp=True,                     # Mixed precision
    compile_model='auto',         # torch.compile()
    compile_mode='default',
    metrics={},
    callbacks={},
    device='auto',
    log_interval=100,
    save_dir='./checkpoints',
)

Trainer

Trainer(
    model,
    train_loader,
    val_loader,
    loss_fn,
    optimizer,
    config,
)

trainer.train()                    # Run training
trainer.predict(dataloader)        # Get predictions
trainer.save_checkpoint(path)      # Save model

KorniaHelper

from LayerZero import (
    is_kornia_available,
    install_kornia,
    ensure_kornia,
    get_kornia_version,
)

if ensure_kornia(auto_install=True):
    # Kornia available
    pass

Architecture

LayerZero/
├── Trainer.py              # Training loop
├── ImageDataLoader.py      # Data loading
├── GPUAugmentation.py      # Kornia augmentation
├── AugmentationMode.py     # Augmentation enums
├── KorniaHelper.py         # Kornia management
└── Helper.py               # Metrics tracking

Troubleshooting

Kornia installation fails

pip install kornia kornia-rs

torch.compile not available

Requires PyTorch 2.0+:

pip install --upgrade torch torchvision

Out of memory

Reduce batch size or enable mixed precision:

config = TrainerConfig(amp=True)

Slow on CPU

Use minimal augmentation:

loader = ImageDataLoader(
    ...,
    augmentation_mode=AugmentationMode.MINIMAL
)

Releasing New Versions

# Bump version (bug fixes: 0.1.3 → 0.1.4)
make bump-patch

# Push to trigger PyPI release
make release

See RELEASE_WORKFLOW.md for complete guide.


License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

layerzero-0.1.12.tar.gz (19.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

layerzero-0.1.12-py3-none-any.whl (21.9 kB view details)

Uploaded Python 3

File details

Details for the file layerzero-0.1.12.tar.gz.

File metadata

  • Download URL: layerzero-0.1.12.tar.gz
  • Upload date:
  • Size: 19.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for layerzero-0.1.12.tar.gz
Algorithm Hash digest
SHA256 79b693ec1f0c0b741cf0532e9e4d32f295074939e6fc765b67bfd0a636c6092b
MD5 e5a0be6a4ed3ca9f7c11052175906eb1
BLAKE2b-256 29314dff040cf58d1e23bb4eddd03dda8fdedd4851b10f605a6701d636c66600

See more details on using hashes here.

File details

Details for the file layerzero-0.1.12-py3-none-any.whl.

File metadata

  • Download URL: layerzero-0.1.12-py3-none-any.whl
  • Upload date:
  • Size: 21.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for layerzero-0.1.12-py3-none-any.whl
Algorithm Hash digest
SHA256 e2b5d62aba5ce56610b645dec1fca707e0250a5b86e87e47a62038fbe3529eed
MD5 964a491cbdb7ca5ae2873b8fdcc69082
BLAKE2b-256 48e058f62bbe232fa2c1763c737fc5b3ffa16202213049fc1c3acfbe743b566a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page