A thin PyTorch-Lightning wrapper for building configuration-based DL pipelines with Hydra.
Project description
Hiraishin
A thin PyTorch-Lightning wrapper for building configuration-based DL pipelines with Hydra.
Dependencies
- PyTorch Lightning
- Hydra
- Pydantic
- etc.
Installation
$ pip install -U hiraishin
Basic workflow
1. Model initialization with type annotations
Define a model class that has training components with type annotations.
import torch.nn as nn
import torch.optim as optim
from hiraishin.models import BaseModel
class ToyModel(BaseModel):
net: nn.Linear
criterion: nn.CrossEntropyLoss
optimizer: optim.Adam
scheduler: optim.lr_schedulers.ExponentialLR
def __init__(self, config: DictConfig) -> None:
super().__init__(config)
Modules with the following prefixes are instantiated by their own role-specific logic.
net
criterion
optimizer
scheduler
The same notation can be used to define components other than the learning components listed above (e.g., tokenizers). It is also possible to define built-in type constants that are YAML serializable.
class ToyModel(BaseModel):
net: nn.Linear
criterion: nn.CrossEntropyLoss
optimizer: optim.Adam
scheduler: optim.lr_schedulers.ExponentialLR
# additional components and constants
tokenizer: MyTokenizer
n_classes: int
def __init__(self, config: DictConfig) -> None:
super().__init__(config)
2. Configuration file generation
Hiraishin provides a CLI command that automatically generates a configuration file based on type annotations.
For example, if ToyModel
is defined in models.py
(i.e., from models import ToyModel
can be executed in the code), then the following command will generate the configuration file automatically.
$ hiraishin generate model.ToyModel --output_dir config/model
The config has been generated! --> config/model/ToyModel.yaml
Let's take a look at the generated file.
_target_: models.ToyModel
_recursive_: false
config:
networks:
net:
args:
_target_: torch.nn.Linear
out_features: ???
in_features: ???
weights:
initializer: null
path: null
losses:
criterion:
args:
_target_: torch.nn.CrossEntropyLoss
weight: 1.0
optimizers:
optimizer:
args:
_target_: torch.optim.Adam
params:
- ???
scheduler:
args:
_target_: torch.optim.lr_scheduler.ExponentialLR
gamma: ???
interval: epoch
frequency: 1
strict: true
monitor: null
tokenizer:
_target_: MyTokenizer
n_classes: ???
First of all, it is compliant with the instantiation by hydra.utils.instantiate
.
The positional arguments are filled with ???
that indicates mandatory parameters. They should be overridden by the values you want to set.
3. Training routines definition
The rest of model definition is only defining your training routine along with the style of PyTorch Lightning.
class ToyModel(BaseModel):
...
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def training_step(self, batch, *args, **kwargs) -> torch.Tensor:
x, target = batch
pred = self.forward(x)
loss = self.criterion(pred, target)
self.log('loss/train', loss)
return loss
4. Model Instantiation
The defined model can be instantiated from configuration file. Let's train your models!
from hydra.utils import inatantiate
from omegeconf import OmegaConf
def app():
...
config = OmegaConf.load('config/model/toy.yaml')
model = inatantiate(config)
print(model)
# ToyModel(
# (net): Linear(in_features=1, out_features=1, bias=True)
# (criterion): CrossEntropyLoss()
# )
trainer.fit(model, ...)
5. Model loading
You can easily load trained models by using the checkpoints generated by PyTorch Lightning's standard features. Let's test your models!
from hiraishin.utils import load_from_checkpoint
model = load_from_checkpoint('path/to/model.ckpt')
print(model)
# ToyModel(
# (net): Linear(in_features=1, out_features=1, bias=True)
# (criterion): CrossEntropyLoss()
# )
License
Hiraishin is licensed under the Apache License, Version 2.0. See LICENSE for the full license text.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for hiraishin-0.2.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c984a568b12a25a485eb74dd1e55e067c5d1c3c394cbbbcdb6b4ecc6fd5f22a6 |
|
MD5 | 751591d3e32a1607e18f312231bd9e27 |
|
BLAKE2b-256 | 05ce2cfce104695898d2c6859cd9bab5ab39859fd93308090044ea8ec0b15100 |