Skip to main content

Package for training and evaluating neural network models made using pytorch

Project description

run-torch-model

Simple program to run a pytorch compatible model. Includes a tools function which has support for creating dataloader objects required for training/testing/validation.

Requirements

  • torch
  • torchmetrics
  • scikit-learn

Install

Install using pip:

pip install run-torch-model

Usage

Use create_dataloader to initiate datasets for training and testing:

from run_torch_model import create_dataloader

dataloader_train, dataloader_test = create_dataloader(features=features, 
                                                      targets=targets,
                                                      batch_size=batch_size,
                                                      train_size=train_size,
                                                      test_size=test_size)

To run a model we define the optimizer, its arguments and a criterion, feed into the class and perform a call for training.

import torch
from run_torch_model import RunTorchNN

optimizer = 'torch.optim.Adam'  # Must be string, if CUDA we initiate the optimizer after calling .cuda for speed-up
optimizer_args = {'lr': 0.001'} # Initialize some arguments for the optimizer
criterion = torch.nn.MSELoss()

run_model = RunTorchNN(model, # Some pytorch model
                          epochs=100, 
                          optimizer=optimizer,
                          optimizer_args=optimizer_args,
                          dataloaders=(dataloader_train, dataloader_test), 
                          criterion=criterion)
 
run_model() # Executes the training

To fetch metrics:

R2 = run_model.get_r2score()
loss = run_model.get_average_loss()

To evaluate the trained model on a different set of features:

predictions, loss = run_model.predict(new_features)

To evaluate the model on a validation set:

loss, r2 = run_model.evaluate(dataloader_validation)
predictions = run_model.get_predictions()  # To get predictions, if necessary 

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

run_torch_model-1.0.2.tar.gz (7.9 kB view details)

Uploaded Source

Built Distribution

run_torch_model-1.0.2-py3-none-any.whl (8.5 kB view details)

Uploaded Python 3

File details

Details for the file run_torch_model-1.0.2.tar.gz.

File metadata

  • Download URL: run_torch_model-1.0.2.tar.gz
  • Upload date:
  • Size: 7.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.26.0 requests-toolbelt/0.9.1 urllib3/1.26.7 tqdm/4.62.3 importlib-metadata/4.8.1 keyring/23.2.1 rfc3986/2.0.0 colorama/0.4.4 CPython/3.8.8

File hashes

Hashes for run_torch_model-1.0.2.tar.gz
Algorithm Hash digest
SHA256 87a9c052970b00f940fcfa17c4acd5c86f447330678f8f457d1a070bb7130d11
MD5 25af0fbc61f711e053fbb1b9ab965950
BLAKE2b-256 c70cc347c314dcc59d48e2dc60e1ade2d7a0d77f62afd6ec7710180c46ea6402

See more details on using hashes here.

File details

Details for the file run_torch_model-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: run_torch_model-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 8.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.26.0 requests-toolbelt/0.9.1 urllib3/1.26.7 tqdm/4.62.3 importlib-metadata/4.8.1 keyring/23.2.1 rfc3986/2.0.0 colorama/0.4.4 CPython/3.8.8

File hashes

Hashes for run_torch_model-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 ca15c26bfcaddb4cbf57a49b233d0c62b6aee3553f916ff1be334149f2545875
MD5 ed42aa88b5b98c05d347dbda9af68be9
BLAKE2b-256 5f1fac9987bafdbe16ff5b3321d826cf2b97b16eddc3f1af6b1a073213420b97

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page