Skip to main content

Fitting simple torch models

Project description

Fitting PyTorch models

License: MIT

This repository follows a workshop to set up a Python package, build some neural networks with torch, and publish the models on HuggingFace.

There are two datasets that we simulate:

  1. Colored shapes (circle, rectangle, triangle, diamond) in a pixellated image
  2. One-dimensional sinusoids

Alongside these datasets, we fit the following model objectives:

  1. (Classification) Predict the shape and color in the image
  2. (Regression) Predict the next time steps of the sine function

Usage

Learn how to train neural networks from scratch.

  1. Run answers/simulate-exercise.ipynb to get data.
  2. Fill in the # TO DO parts in examples/modeling-exercise-*.ipynb.
  3. Compare to solutions in answers/modeling-exercise-*.ipynb.
  4. You can explore different parameters on big models with scripts/modeling.py.
    • Write a shell script that invokes scripts/modeling.py and pass args to slurm.
  5. Run scripts/modeling-final.py for best model choice (train + val data).
  6. (Optional) Compare to the benchmark here.

The package defined under src/ provides:

  • A class Shape that instantiates an image with 1 colored shape
  • A function simulate_shapes() to make many images for an image classifier
  • A model class MyCNN to fit a standard architecture

Caution: you may need GPU resources if your models or data are large.

Requirements

  • Python 3.10+

Install

If you want to install the package only from the internet:

pip install zootopia3

If you want to set up an isolated environment and build locally:

python -m venv path-to/your-environment
source path-to/your-environment/bin/activate
pip install -e .

You run the pip command within this repo.

Data

I made a training and validation set with:

  • 2000 samples for each combo
  • mix_x = 20
  • max_x = 100
  • shades = True
  • magnitude = 50

I made a testing set with:

  • 20 samples for each combo
  • min_x = 10
  • max_x = 50

Therefore, the test set is a more difficult prediction problem.

You can find the test data here.

Sharing your work

To upload your model to Hugging Face, run these short scripts.

python scripts/hf-convert.py your-model.pt your-model.safetensors
python scripts/hf-model.py your-model.safetensors sdtemple/color-prediction-model

To upload your data to Hugging Face, run this script.

python scripts/hf-dataset.py images.npy target_color.txt target_shape.txt your-username/colored-shapes

If you successfully uploaded a pretrained model to Hugging Face, you can run it on a CPU with answers/inference-exercise.ipynb.

Test

You can run the test scripts in tests/ with the following:

python -m pytest

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zootopia3-1.1.tar.gz (13.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zootopia3-1.1-py3-none-any.whl (16.4 kB view details)

Uploaded Python 3

File details

Details for the file zootopia3-1.1.tar.gz.

File metadata

  • Download URL: zootopia3-1.1.tar.gz
  • Upload date:
  • Size: 13.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.10

File hashes

Hashes for zootopia3-1.1.tar.gz
Algorithm Hash digest
SHA256 40805d67ba38a4361584ea1a41536fc2be273b906ec3ed3abe2529ea83b53890
MD5 cf2ada4ce49dff185e0591e96641331b
BLAKE2b-256 2cf79705f7b18ccc2d4cd9ff91a91e280b0ac7d70dd08a8573fb19dce86e214b

See more details on using hashes here.

File details

Details for the file zootopia3-1.1-py3-none-any.whl.

File metadata

  • Download URL: zootopia3-1.1-py3-none-any.whl
  • Upload date:
  • Size: 16.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.10

File hashes

Hashes for zootopia3-1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 bb1ef080d0eb40bb05884f270b2466726c0a53b8476e323ff8817945a642ff4d
MD5 9f30de8349f369b28313e3c4c6a8c8c8
BLAKE2b-256 37828056f574743ef9e4c1da7e0b54d9405086ce1e92ba59f19656f65f9c6aa4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page