Skip to main content

Utilities for training models in pytorch

Project description

xt-training

Description

This repo contains utilities for training deep learning models in pytorch, developed by Xtract AI.

Installation

From PyPI:

pip install xt-training

From source:

git clone https://github.com/XtractTech/xt-training.git
pip install ./xt-training

Usage

See specific help on a class or function using help. E.g., help(Runner).

Training a model

from xt_training import Runner, metrics
from torch.utils.tensorboard import SummaryWriter

# Here, define class instances for the required objects
# model = 
# optimizer = 
# scheduler = 
# loss_fn = 

# Define metrics - each of these will be printed for each iteration
# Either per-batch or running-average values can be printed
batch_metrics = {
    'eps': metrics.EPS(),
    'acc': metrics.Accuracy(),
    'kappa': metrics.Kappa(),
    'cm': metrics.ConfusionMatrix()
}

# Define tensorboard writer
writer = SummaryWriter()

# Create runner
runner = Runner(
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    scheduler=scheduler,
    batch_metrics=batch_metrics,
    device='cuda:0',
    writer=writer
)

# Define dataset and loaders
# dataset = 
# train_loader = 
# val_loader = 

# Train
model.train()
runner(train_loader)
batch_metrics['cm'].print()

# Evaluate
model.eval()
runner(val_loader)
batch_metrics['cm'].print()

# Print training and evaluation history
print(runner)

Scoring a model

import torch
from xt_training import Runner

# Here, define the model
# model = 
# model.load_state_dict(torch.load(<checkpoint file>))

# Create runner
# (alternatively, can use a fully-specified training runner as in the example above)
runner = Runner(model=model, device='cuda:0')

# Define dataset and loaders
# dataset = 
# test_loader = 

# Score
model.eval()
y_pred, y = runner(test_loader, return_preds=True)

Data Sources

[descriptions and links to data]

Dependencies/Licensing

[list of dependencies and their licenses, including data]

References

[list of references]

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xt-training-1.7.1.tar.gz (12.5 kB view details)

Uploaded Source

Built Distribution

xt_training-1.7.1-py3-none-any.whl (14.5 kB view details)

Uploaded Python 3

File details

Details for the file xt-training-1.7.1.tar.gz.

File metadata

  • Download URL: xt-training-1.7.1.tar.gz
  • Upload date:
  • Size: 12.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0.post20200210 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.3

File hashes

Hashes for xt-training-1.7.1.tar.gz
Algorithm Hash digest
SHA256 3fa7e6cc3100556ae6562cb1eb4ce2b50d68f632a5183b87a10ece6af93127a9
MD5 f4ce8f5e67c9513ebbdfcb74f9e2c4a3
BLAKE2b-256 e59558eefa9d88c5a910a926076fc7e087c8d6e4704fcc4a4541824b9832de12

See more details on using hashes here.

File details

Details for the file xt_training-1.7.1-py3-none-any.whl.

File metadata

  • Download URL: xt_training-1.7.1-py3-none-any.whl
  • Upload date:
  • Size: 14.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0.post20200210 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.3

File hashes

Hashes for xt_training-1.7.1-py3-none-any.whl
Algorithm Hash digest
SHA256 66f216a1c2849add4d50538ac86425bf03da15e25fa03b4b72f23bfcb7123d2b
MD5 fe1a3f4a9f90b5299a3279e0a0648e96
BLAKE2b-256 199283ee7730236be93355f41e4a7307d285ae4d47d3b7d1ab1e3a2613e61954

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page