Skip to main content

Effortless PyTorch training - define your model, Kito handles the rest

Project description

Kito Logo


Effortless PyTorch training - define your model, Kito handles the rest.

Tests PyPI version License: MIT PyPI - Python Version Downloads Documentation Status

Kito is a lightweight PyTorch training library that eliminates boilerplate code. Define your model architecture and loss function - Kito automatically handles training loops, optimization, callbacks, distributed training, and more.

✨ Key Features

  • Zero Boilerplate - No training loops, no optimizer setup, no device management
  • Auto-Everything - Automatic model building, optimizer binding, and device assignment
  • Built-in DDP - Distributed training works out of the box
  • Smart Callbacks - TensorBoard, checkpointing, logging, and custom callbacks
  • Flexible - Simple for beginners, powerful for experts
  • Lightweight - Minimal dependencies, pure PyTorch under the hood

Quick Start

Installation

pip install pytorch-kito

Your First Model in 3 Steps

import torch.nn as nn
from kito import Engine, KitoModule

# 1. Define your model
class MyModel(KitoModule):
    def build_inner_model(self):
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
        self.model_input_size = (784,)

    def bind_optimizer(self):
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.learning_rate
        )

# 2. Initialize
model = MyModel('MyModel', device, config)
engine = Engine(model, config)

# 3. Train! (That's it - everything else is automatic)
engine.fit(train_loader, val_loader, max_epochs=10)

Philosophy

Kito follows a "define once, train anywhere" philosophy:

  1. You focus on: Model architecture and research ideas
  2. Kito handles: Training loops, optimization, distributed training, callbacks

Perfect for researchers who want to iterate quickly without rewriting training code.

Core Concepts

KitoModule

Your model inherits from KitoModule and implements two methods:

class MyModel(KitoModule):
    def build_inner_model(self):
        # Define your architecture
        self.model = nn.Sequential(...)
        self.model_input_size = (C, H, W)  # Input shape

    def bind_optimizer(self):
        # Choose your optimizer
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.learning_rate
        )

Engine

The Engine orchestrates everything:

engine = Engine(module, config)

# Training
engine.fit(train_loader, val_loader, max_epochs=100)

# Inference
predictions = engine.predict(test_loader)

Data Pipeline

Kito provides a clean data pipeline with preprocessing:

from kito.data import H5Dataset, GenericDataPipeline
from kito.data.preprocessing import Pipeline, Normalize, ToTensor

# Create dataset
dataset = H5Dataset('data.h5')

# Add preprocessing
preprocessing = Pipeline([
    Normalize(min_val=0.0, max_val=1.0),
    ToTensor()
])

# Setup data pipeline
pipeline = GenericDataPipeline(
    config=config,
    dataset=dataset,
    preprocessing=preprocessing
)
pipeline.setup()

# Get dataloaders
train_loader = pipeline.train_dataloader()
val_loader = pipeline.val_dataloader()

Callbacks

Kito includes powerful callbacks for common tasks:

from kito.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger

callbacks = [
    ModelCheckpoint('best_model.pt', monitor='val_loss', mode='min'),
    EarlyStopping(patience=10, monitor='val_loss'),
    CSVLogger('training.csv')
]

engine.fit(train_loader, val_loader, callbacks=callbacks)

Or create custom callbacks:

from kito.callbacks import Callback

class MyCallback(Callback):
    def on_epoch_end(self, epoch, logs, **kwargs):
        print(f"Epoch {epoch}: loss={logs['train_loss']:.4f}")

Advanced Features

Distributed Training (DDP)

Enable distributed training with one config change:

config.training.distributed_training = True

# Everything else stays the same!
engine.fit(train_loader, val_loader, max_epochs=100)

Custom Training Logic

Override training_step for custom behavior:

class MyModel(KitoModule):
    def training_step(self, batch, pbar_handler=None):
        inputs, targets = batch

        # Custom forward pass
        outputs = self.model(inputs)
        loss = self.compute_loss((inputs, targets), outputs)

        # Custom backward (e.g., gradient clipping)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()

        return {'loss': loss}

Multiple Datasets

Kito supports HDF5 and in-memory datasets out of the box:

from kito.data import H5Dataset, MemDataset

# HDF5 dataset (lazy loading)
dataset = H5Dataset('large_data.h5')

# In-memory dataset (fast)
dataset = MemDataset(x_train, y_train)

Register custom datasets easily:

from kito.data.registry import DATASETS

@DATASETS.register('my_custom_dataset')
class MyDataset(KitoDataset):
    def _load_sample(self, index):
        return data, labels

📦 Installation Options

# Basic installation
pip install pytorch-kito

# With TensorBoard support
pip install pytorch-kito[tensorboard]

# Development installation
pip install pytorch-kito[dev]

# Everything
pip install pytorch-kito[all]

🤝 Contributing

Contributions are very welcome! Please check out our Contributing Guide.

📄 License

MIT License - see LICENSE file for details.

🙏 Acknowledgments

Kito is inspired by PyTorch Lightning and Keras, aiming to bring similar ease-of-use to pure PyTorch workflows for researchers.

Contact


Made with ❤️ for the PyTorch community

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_kito-0.2.9.tar.gz (51.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_kito-0.2.9-py3-none-any.whl (58.8 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_kito-0.2.9.tar.gz.

File metadata

  • Download URL: pytorch_kito-0.2.9.tar.gz
  • Upload date:
  • Size: 51.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.8

File hashes

Hashes for pytorch_kito-0.2.9.tar.gz
Algorithm Hash digest
SHA256 096fe6aa3a9f87acfaf1d337a975228dc07dc87ad4f28167e6157b0f6318b352
MD5 24568c0d45de1d5e02f23e373def4a74
BLAKE2b-256 89eb0afd14c75cd2200127a31cc2934bf18f1960ac3d02df237040b8b71b54e0

See more details on using hashes here.

File details

Details for the file pytorch_kito-0.2.9-py3-none-any.whl.

File metadata

  • Download URL: pytorch_kito-0.2.9-py3-none-any.whl
  • Upload date:
  • Size: 58.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.8

File hashes

Hashes for pytorch_kito-0.2.9-py3-none-any.whl
Algorithm Hash digest
SHA256 47662996257a76dbb8a2b53b3c10c171dd73611d936e217fbbdebd005f0bffe5
MD5 71b9116ebaa1746ae77463db7508ba96
BLAKE2b-256 0e7eaa1bcac967bf65bd9bd9b3b0969f80ff7eba321fa67b8000537a2fa3de21

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page