Skip to main content

Toolbelt for pytorch framework

Project description

Torchest

Pytorch tools and utilities (Trainers, data generators, functions, and more...)

Trainers

One of the most common tasks you will do in pytorch is define training loops, which is a hassle because they are almost always the same. Torchest has trainers you can use for your projects

from torchest.trainer import SimpleTrainer

# previous definition of model and data preparation
# train_dataloader = DataLoader(train_data)
# dev_dataloader = DataLoader(dev_data)
# test_dataloader = DataLoader(test_data)
# model = nn.Sequential(...)
...

"""
Prepare Trainer
"""
loss_fn = nn.CrossEntropyLoss()
learning_rate= 1e-2
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

trainer = SimpleTrainer(model, loss_fn, optimizer)

"""
Train
"""
epochs = 500
trainer.train(data_train=train_dataloader, data_dev=dev_dataloader, data_test=test_dataloader, epochs=epochs)

Training loop progress

31%|██████████████████▍      | 156/500 [00:02<00:05, 58.01epoch/s, train_accuracy=76.2, train_cost=0.238290]

Trainer also saves the cost for the train, dev and test passes. You can simply call trainer.plot_costs() to display the graph with the costs

Trainer costs

Wandb visualization

Trainers support wandb to visualize loss and acurracy as well as to save model versions. if you want to enable you have to set 2 parameters in your trainer initialization, name and wandb

trainer = SimpleTrainer(model, loss_fn, optimizer, wandb_project_name="MyProject", wandb=True)

On the first run you will have to enter your API key, which you can obtain at https://wandb.ai/settings

Data generators

Spiral data

This will create a 2d matrix with points in a dataplot. Ideal for testing non-linearity in your network

from torchest.datagen import spiral_datagen

class_num =  3
X, Y = spiral_datagen(450, class_num) # 450 elements per class
Spiral data screenshot

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchest-0.0.6.tar.gz (7.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchest-0.0.6-py3-none-any.whl (7.6 kB view details)

Uploaded Python 3

File details

Details for the file torchest-0.0.6.tar.gz.

File metadata

  • Download URL: torchest-0.0.6.tar.gz
  • Upload date:
  • Size: 7.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for torchest-0.0.6.tar.gz
Algorithm Hash digest
SHA256 a214051b7127432b7ea551e007d0ec0923b6098365571d5806e4773fa96a5688
MD5 157b37ff163fb210ec879a98c7d13087
BLAKE2b-256 513ad5c3ad1bf07d9436728ed42e3893b52c5ddd5a4012d642931f45d62d0010

See more details on using hashes here.

File details

Details for the file torchest-0.0.6-py3-none-any.whl.

File metadata

  • Download URL: torchest-0.0.6-py3-none-any.whl
  • Upload date:
  • Size: 7.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for torchest-0.0.6-py3-none-any.whl
Algorithm Hash digest
SHA256 8d706e814e54655c57c6d9fc71b216c4aff4566d33b79526878cc7d9f5268e3f
MD5 1fe4e5fd0118cd5ead356c5743d814f9
BLAKE2b-256 52c27727ffc7eb228909cdeb788099dd1b29365247ccf98825e2e0d7ff3b459a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page