Skip to main content

Utilities for training models in pytorch

Project description

xt-training

Description

This repo contains utilities for training deep learning models in pytorch, developed by Xtract AI.

Installation

From PyPI:

pip install xt-training

From source:

git clone https://github.com/XtractTech/xt-training.git
pip install ./xt-training

Usage

See specific help on a class or function using help. E.g., help(Runner).

Training a model

Using xt-training (High Level)

First, you must define a config file with the necessary items. To generate a template config file, run:

python -m xt_training template path/to/save/dir

To generate template files for nni, add the --nni flag

Instructions for defining a valid config file can be seen at the top of the config file.

After defining a valid config file, you can train your model by running:

python -m xt_training train path/to/config.py /path/to/save_dir

You can test the model by running

python -m xt_training test path/to/config.py /path/to/save_dir
Using functional train (Middle Level)

As of version >=2.0.0, xt-training has functional calls for the train and test functions This is useful if you want to run other code after training, or want any values/metrics returned after training. This can be called like so:

from xt_training.utils import functional

# model = 
# train_loader = 
# optimizer = 
# scheduler = 
# loss_fn = 
# metrics = 
# epochs = 
# save_dir = 
def on_exit(test_loaders, runner, save_dir, model):
    # Do what you want after training.
    # As of version >=2.0.0. whatever gets returned here will get returned from the functional call
    return runner, model

runner, model = functional.train(
    save_dir,
    train_loader,
    model,
    optimizer,
    epochs,
    loss_fn,
    val_loader=None,
    test_loaders=None,
    scheduler=scheduler,
    is_batch_scheduler=False, # Whether or not to run scheduler.step() every epoch or every step
    eval_metrics=metrics,
    tokenizer=None,
    on_exit=train_exit,
    use_nni=False
)

# Do something after with runner and/or model...

A similar functional call exists for test.

Using Runner (Low Level)

If you want a little more control and want to define the trianing code yourself, you can utilize the Runner like so:

from xt_training import Runner, metrics
from torch.utils.tensorboard import SummaryWriter

# Here, define class instances for the required objects
# model = 
# optimizer = 
# scheduler = 
# loss_fn = 

# Define metrics - each of these will be printed for each iteration
# Either per-batch or running-average values can be printed
batch_metrics = {
    'eps': metrics.EPS(),
    'acc': metrics.Accuracy(),
    'kappa': metrics.Kappa(),
    'cm': metrics.ConfusionMatrix()
}

# Define tensorboard writer
writer = SummaryWriter()

# Create runner
runner = Runner(
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    scheduler=scheduler,
    batch_metrics=batch_metrics,
    device='cuda:0',
    writer=writer
)

# Define dataset and loaders
# dataset = 
# train_loader = 
# val_loader = 

# Train
model.train()
runner(train_loader)
batch_metrics['cm'].print()

# Evaluate
model.eval()
runner(val_loader)
batch_metrics['cm'].print()

# Print training and evaluation history
print(runner)

Scoring a model

import torch
from xt_training import Runner

# Here, define the model
# model = 
# model.load_state_dict(torch.load(<checkpoint file>))

# Create runner
# (alternatively, can use a fully-specified training runner as in the example above)
runner = Runner(model=model, device='cuda:0')

# Define dataset and loaders
# dataset = 
# test_loader = 

# Score
model.eval()
y_pred, y = runner(test_loader, return_preds=True)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xt-training-2.3.2.tar.gz (18.1 kB view details)

Uploaded Source

Built Distribution

xt_training-2.3.2-py3-none-any.whl (21.6 kB view details)

Uploaded Python 3

File details

Details for the file xt-training-2.3.2.tar.gz.

File metadata

  • Download URL: xt-training-2.3.2.tar.gz
  • Upload date:
  • Size: 18.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.7.3

File hashes

Hashes for xt-training-2.3.2.tar.gz
Algorithm Hash digest
SHA256 2ac864e3f236dde84dfb6546158e1bfbd8b15b142380ca8da3fab048fc81ac3c
MD5 850bbe9df56754f724ac3541e6a00c11
BLAKE2b-256 53280732afba0d7ad6340d4a51a7551b10ef31c9435bb6246f488145005239d1

See more details on using hashes here.

File details

Details for the file xt_training-2.3.2-py3-none-any.whl.

File metadata

  • Download URL: xt_training-2.3.2-py3-none-any.whl
  • Upload date:
  • Size: 21.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.7.3

File hashes

Hashes for xt_training-2.3.2-py3-none-any.whl
Algorithm Hash digest
SHA256 cf48e6b325aa8ae7f4909964228e5b6c59c743221c129964dca9a472a43c9a40
MD5 6fab9d810207c1c44bcd1e45e3eb7166
BLAKE2b-256 9a2a708ddb4416a404c1420d8d7f585183f7cd32de29bc43b0e0d06b33a9a8b9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page