Skip to main content

A transparent boilerplate + bag of tricks to ease my (yours?) (our?) PyTorch dev time.

Project description

mytorch is your torch :fire:

Build Status GitHub PyPI

A transparent boilerplate + bag of tricks to ease my (yours?) (our?) PyTorch dev time.

Some parts here are inspired/copied from fast.ai. However, I've tried to keep is such that the control of model (model architecture), vocabulary, preprocessing is always maintained outside of this library. The training loop, data samplers etc can be used independent of anything else in here, but ofcourse work better together.

I'll be adding proper documentation, examples here, gradually.

Installation

pip install my-torch

(Added hyphen because someone beat me to the mytorch package name.)

Idea

Use/Ignore most parts of the library. Will not hide code from you, and you retain control over your models. If you need just one thing, no fluff, feel free to copy-paste snippets of the code from this repo to yours. I'd be delighted if you drop me a line, if you found this stuff helpful.

Features

  1. Customizable Training Loop

    • Callbacks @ epoch start and end
    • Weight Decay (see this blog post )
    • :scissors: Gradient Clipping
    • :floppy_disk: Model Saving
    • :bell: Mobile push notifications @ the end of training :ghost: ( See Usage) )
  2. Sortist Sampling

  3. Custom Learning Rate Schedules

  4. Customisability & Flat Hierarchy

Usage

Simplest Use Case

import torch, torch.nn as nn, numpy as np

# Assuming that you have a torch model with a predict and a forward function.
# model = MyModel()
assert type(model) is nn.Module

# X, Y are input and output labels for a text classification task with four classes. 200 examples.
X_trn = np.random.randint(0, 100, (200, 4))
Y_trn = np.random.randint(0, 4, (200, 1))
X_val = np.random.randint(0, 100, (100, 4))
Y_val = np.random.randint(0, 4, (100, 1))

# Preparing data
data = {"train":{"x":X_trn, "y":Y_trn}, "valid":{"x":X_val, "y":Y_val} }

# Specifying other hyperparameters
epochs = 10
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
loss_function = nn.functional.cross_entropy
train_function = model      # or model.forward
predict_function = model.predict

train_acc, valid_acc, train_loss = loops.simplest_loop(epochs=epochs, data=data, opt=optimizer,
                                                        loss_fn=loss_function, 
                                                        train_fn=train_function,
                                                        predict_fn=predict_function)

Slightly more complex examples

@TODO: They exist! Just need to add examples :sweat_smile:

  1. Custom eval
  2. Custom data sampler
  3. Custom learning rate annealing schedules

Saving the model

@TODO

Notifications

The training loop can send notifications to your phone informing you that your model's done training and report metrics alongwith. We use push.techulus.com to do so and you'll need the app on your phone. If you're not bothered, this part of the code will stay out of your way. But If you'd like this completely unnecessary gimmick, follow along:

  1. Get the app. Play Store | AppStore
  2. Sign In/Up and get yout api key
  3. Making the key available. Options:
    1. in a file, named ./push-techulus-key, in plaintext at the root dir of this folder. You could just echo 'your-api-key' >> ./push-techulus-ley.
    2. through arguments to the training loop as a string
  4. Pass flag to loop, to enable notifications
  5. Done :balloon: You'll be notified when your model's done training.

Changelog

v0.0.6

  1. Interfaced some metrics from torchmetrics, and implemented some more into a neat little package pending

v0.0.2

  1. Added negative sampling
  2. [TODO] Added multiple evaluation functions
  3. [TODO] Logging
  4. [TODO] Typing all functions

v0.0.1

  1. Added some tests.
  2. Wrapping spaCy tokenizers, with some vocab management.
  3. Packaging :confetti:

Upcoming

  1. Models
    1. Classifiers
    2. Encoders
    3. Transformers (USE pytorch-transformers by :huggingface:)
  2. Using FastProgress for progress + live plotting
  3. W&B integration
  4. ?? (tell me here)

Contributions

I'm eager to implement more tricks/features in the library, while maintaining the flat structure (and ensuring backward compatibility). Open to suggestions and contributions. Thanks!

PS: Always appreciate more tests.

Acknowledgements

An important part of the code was designed, and tested by :

Gaurav Maheshwari  ·  GitHub @saist1993  ·  Twitter @__gauravm

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

my-torch-0.0.13.tar.gz (19.2 kB view details)

Uploaded Source

Built Distribution

my_torch-0.0.13-py3-none-any.whl (21.4 kB view details)

Uploaded Python 3

File details

Details for the file my-torch-0.0.13.tar.gz.

File metadata

  • Download URL: my-torch-0.0.13.tar.gz
  • Upload date:
  • Size: 19.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.1 setuptools/58.0.4 requests-toolbelt/0.9.1 tqdm/4.64.0 CPython/3.8.5

File hashes

Hashes for my-torch-0.0.13.tar.gz
Algorithm Hash digest
SHA256 506a13eccf21941913e01751b654bdcea950861cb3135b52d6937ffb6044f93d
MD5 0b675cdf704c18453c4ec7a35a07d25c
BLAKE2b-256 a151e61688259c3b9675b7c92cc5e38b57abab6696ffd60519c786a677cedbf2

See more details on using hashes here.

File details

Details for the file my_torch-0.0.13-py3-none-any.whl.

File metadata

  • Download URL: my_torch-0.0.13-py3-none-any.whl
  • Upload date:
  • Size: 21.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.1 setuptools/58.0.4 requests-toolbelt/0.9.1 tqdm/4.64.0 CPython/3.8.5

File hashes

Hashes for my_torch-0.0.13-py3-none-any.whl
Algorithm Hash digest
SHA256 22d1604b1ef7ce3eebdb9aabc62243a890444f4f26083327df6497357e619920
MD5 5655a68227cf1937d9bcd253fcb2ce1f
BLAKE2b-256 9f7da62bc1bdd4abc805b29365f521c34f844ab1e892a4bb534345fc0a24c5cf

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page