Skip to main content

Effortless PyTorch training - define your model, Kito handles the rest

Project description

Kito

Effortless PyTorch training - define your model, Kito handles the rest.

Tests PyPI version License: MIT Python 3.9+ Downloads

Kito is a lightweight PyTorch training library that eliminates boilerplate code. Define your model architecture and loss function - Kito automatically handles training loops, optimization, callbacks, distributed training, and more.

✨ Key Features

  • Zero Boilerplate - No training loops, no optimizer setup, no device management
  • Auto-Everything - Automatic model building, optimizer binding, and device assignment
  • Built-in DDP - Distributed training works out of the box
  • Smart Callbacks - TensorBoard, checkpointing, logging, and custom callbacks
  • Flexible - Simple for beginners, powerful for experts
  • Lightweight - Minimal dependencies, pure PyTorch under the hood

Quick Start

Installation

pip install pytorch-kito

Your First Model in 3 Steps

import torch.nn as nn
from kito import Engine, KitoModule

# 1. Define your model
class MyModel(KitoModule):
    def build_inner_model(self):
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
        self.model_input_size = (784,)

    def bind_optimizer(self):
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.learning_rate
        )

# 2. Initialize
model = MyModel('MyModel', device, config)
engine = Engine(model, config)

# 3. Train! (That's it - everything else is automatic)
engine.fit(train_loader, val_loader, max_epochs=10)

Philosophy

Kito follows a "define once, train anywhere" philosophy:

  1. You focus on: Model architecture and research ideas
  2. Kito handles: Training loops, optimization, distributed training, callbacks

Perfect for researchers who want to iterate quickly without rewriting training code.

Core Concepts

KitoModule

Your model inherits from KitoModule and implements two methods:

class MyModel(KitoModule):
    def build_inner_model(self):
        # Define your architecture
        self.model = nn.Sequential(...)
        self.model_input_size = (C, H, W)  # Input shape

    def bind_optimizer(self):
        # Choose your optimizer
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.learning_rate
        )

Engine

The Engine orchestrates everything:

engine = Engine(module, config)

# Training
engine.fit(train_loader, val_loader, max_epochs=100)

# Inference
predictions = engine.predict(test_loader)

Data Pipeline

Kito provides a clean data pipeline with preprocessing:

from kito.data import H5Dataset, GenericDataPipeline
from kito.data.preprocessing import Pipeline, Normalize, ToTensor

# Create dataset
dataset = H5Dataset('data.h5')

# Add preprocessing
preprocessing = Pipeline([
    Normalize(min_val=0.0, max_val=1.0),
    ToTensor()
])

# Setup data pipeline
pipeline = GenericDataPipeline(
    config=config,
    dataset=dataset,
    preprocessing=preprocessing
)
pipeline.setup()

# Get dataloaders
train_loader = pipeline.train_dataloader()
val_loader = pipeline.val_dataloader()

Callbacks

Kito includes powerful callbacks for common tasks:

from kito.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger

callbacks = [
    ModelCheckpoint('best_model.pt', monitor='val_loss', mode='min'),
    EarlyStopping(patience=10, monitor='val_loss'),
    CSVLogger('training.csv')
]

engine.fit(train_loader, val_loader, callbacks=callbacks)

Or create custom callbacks:

from kito.callbacks import Callback

class MyCallback(Callback):
    def on_epoch_end(self, epoch, logs, **kwargs):
        print(f"Epoch {epoch}: loss={logs['train_loss']:.4f}")

Advanced Features

Distributed Training (DDP)

Enable distributed training with one config change:

config.training.distributed_training = True

# Everything else stays the same!
engine.fit(train_loader, val_loader, max_epochs=100)

Custom Training Logic

Override training_step for custom behavior:

class MyModel(KitoModule):
    def training_step(self, batch, pbar_handler=None):
        inputs, targets = batch

        # Custom forward pass
        outputs = self.model(inputs)
        loss = self.compute_loss((inputs, targets), outputs)

        # Custom backward (e.g., gradient clipping)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()

        return {'loss': loss}

Multiple Datasets

Kito supports HDF5 and in-memory datasets out of the box:

from kito.data import H5Dataset, MemDataset

# HDF5 dataset (lazy loading)
dataset = H5Dataset('large_data.h5')

# In-memory dataset (fast)
dataset = MemDataset(x_train, y_train)

Register custom datasets easily:

from kito.data.registry import DATASETS

@DATASETS.register('my_custom_dataset')
class MyDataset(KitoDataset):
    def _load_sample(self, index):
        return data, labels

📦 Installation Options

# Basic installation
pip install pytorch-kito

# With TensorBoard support
pip install pytorch-kito[tensorboard]

# Development installation
pip install pytorch-kito[dev]

# Everything
pip install pytorch-kito[all]

🤝 Contributing

Contributions are very welcome! Please check out our Contributing Guide.

📄 License

MIT License - see LICENSE file for details.

🙏 Acknowledgments

Kito is inspired by PyTorch Lightning and Keras, aiming to bring similar ease-of-use to pure PyTorch workflows for researchers.

Contact


Made with ❤️ for the PyTorch community

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_kito-0.2.0.tar.gz (37.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_kito-0.2.0-py3-none-any.whl (42.7 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_kito-0.2.0.tar.gz.

File metadata

  • Download URL: pytorch_kito-0.2.0.tar.gz
  • Upload date:
  • Size: 37.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.8

File hashes

Hashes for pytorch_kito-0.2.0.tar.gz
Algorithm Hash digest
SHA256 83aafc4f07ead2f94b8201fe14f50ac169db06995fb632d75208b25df614ff3d
MD5 20a6708f1dc86b2de3bee20bfd216170
BLAKE2b-256 d364f02e144c01e17ea6fc0551bb54b384f2a802def78d66c187cabf0989d0bc

See more details on using hashes here.

File details

Details for the file pytorch_kito-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: pytorch_kito-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 42.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.8

File hashes

Hashes for pytorch_kito-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fd62e75ef60e105ac5fe5f84596c927e65890918cfa215b3778d04889fc01266
MD5 18748f4511fdc4cde50b454fc4af90e2
BLAKE2b-256 93bb9a94f80c7b430e730184a5f0a4de1534e6b2b47d12d7e239c022040d805a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page