Utilities for training models in pytorch
Project description
xt-training
Description
This repo contains utilities for training deep learning models in pytorch, developed by Xtract AI.
Installation
From PyPI:
pip install xt-training
From source:
git clone https://github.com/XtractTech/xt-training.git
pip install ./xt-training
Usage
See specific help on a class or function using help
. E.g., help(Runner)
.
Training a model
from xt_training import Runner, metrics
from torch.utils.tensorboard import SummaryWriter
# Here, define class instances for the required objects
# model =
# optimizer =
# scheduler =
# loss_fn =
# Define metrics - each of these will be printed for each iteration
# Either per-batch or running-average values can be printed
batch_metrics = {
'eps': metrics.BatchTimer(),
'acc': metrics.accuracy,
'kappa': metrics.kappa
}
# Define tensorboard writer
writer = SummaryWriter()
# Create runner
runner = Runner(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
scheduler=scheduler,
batch_metrics=batch_metrics,
device='cuda:0',
writer=writer
)
# Define dataset and loaders
# dataset =
# train_loader =
# val_loader =
# Train
model.train()
runner(train_loader)
# Evaluate
model.eval()
runner(val_loader)
# Print training and evaluation history
print(runner)
Scoring a model
import torch
from xt_training import Runner
# Here, define the model
# model =
# model.load_state_dict(torch.load(<checkpoint file>))
# Create runner
# (alternatively, can use a fully-specified training runner as in the example above)
runner = Runner(model=model, device='cuda:0')
# Define dataset and loaders
# dataset =
# test_loader =
# Score
model.eval()
y_pred, y = runner(test_loader, return_preds=True)
Data Sources
[descriptions and links to data]
Dependencies/Licensing
[list of dependencies and their licenses, including data]
References
[list of references]
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
xt-training-0.3.4.tar.gz
(5.2 kB
view details)
Built Distribution
File details
Details for the file xt-training-0.3.4.tar.gz
.
File metadata
- Download URL: xt-training-0.3.4.tar.gz
- Upload date:
- Size: 5.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.8.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e05bf2e7ad736da3064b033a0f6b58050abe2ee8f8e1ee948b6dc7f679f0287a |
|
MD5 | 24901c61b919398524ff732d4c8de029 |
|
BLAKE2b-256 | adeb0f589d828d2df0c3e35bd5eaf0e0c2b321bf3757695ace89741783125c46 |
File details
Details for the file xt_training-0.3.4-py3-none-any.whl
.
File metadata
- Download URL: xt_training-0.3.4-py3-none-any.whl
- Upload date:
- Size: 5.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.8.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d1c3409c86047dc7543b10b0b9a47e83b4b754b8b5d8518eecb2e0dab2e9eee7 |
|
MD5 | 8f78e1b73ecc1d61226e881c4c1f3be6 |
|
BLAKE2b-256 | 9be04fd5e12a57f3e839fad38f9cc24f4cea8caf8ece8653907682ff26e41e62 |