Utilities for training models in pytorch
Project description
xt-training
Description
This repo contains utilities for training deep learning models in pytorch, developed by Xtract AI.
Installation
From PyPI:
pip install xt-training
From source:
git clone https://github.com/XtractTech/xt-training.git
pip install ./xt-training
Usage
See specific help on a class or function using help
. E.g., help(Runner)
.
Training a model
Using xt-training (High Level)
First, you must define a config file with the necessary items. To generate a template config file, run:
python -m xt_training template path/to/save/dir
To generate template files for nni, add the --nni
flag
Instructions for defining a valid config file can be seen at the top of the config file.
After defining a valid config file, you can train your model by running:
python -m xt_training train path/to/config.py /path/to/save_dir
You can test the model by running
python -m xt_training test path/to/config.py /path/to/save_dir
Using Runner (Low Level)
If you want a little more control and want to define the trianing code yourself, you can utilize the Runner like so:
from xt_training import Runner, metrics
from torch.utils.tensorboard import SummaryWriter
# Here, define class instances for the required objects
# model =
# optimizer =
# scheduler =
# loss_fn =
# Define metrics - each of these will be printed for each iteration
# Either per-batch or running-average values can be printed
batch_metrics = {
'eps': metrics.EPS(),
'acc': metrics.Accuracy(),
'kappa': metrics.Kappa(),
'cm': metrics.ConfusionMatrix()
}
# Define tensorboard writer
writer = SummaryWriter()
# Create runner
runner = Runner(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
scheduler=scheduler,
batch_metrics=batch_metrics,
device='cuda:0',
writer=writer
)
# Define dataset and loaders
# dataset =
# train_loader =
# val_loader =
# Train
model.train()
runner(train_loader)
batch_metrics['cm'].print()
# Evaluate
model.eval()
runner(val_loader)
batch_metrics['cm'].print()
# Print training and evaluation history
print(runner)
Scoring a model
import torch
from xt_training import Runner
# Here, define the model
# model =
# model.load_state_dict(torch.load(<checkpoint file>))
# Create runner
# (alternatively, can use a fully-specified training runner as in the example above)
runner = Runner(model=model, device='cuda:0')
# Define dataset and loaders
# dataset =
# test_loader =
# Score
model.eval()
y_pred, y = runner(test_loader, return_preds=True)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file xt-training-2.2.0.tar.gz
.
File metadata
- Download URL: xt-training-2.2.0.tar.gz
- Upload date:
- Size: 17.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/44.0.0.post20200106 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ecc4811f0bab1cbd29a103d141339315b608c05bc5c65def8bfde5e8dd9ee20d |
|
MD5 | 1981c311ef6171fc7045e4fd167d4828 |
|
BLAKE2b-256 | 826642ff6a3b998a443d86e22ddd5832cbf9e17b84646d7af285cd1348868a03 |
File details
Details for the file xt_training-2.2.0-py3-none-any.whl
.
File metadata
- Download URL: xt_training-2.2.0-py3-none-any.whl
- Upload date:
- Size: 21.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/44.0.0.post20200106 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4994baf85e7765937860f1129703d0307d0be61e74502189763fca0f0cb887cf |
|
MD5 | 7520c0635db4c326db06128ecfe6cf82 |
|
BLAKE2b-256 | 51488060a6f5165d20ea6403a06209955667d7eb0f2c9818e5864d5abc2288df |