Skip to main content

Effortless PyTorch training - define your model, Kito handles the rest

Project description

Kito Logo


Effortless PyTorch training - define your model, Kito handles the rest.

Tests PyPI version License: MIT PyPI - Python Version Downloads Documentation Status

Kito is a lightweight PyTorch training library that eliminates boilerplate code. Define your model architecture and loss function - Kito automatically handles training loops, optimization, callbacks, distributed training, and more.

✨ Key Features

  • Zero Boilerplate - No training loops, no optimizer setup, no device management
  • Auto-Everything - Automatic model building, optimizer binding, and device assignment
  • Built-in DDP - Distributed training works out of the box
  • Smart Callbacks - TensorBoard, checkpointing, logging, and custom callbacks
  • Flexible - Simple for beginners, powerful for experts
  • Lightweight - Minimal dependencies, pure PyTorch under the hood

Quick Start

Installation

pip install pytorch-kito

Your First Model in 3 Steps

import torch.nn as nn
from kito import Engine, KitoModule

# 1. Define your model
class MyModel(KitoModule):
    def build_inner_model(self):
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
        self.model_input_size = (784,)

    def bind_optimizer(self):
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.learning_rate
        )

# 2. Initialize
model = MyModel('MyModel', device, config)
engine = Engine(model, config)

# 3. Train! (That's it - everything else is automatic)
engine.fit(train_loader, val_loader, max_epochs=10)

Philosophy

Kito follows a "define once, train anywhere" philosophy:

  1. You focus on: Model architecture and research ideas
  2. Kito handles: Training loops, optimization, distributed training, callbacks

Perfect for researchers who want to iterate quickly without rewriting training code.

Core Concepts

KitoModule

Your model inherits from KitoModule and implements two methods:

class MyModel(KitoModule):
    def build_inner_model(self):
        # Define your architecture
        self.model = nn.Sequential(...)
        self.model_input_size = (C, H, W)  # Input shape

    def bind_optimizer(self):
        # Choose your optimizer
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.learning_rate
        )

Engine

The Engine orchestrates everything:

engine = Engine(module, config)

# Training
engine.fit(train_loader, val_loader, max_epochs=100)

# Inference
predictions = engine.predict(test_loader)

Data Pipeline

Kito provides a clean data pipeline with preprocessing:

from kito.data import H5Dataset, GenericDataPipeline
from kito.data.preprocessing import Pipeline, Normalize, ToTensor

# Create dataset
dataset = H5Dataset('data.h5')

# Add preprocessing
preprocessing = Pipeline([
    Normalize(min_val=0.0, max_val=1.0),
    ToTensor()
])

# Setup data pipeline
pipeline = GenericDataPipeline(
    config=config,
    dataset=dataset,
    preprocessing=preprocessing
)
pipeline.setup()

# Get dataloaders
train_loader = pipeline.train_dataloader()
val_loader = pipeline.val_dataloader()

Callbacks

Kito includes powerful callbacks for common tasks:

from kito.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger

callbacks = [
    ModelCheckpoint('best_model.pt', monitor='val_loss', mode='min'),
    EarlyStopping(patience=10, monitor='val_loss'),
    CSVLogger('training.csv')
]

engine.fit(train_loader, val_loader, callbacks=callbacks)

Or create custom callbacks:

from kito.callbacks import Callback

class MyCallback(Callback):
    def on_epoch_end(self, epoch, logs, **kwargs):
        print(f"Epoch {epoch}: loss={logs['train_loss']:.4f}")

Advanced Features

Distributed Training (DDP)

Enable distributed training with one config change:

config.training.distributed_training = True

# Everything else stays the same!
engine.fit(train_loader, val_loader, max_epochs=100)

Custom Training Logic

Override training_step for custom behavior:

class MyModel(KitoModule):
    def training_step(self, batch, pbar_handler=None):
        inputs, targets = batch

        # Custom forward pass
        outputs = self.model(inputs)
        loss = self.compute_loss((inputs, targets), outputs)

        # Custom backward (e.g., gradient clipping)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()

        return {'loss': loss}

Multiple Datasets

Kito supports HDF5 and in-memory datasets out of the box:

from kito.data import H5Dataset, MemDataset

# HDF5 dataset (lazy loading)
dataset = H5Dataset('large_data.h5')

# In-memory dataset (fast)
dataset = MemDataset(x_train, y_train)

Register custom datasets easily:

from kito.data.registry import DATASETS

@DATASETS.register('my_custom_dataset')
class MyDataset(KitoDataset):
    def _load_sample(self, index):
        return data, labels

📦 Installation Options

# Basic installation
pip install pytorch-kito

# With TensorBoard support
pip install pytorch-kito[tensorboard]

# Development installation
pip install pytorch-kito[dev]

# Everything
pip install pytorch-kito[all]

🤝 Contributing

Contributions are very welcome! Please check out our Contributing Guide.

📄 License

MIT License - see LICENSE file for details.

🙏 Acknowledgments

Kito is inspired by PyTorch Lightning and Keras, aiming to bring similar ease-of-use to pure PyTorch workflows for researchers.

Contact


Made with ❤️ for the PyTorch community

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_kito-0.2.15.tar.gz (53.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_kito-0.2.15-py3-none-any.whl (61.0 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_kito-0.2.15.tar.gz.

File metadata

  • Download URL: pytorch_kito-0.2.15.tar.gz
  • Upload date:
  • Size: 53.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.8

File hashes

Hashes for pytorch_kito-0.2.15.tar.gz
Algorithm Hash digest
SHA256 0018bcaa48281aacdd731f465a23bb6a5f7222257f2cf29f314d63f3e3f6ebff
MD5 3219b87421304e48f2ccfabf6fe06e9e
BLAKE2b-256 a115a132a093f5017a6f4289f93b5a9637b155f6c14326e98cb27ac8d96c4820

See more details on using hashes here.

File details

Details for the file pytorch_kito-0.2.15-py3-none-any.whl.

File metadata

  • Download URL: pytorch_kito-0.2.15-py3-none-any.whl
  • Upload date:
  • Size: 61.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.8

File hashes

Hashes for pytorch_kito-0.2.15-py3-none-any.whl
Algorithm Hash digest
SHA256 10d4073c5fe4d80190f71fbd38df6cec81e2d85170bc354c516c6ff85f3912df
MD5 79b0cc097feadd2035f1abcd80480367
BLAKE2b-256 52a6f2ba52302ad564815c25aa3addd113098ba3484f244f91b4fbcfd4355f10

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page