Skip to main content

Toolbelt for pytorch framework

Project description

Torchest

Pytorch tools and utilities (Trainers, data generators, functions, and more...)

Trainers

One of the most common tasks you will do in pytorch is define training loops, which is a hassle because they are almost always the same. Torchest has trainers you can use for your projects

from torchest.trainer import SimpleTrainer

# previous definition of model and data preparation
# train_dataloader = DataLoader(train_data)
# dev_dataloader = DataLoader(dev_data)
# test_dataloader = DataLoader(test_data)
# model = nn.Sequential(...)
...

"""
Prepare Trainer
"""
loss_fn = nn.CrossEntropyLoss()
learning_rate= 1e-2
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

trainer = SimpleTrainer(model, loss_fn, optimizer)

"""
Train
"""
epochs = 500
trainer.train(data_train=train_dataloader, data_dev=dev_dataloader, data_test=test_dataloader, epochs=epochs)

Training loop progress

31%|██████████████████▍      | 156/500 [00:02<00:05, 58.01epoch/s, train_accuracy=76.2, train_cost=0.238290]

Trainer also saves the cost for the train, dev and test passes. You can simply call trainer.plot_costs() to display the graph with the costs

Trainer costs

Wandb visualization

Trainers support wandb to visualize loss and acurracy as well as to save model versions. if you want to enable you have to set 2 parameters in your trainer initialization, name and wandb

trainer = SimpleTrainer(model, loss_fn, optimizer, name="MyProject", wandb=True)

On the first run you will have to enter your API key, which you can obtain at https://wandb.ai/settings

Data generators

Spiral data

This will create a 2d matrix with points in a dataplot. Ideal for testing non-linearity in your network

from torchest.datagen import spiral_datagen

class_num =  3
X, Y = spiral_datagen(450, class_num) # 450 elements per class
Spiral data screenshot

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchest-0.0.5.tar.gz (6.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchest-0.0.5-py3-none-any.whl (7.2 kB view details)

Uploaded Python 3

File details

Details for the file torchest-0.0.5.tar.gz.

File metadata

  • Download URL: torchest-0.0.5.tar.gz
  • Upload date:
  • Size: 6.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for torchest-0.0.5.tar.gz
Algorithm Hash digest
SHA256 472c633ce452e5b9be26f883881d74319824bd6e16a38df0393f5fa865dbc597
MD5 547539e99d17d51dab5f44fb19088a8d
BLAKE2b-256 57e43acfa2097b0c368733b46064bd4c6c5acdac6e19431bd552bd759ff9c788

See more details on using hashes here.

File details

Details for the file torchest-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: torchest-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 7.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for torchest-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 0108ac0c2516c1100937fd4445593579f0203341e8ac152ab26c53fa64053d7e
MD5 e823805947f9c8cf8b41c5d6b21845a0
BLAKE2b-256 3e44bf1fa9f6fb819afbb26cefefba6dff154ea4321e805dac91f0772ac23cc1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page