Toolbelt for pytorch framework
Project description
Torchest
Pytorch tools and utilities (Trainers, data generators, functions, and more...)
Trainers
One of the most common tasks you will do in pytorch is define training loops, which is a hassle because they are almost always the same. Torchest has trainers you can use for your projects
from torchest.trainer import SimpleTrainer
# previous definition of model and data preparation
# train_dataloader = DataLoader(train_data)
# dev_dataloader = DataLoader(dev_data)
# test_dataloader = DataLoader(test_data)
# model = nn.Sequential(...)
...
"""
Prepare Trainer
"""
loss_fn = nn.CrossEntropyLoss()
learning_rate= 1e-2
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
trainer = SimpleTrainer(model, loss_fn, optimizer)
"""
Train
"""
epochs = 500
trainer.train(data_train=train_dataloader, data_dev=dev_dataloader, data_test=test_dataloader, epochs=epochs)
Training loop progress
31%|██████████████████▍ | 156/500 [00:02<00:05, 58.01epoch/s, train_accuracy=76.2, train_cost=0.238290]
Trainer also saves the cost for the train, dev and test passes. You can simply call trainer.plot_costs() to display the graph with the costs
Wandb visualization
Note: You must have wandb python package installed to use this feature.
pip install wandb
Trainers support wandb to visualize loss and acurracy as well as to save model versions. if you want to enable you have to set 2 parameters in your trainer initialization, name and wandb
trainer = SimpleTrainer(model, loss_fn, optimizer, name="MyProject", wandb=True)
On the first run you will have to enter your API key, which you can obtain at https://wandb.ai/settings
Data generators
Spiral data
This will create a 2d matrix with points in a dataplot. Ideal for testing non-linearity in your network
from torchest.datagen import spiral_datagen
class_num = 3
X, Y = spiral_datagen(450, class_num) # 450 elements per class
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchest-0.0.3.tar.gz.
File metadata
- Download URL: torchest-0.0.3.tar.gz
- Upload date:
- Size: 6.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
73cc5f113b3b4951d34a0b818f7bde2aaab87b63da28fdc1ef5d34ccf761dd90
|
|
| MD5 |
e8a9f2d4baca4afd745672df371f152e
|
|
| BLAKE2b-256 |
d22723ff992733a6a7e590a34c16cfa6926866cabce90660713da8bfc57bea2a
|
File details
Details for the file torchest-0.0.3-py3-none-any.whl.
File metadata
- Download URL: torchest-0.0.3-py3-none-any.whl
- Upload date:
- Size: 7.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d94847ce00a615435ddfdaf26ccfc1899cf73340b2af8c8c86d02792c0b1b7cf
|
|
| MD5 |
fb73791bdfd9697bf0aa8a4b2ef41712
|
|
| BLAKE2b-256 |
f7898bf47242a14719e3d697241cc73a0c4f6dcfbe5d399a38f0a24aa7c5ee7e
|