Skip to main content

Fitting simple torch models

Project description

Fitting PyTorch models

License: MIT

This repository follows a workshop to set up a Python package, build some neural networks with torch, and publish the models on HuggingFace.

There are two datasets that we simulate:

  1. Colored shapes (circle, rectangle, triangle, diamond) in a pixellated image
  2. One-dimensional sinusoids

Alongside these datasets, we fit the following model objectives:

  1. (Classification) Predict the shape and color in the image
  2. (Regression) Predict the next time steps of the sine function

Usage

Learn how to train neural networks from scratch.

  1. Run answers/simulate-exercise.ipynb to get data.
  2. Fill in the # TO DO parts in examples/modeling-exercise-*.ipynb.
  3. Compare to solutions in answers/modeling-exercise-*.ipynb.
  4. You can explore different parameters on big models with scripts/modeling.py.
    • Write a shell script that invokes scripts/modeling.py and pass args to slurm.
  5. Run scripts/modeling-final.py for best model choice (train + val data).
  6. (Optional) Compare to the benchmark here.

The package defined under src/ provides:

  • A class Shape that instantiates an image with 1 colored shape
  • A function simulate_shapes() to make many images for an image classifier
  • A model class MyCNN to fit a standard architecture

Caution: you may need GPU resources if your models or data are large.

Requirements

  • Python 3.10+

Install

If you want to install the package only from the internet:

pip install zootopia3

If you want to set up an isolated environment and build locally:

python -m venv path-to/your-environment
source path-to/your-environment/bin/activate
pip install -e .

You run the pip command within this repo.

Data

I made a training and validation set with:

  • 2000 samples for each combo
  • mix_x = 20
  • max_x = 100
  • shades = True
  • magnitude = 50

I made a testing set with:

  • 20 samples for each combo
  • min_x = 10
  • max_x = 50

Therefore, the test set is a more difficult prediction problem.

You can find the test data here.

Sharing your work

To upload your model to Hugging Face, run these short scripts.

python scripts/hf-convert.py your-model.pt your-model.safetensors
python scripts/hf-model.py your-model.safetensors sdtemple/color-classifier

To upload your data to Hugging Face, run this script.

python scripts/hf-dataset.py images.npy target_color.txt target_shape.txt your-username/colored-shapes

Test

You can run the test scripts in tests/ with the following:

python -m pytest

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zootopia3-1.0.tar.gz (13.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zootopia3-1.0-py3-none-any.whl (16.2 kB view details)

Uploaded Python 3

File details

Details for the file zootopia3-1.0.tar.gz.

File metadata

  • Download URL: zootopia3-1.0.tar.gz
  • Upload date:
  • Size: 13.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.10

File hashes

Hashes for zootopia3-1.0.tar.gz
Algorithm Hash digest
SHA256 e50086d379b1c8b72946893527cf985d46b00246db4cff58e175db6b48d703d3
MD5 c47039f8492a0107e5f7add55a3edd3e
BLAKE2b-256 4f3ed18dfa441792b98fce4280fa1d0d7b375aa44081d3f03fe94ae6f0ebff18

See more details on using hashes here.

File details

Details for the file zootopia3-1.0-py3-none-any.whl.

File metadata

  • Download URL: zootopia3-1.0-py3-none-any.whl
  • Upload date:
  • Size: 16.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.10

File hashes

Hashes for zootopia3-1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 722af70004bbfd1b65e035615fdb52ae9694801e59cb4f25607c90966821c8a4
MD5 bdb24440304eacf6e3192d5047348ee2
BLAKE2b-256 a99b03d7b1ca2e9cf1ec216388683713d33ba8b230860f04165c205cdeda93ea

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page