Skip to main content

Toolbelt for pytorch framework

Project description

Torchest

Pytorch tools and utilities (Trainers, data generators, functions, and more...)

Trainers

One of the most common tasks you will do in pytorch is define training loops, which is a hassle because they are almost always the same. Torchest has trainers you can use for your projects

from torchest.trainer import SimpleTrainer

# previous definition of model and data preparation
# train_dataloader = DataLoader(train_data)
# dev_dataloader = DataLoader(dev_data)
# test_dataloader = DataLoader(test_data)
# model = nn.Sequential(...)
...

"""
Prepare Trainer
"""
loss_fn = nn.CrossEntropyLoss()
learning_rate= 1e-2
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

trainer = SimpleTrainer(model, loss_fn, optimizer)

"""
Train
"""
epochs = 500
trainer.train(data_train=train_dataloader, data_dev=dev_dataloader, data_test=test_dataloader, epochs=epochs)

Training loop progress

31%|██████████████████▍      | 156/500 [00:02<00:05, 58.01epoch/s, train_accuracy=76.2, train_cost=0.238290]

Trainer also saves the cost for the train, dev and test passes. You can simply call trainer.plot_costs() to display the graph with the costs

Trainer costs

Wandb visualization

Note: You must have wandb python package installed to use this feature.

pip install wandb

Trainers support wandb to visualize loss and acurracy as well as to save model versions. if you want to enable you have to set 2 parameters in your trainer initialization, name and wandb

trainer = SimpleTrainer(model, loss_fn, optimizer, name="MyProject", wandb=True)

On the first run you will have to enter your API key, which you can obtain at https://wandb.ai/settings

Data generators

Spiral data

This will create a 2d matrix with points in a dataplot. Ideal for testing non-linearity in your network

from torchest.datagen import spiral_datagen

class_num =  3
X, Y = spiral_datagen(450, class_num) # 450 elements per class
Spiral data screenshot

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchest-0.0.4.tar.gz (6.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchest-0.0.4-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file torchest-0.0.4.tar.gz.

File metadata

  • Download URL: torchest-0.0.4.tar.gz
  • Upload date:
  • Size: 6.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for torchest-0.0.4.tar.gz
Algorithm Hash digest
SHA256 b9e589e26f2ba5efa90b95a450091bd4a649f9897a428be054856efe0a43558e
MD5 2df7cf2579c6ffe4e0f28ae8856fcbc0
BLAKE2b-256 c7ee9b437a2f6d399992abbf023a506b9ae4a4f3e3e05f07436e575c791c3c62

See more details on using hashes here.

File details

Details for the file torchest-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: torchest-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 7.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for torchest-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 da3e1c558b14cc048e38bd0e1dd14c00f39b005c45d7f213c869c1071b5da830
MD5 738ef060d4cb81e9563a7b88486d910e
BLAKE2b-256 c59487918fa2c17ebcf3baf22cabb56d2bc1d16bfd209e761db10da0bf77a10e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page