Skip to main content

Utilities for training models in pytorch

Project description

xt-training

Description

This repo contains utilities for training deep learning models in pytorch, developed by Xtract AI.

Installation

From PyPI:

pip install xt-training

From source:

git clone https://github.com/XtractTech/xt-training.git
pip install ./xt-training

Usage

See specific help on a class or function using help. E.g., help(Runner).

Training a model

from xt_training import Runner, metrics
from torch.utils.tensorboard import SummaryWriter

# Here, define class instances for the required objects
# model = 
# optimizer = 
# scheduler = 
# loss_fn = 

# Define metrics - each of these will be printed for each iteration
# Either per-batch or running-average values can be printed
batch_metrics = {
    'eps': metrics.EPS(),
    'acc': metrics.Accuracy(),
    'kappa': metrics.Kappa(),
    'cm': metrics.ConfusionMatrix()
}

# Define tensorboard writer
writer = SummaryWriter()

# Create runner
runner = Runner(
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    scheduler=scheduler,
    batch_metrics=batch_metrics,
    device='cuda:0',
    writer=writer
)

# Define dataset and loaders
# dataset = 
# train_loader = 
# val_loader = 

# Train
model.train()
runner(train_loader)
batch_metrics['cm'].print()

# Evaluate
model.eval()
runner(val_loader)
batch_metrics['cm'].print()

# Print training and evaluation history
print(runner)

Scoring a model

import torch
from xt_training import Runner

# Here, define the model
# model = 
# model.load_state_dict(torch.load(<checkpoint file>))

# Create runner
# (alternatively, can use a fully-specified training runner as in the example above)
runner = Runner(model=model, device='cuda:0')

# Define dataset and loaders
# dataset = 
# test_loader = 

# Score
model.eval()
y_pred, y = runner(test_loader, return_preds=True)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xt-training-1.10.0.tar.gz (15.1 kB view details)

Uploaded Source

Built Distribution

xt_training-1.10.0-py3-none-any.whl (19.1 kB view details)

Uploaded Python 3

File details

Details for the file xt-training-1.10.0.tar.gz.

File metadata

  • Download URL: xt-training-1.10.0.tar.gz
  • Upload date:
  • Size: 15.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.7.3

File hashes

Hashes for xt-training-1.10.0.tar.gz
Algorithm Hash digest
SHA256 daf518cc3eefd40a79b73b29a742f9f7d8c1152ac25d47802b270ac0a0af8b0a
MD5 45d214bbf712c9ce7d2384859f548da0
BLAKE2b-256 db0754e9217f987756ffa38b5f6762a744218ec2acb349428e5d4cdc8fd3c758

See more details on using hashes here.

File details

Details for the file xt_training-1.10.0-py3-none-any.whl.

File metadata

  • Download URL: xt_training-1.10.0-py3-none-any.whl
  • Upload date:
  • Size: 19.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.7.3

File hashes

Hashes for xt_training-1.10.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cdf4a1c78cd1a94be46528c0d9d41111866d0bb5a7f7f45b33ea91fd9dc12759
MD5 9da2bf0843bea7a81c2893cf3e6b959b
BLAKE2b-256 d2d3b4959739320d710f7ea316776c4f06db39045553206e95d84b7e5e9b56f4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page