Skip to main content

Fitting simple torch models

Project description

Fitting PyTorch models

License: MIT

This repository follows a workshop to set up a Python package, build some neural networks with torch, and publish the models on HuggingFace.

There are two datasets that we simulate:

  1. Colored shapes (circle, rectangle, triangle, diamond) in a pixellated image
  2. One-dimensional sinusoids

Alongside these datasets, we fit the following model objectives:

  1. (Classification) Predict the shape and color in the image
  2. (Regression) Predict the next time steps of the sine function

Usage

Learn how to train neural networks from scratch.

  1. Run answers/simulate-exercise.ipynb to get data.
  2. Fill in the # TO DO parts in examples/modeling-exercise-*.ipynb.
  3. Compare to solutions in answers/modeling-exercise-*.ipynb.
  4. You can explore different parameters on big models with scripts/modeling.py.
    • Write a shell script that invokes scripts/modeling.py and pass args to slurm.
  5. Run scripts/modeling-final.py for best model choice (train + val data).
  6. (Optional) Compare to the benchmark here.

The package defined under src/ provides:

  • A class Shape that instantiates an image with 1 colored shape
  • A function simulate_shapes() to make many images for an image classifier
  • A model class MyCNN to fit a standard architecture

Caution: you may need GPU resources if your models or data are large.

Requirements

  • Python 3.10+

Install

If you want to install the package only from the internet:

pip install zootopia3

If you want to set up an isolated environment and build locally:

python -m venv path-to/your-environment
source path-to/your-environment/bin/activate
pip install -e .

You run the pip command within this repo.

Data

I made a training and validation set with:

  • 2000 samples for each combo
  • mix_x = 20
  • max_x = 100
  • shades = True
  • magnitude = 50

I made a testing set with:

  • 20 samples for each combo
  • min_x = 10
  • max_x = 50

Therefore, the test set is a more difficult prediction problem.

You can find the test data here.

Sharing your work

To upload your model to Hugging Face, run these short scripts.

python hf-push.py your-model.pth sdtemple/color-prediction-model

A less elegant solution where you have to manually write the config.json data is:

python scripts/hf-convert.py your-model.pt your-model.safetensors
python scripts/hf-model.py your-model.safetensors sdtemple/color-prediction-model

To upload your data to Hugging Face, run this script.

python scripts/hf-dataset.py images.npy target_color.txt target_shape.txt your-username/colored-shapes

If you successfully uploaded a pretrained model to Hugging Face, you can run it on a CPU with answers/inference-exercise.ipynb.

Test

You can run the test scripts in tests/ with the following:

python -m pytest

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zootopia3-1.2.tar.gz (13.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zootopia3-1.2-py3-none-any.whl (16.4 kB view details)

Uploaded Python 3

File details

Details for the file zootopia3-1.2.tar.gz.

File metadata

  • Download URL: zootopia3-1.2.tar.gz
  • Upload date:
  • Size: 13.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.10

File hashes

Hashes for zootopia3-1.2.tar.gz
Algorithm Hash digest
SHA256 53f03a6f90882a85d50a0cf5f68777c9b358e56d308f28ac4bb0b306de010d90
MD5 d19da184f9aed9f15546619ee4a135b9
BLAKE2b-256 939028f114d142b64a89e07438e42b1115f4e1c5990f50c1901f9855aa99346b

See more details on using hashes here.

File details

Details for the file zootopia3-1.2-py3-none-any.whl.

File metadata

  • Download URL: zootopia3-1.2-py3-none-any.whl
  • Upload date:
  • Size: 16.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.10

File hashes

Hashes for zootopia3-1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 e59a9629e854540013f4fe4d7a073de5aeb7610d479a70b0cf861af144777a68
MD5 e8e3444bb260622fb28f64a2be2e27da
BLAKE2b-256 89c7138fdda95adf025e22cd253131544a0520ae344a14a0bca0bf640f7e9d09

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page