Skip to main content

Toolbelt for pytorch framework

Project description

Torchest

Pytorch tools and utilities (Trainers, data generators, functions, and more...)

Trainers

One of the most common tasks you will do in pytorch is define training loops, which is a hassle because they are almost always the same. Torchest has trainers you can use for your projects

from torchest.trainer import SimpleTrainer

# previous definition of model and data preparation
# train_dataloader = DataLoader(train_data)
# dev_dataloader = DataLoader(dev_data)
# test_dataloader = DataLoader(test_data)
# model = nn.Sequential(...)
...

"""
Prepare Trainer
"""
loss_fn = nn.CrossEntropyLoss()
learning_rate= 1e-2
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

trainer = SimpleTrainer(model, loss_fn, optimizer)

"""
Train
"""
epochs = 500
trainer.train(data_train=train_dataloader, data_dev=dev_dataloader, data_test=test_dataloader, epochs=epochs)

Training loop progress

31%|██████████████████▍      | 156/500 [00:02<00:05, 58.01epoch/s, train_accuracy=76.2, train_cost=0.238290]

Trainer also saves the cost for the train, dev and test passes. You can simply call trainer.plot_costs() to display the graph with the costs

Trainer costs

Wandb visualization

Note: You must have wandb python package installed to use this feature.

pip install wandb

Trainers support wandb to visualize loss and acurracy as well as to save model versions. if you want to enable you have to set 2 parameters in your trainer initialization, name and wandb

trainer = SimpleTrainer(model, loss_fn, optimizer, name="MyProject", wandb=True)

On the first run you will have to enter your API key, which you can obtain at https://wandb.ai/settings

Data generators

Spiral data

This will create a 2d matrix with points in a dataplot. Ideal for testing non-linearity in your network

from torchest.datagen import spiral_datagen

class_num =  3
X, Y = spiral_datagen(450, class_num) # 450 elements per class
Spiral data screenshot

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchest-0.0.3.tar.gz (6.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchest-0.0.3-py3-none-any.whl (7.2 kB view details)

Uploaded Python 3

File details

Details for the file torchest-0.0.3.tar.gz.

File metadata

  • Download URL: torchest-0.0.3.tar.gz
  • Upload date:
  • Size: 6.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for torchest-0.0.3.tar.gz
Algorithm Hash digest
SHA256 73cc5f113b3b4951d34a0b818f7bde2aaab87b63da28fdc1ef5d34ccf761dd90
MD5 e8a9f2d4baca4afd745672df371f152e
BLAKE2b-256 d22723ff992733a6a7e590a34c16cfa6926866cabce90660713da8bfc57bea2a

See more details on using hashes here.

File details

Details for the file torchest-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: torchest-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 7.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for torchest-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 d94847ce00a615435ddfdaf26ccfc1899cf73340b2af8c8c86d02792c0b1b7cf
MD5 fb73791bdfd9697bf0aa8a4b2ef41712
BLAKE2b-256 f7898bf47242a14719e3d697241cc73a0c4f6dcfbe5d399a38f0a24aa7c5ee7e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page