Effortless PyTorch training - define your model, Kito handles the rest
Project description
Effortless PyTorch training - define your model, Kito handles the rest.
Kito is a lightweight PyTorch training library that eliminates boilerplate code. Define your model architecture and loss function - Kito automatically handles training loops, optimization, callbacks, distributed training, and more.
✨ Key Features
- Zero Boilerplate - No training loops, no optimizer setup, no device management
- Auto-Everything - Automatic model building, optimizer binding, and device assignment
- Built-in DDP - Distributed training works out of the box
- Smart Callbacks - TensorBoard, checkpointing, logging, and custom callbacks
- Flexible - Simple for beginners, powerful for experts
- Lightweight - Minimal dependencies, pure PyTorch under the hood
Quick Start
Installation
pip install pytorch-kito
Your First Model in 3 Steps
import torch.nn as nn
from kito import Engine, KitoModule
# 1. Define your model
class MyModel(KitoModule):
def build_inner_model(self):
self.model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
self.model_input_size = (784,)
def bind_optimizer(self):
self.optimizer = torch.optim.Adam(
self.model.parameters(),
lr=self.learning_rate
)
# 2. Initialize
model = MyModel('MyModel', device, config)
engine = Engine(model, config)
# 3. Train! (That's it - everything else is automatic)
engine.fit(train_loader, val_loader, max_epochs=10)
Philosophy
Kito follows a "define once, train anywhere" philosophy:
- You focus on: Model architecture and research ideas
- Kito handles: Training loops, optimization, distributed training, callbacks
Perfect for researchers who want to iterate quickly without rewriting training code.
Core Concepts
KitoModule
Your model inherits from KitoModule and implements two methods:
class MyModel(KitoModule):
def build_inner_model(self):
# Define your architecture
self.model = nn.Sequential(...)
self.model_input_size = (C, H, W) # Input shape
def bind_optimizer(self):
# Choose your optimizer
self.optimizer = torch.optim.Adam(
self.model.parameters(),
lr=self.learning_rate
)
Engine
The Engine orchestrates everything:
engine = Engine(module, config)
# Training
engine.fit(train_loader, val_loader, max_epochs=100)
# Inference
predictions = engine.predict(test_loader)
Data Pipeline
Kito provides a clean data pipeline with preprocessing:
from kito.data import H5Dataset, GenericDataPipeline
from kito.data.preprocessing import Pipeline, Normalize, ToTensor
# Create dataset
dataset = H5Dataset('data.h5')
# Add preprocessing
preprocessing = Pipeline([
Normalize(min_val=0.0, max_val=1.0),
ToTensor()
])
# Setup data pipeline
pipeline = GenericDataPipeline(
config=config,
dataset=dataset,
preprocessing=preprocessing
)
pipeline.setup()
# Get dataloaders
train_loader = pipeline.train_dataloader()
val_loader = pipeline.val_dataloader()
Callbacks
Kito includes powerful callbacks for common tasks:
from kito.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger
callbacks = [
ModelCheckpoint('best_model.pt', monitor='val_loss', mode='min'),
EarlyStopping(patience=10, monitor='val_loss'),
CSVLogger('training.csv')
]
engine.fit(train_loader, val_loader, callbacks=callbacks)
Or create custom callbacks:
from kito.callbacks import Callback
class MyCallback(Callback):
def on_epoch_end(self, epoch, logs, **kwargs):
print(f"Epoch {epoch}: loss={logs['train_loss']:.4f}")
Advanced Features
Distributed Training (DDP)
Enable distributed training with one config change:
config.training.distributed_training = True
# Everything else stays the same!
engine.fit(train_loader, val_loader, max_epochs=100)
Custom Training Logic
Override training_step for custom behavior:
class MyModel(KitoModule):
def training_step(self, batch, pbar_handler=None):
inputs, targets = batch
# Custom forward pass
outputs = self.model(inputs)
loss = self.compute_loss((inputs, targets), outputs)
# Custom backward (e.g., gradient clipping)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
return {'loss': loss}
Multiple Datasets
Kito supports HDF5 and in-memory datasets out of the box:
from kito.data import H5Dataset, MemDataset
# HDF5 dataset (lazy loading)
dataset = H5Dataset('large_data.h5')
# In-memory dataset (fast)
dataset = MemDataset(x_train, y_train)
Register custom datasets easily:
from kito.data.registry import DATASETS
@DATASETS.register('my_custom_dataset')
class MyDataset(KitoDataset):
def _load_sample(self, index):
return data, labels
📦 Installation Options
# Basic installation
pip install pytorch-kito
# With TensorBoard support
pip install pytorch-kito[tensorboard]
# Development installation
pip install pytorch-kito[dev]
# Everything
pip install pytorch-kito[all]
🤝 Contributing
Contributions are very welcome! Please check out our Contributing Guide.
📄 License
MIT License - see LICENSE file for details.
🙏 Acknowledgments
Kito is inspired by PyTorch Lightning and Keras, aiming to bring similar ease-of-use to pure PyTorch workflows for researchers.
Contact
- GitHub Issues: Report bugs or request features
- GitHub Discussions: Ask questions
Made with ❤️ for the PyTorch community
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pytorch_kito-0.2.9.tar.gz.
File metadata
- Download URL: pytorch_kito-0.2.9.tar.gz
- Upload date:
- Size: 51.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
096fe6aa3a9f87acfaf1d337a975228dc07dc87ad4f28167e6157b0f6318b352
|
|
| MD5 |
24568c0d45de1d5e02f23e373def4a74
|
|
| BLAKE2b-256 |
89eb0afd14c75cd2200127a31cc2934bf18f1960ac3d02df237040b8b71b54e0
|
File details
Details for the file pytorch_kito-0.2.9-py3-none-any.whl.
File metadata
- Download URL: pytorch_kito-0.2.9-py3-none-any.whl
- Upload date:
- Size: 58.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
47662996257a76dbb8a2b53b3c10c171dd73611d936e217fbbdebd005f0bffe5
|
|
| MD5 |
71b9116ebaa1746ae77463db7508ba96
|
|
| BLAKE2b-256 |
0e7eaa1bcac967bf65bd9bd9b3b0969f80ff7eba321fa67b8000537a2fa3de21
|