Skip to main content

Utilities for training models in pytorch

Project description

xt-training

Description

This repo contains utilities for training deep learning models in pytorch, developed by Xtract AI.

Installation

From PyPI:

pip install xt-training

From source:

git clone https://github.com/XtractTech/xt-training.git
pip install ./xt-training

Usage

See specific help on a class or function using help. E.g., help(Runner).

Training a model

from xt_training import Runner, metrics
from torch.utils.tensorboard import SummaryWriter

# Here, define class instances for the required objects
# model = 
# optimizer = 
# scheduler = 
# loss_fn = 

# Define metrics - each of these will be printed for each iteration
# Either per-batch or running-average values can be printed
batch_metrics = {
    'eps': metrics.BatchTimer(),
    'acc': metrics.accuracy,
    'kappa': metrics.kappa
}

# Define tensorboard writer
writer = SummaryWriter()

# Create runner
runner = Runner(
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    scheduler=scheduler,
    batch_metrics=batch_metrics,
    device='cuda:0',
    writer=writer
)

# Define dataset and loaders
# dataset = 
# train_loader = 
# val_loader = 

# Train
model.train()
train_loss, train_metrics = runner(train_loader)

# Evaluate
model.eval()
val_loss, val_metrics = runner(val_loader)

Scoring a model

import torch
from xt_training import Runner

# Here, define the model
# model = 
# model.load_state_dict(torch.load(<checkpoint file>))

# Create runner
# (alternatively, can use a fully-specified training runner as in the example above)
runner = Runner(model=model, device='cuda:0')

# Define dataset and loaders
# dataset = 
# test_loader = 

# Score
model.eval()
y_pred, y = runner.score(test_loader)

Data Sources

[descriptions and links to data]

Dependencies/Licensing

[list of dependencies and their licenses, including data]

References

[list of references]

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xt-training-0.2.0.tar.gz (5.1 kB view details)

Uploaded Source

Built Distribution

xt_training-0.2.0-py3-none-any.whl (5.5 kB view details)

Uploaded Python 3

File details

Details for the file xt-training-0.2.0.tar.gz.

File metadata

  • Download URL: xt-training-0.2.0.tar.gz
  • Upload date:
  • Size: 5.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.8.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.3

File hashes

Hashes for xt-training-0.2.0.tar.gz
Algorithm Hash digest
SHA256 b2e3043496f4394a73449d1b5546ca59d3f222bb8cba2a5f5e77d49e7d95f048
MD5 8490572a37330ab260b04ad6b97ec296
BLAKE2b-256 7bd3057521eb4136400e15f9d108e5ecb8a4d0d6a21e96ebed68e384858d25a1

See more details on using hashes here.

File details

Details for the file xt_training-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: xt_training-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 5.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.8.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.3

File hashes

Hashes for xt_training-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ba86f8b31e3a8c5af2e683ecd5a42f1874b3111a4159792037838ce5aa151406
MD5 38918f5273aea2547a3d54b6a4d4a046
BLAKE2b-256 ba6fb1c1ae7b7d3241dcf4b1b31b3b7efbba7c9191349b942da2aae811cb835e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page