Skip to main content

AI Toolkit

Project description

ai-toolkit

Python 3.7+ Build Status GitHub license codecov

Motivation

When working on ML projects, especially supervised learning, there tends to be a lot of repeated code, because in every project, we always want a way to checkpoint our work, visualize loss curves in tensorboard, add additional metrics, and see example output. Some projects we are able to do this better than others. Ideally, we want to have some way to consolidate all of this code into a single place.

The problem is that Pytorch examples are generally not very similar. Like most data exploration, we want the ability to modify every part of the codebase to handle different loss metrics, different types of data, or different visualizations based on our data dimensions. Combining everything into a single repository often overcomplicates the underlying logic (making the training loop extremely unreadable, for example). We want to strike a balance between extremely minimalistic / readable code that makes it easy to add extra functionality when needed.

This project is for developers or ML scientists who want features of a fully-functioning ML pipeline from the beginning. Each project comes with consistent styling, an opinionated way of handling logging, metrics, and checkpointing / resuming training from checkpoints. It also integrates seamlessly with Google Colab and AWS/Google Cloud GPUs.

Try It Out!

The first thing you should do is go into one of the output_*/ folders and try training a model. We currently have the following models:

Notable Features

  • In train.py, the code performs some verification checks on all models to make sure you aren't mixing up your batch dimensions.
  • Try stopping it and starting it after a couple epochs - it should resume training from the same place.
  • On tensorboard, loss curves should already be plotting seamlessly across runs.
  • All checkpoints should be available in checkpoints/, which contains activation layers, input data, and best models.
  • Scheduling runs is easy by specifying a file in the configs/ folder.

Evaluation Criteria

The goal is for this repository to contain a series of clean ML examples of different levels of understanding that I can draw from and use as examples, test models, etc. I essentially want to gather all of the best-practice code gists I find or have used in the past, and make them modular and easily imported or exported for later use.

The goal is not for this to be some ML framework built on PyTorch, but to focus on a single researcher/developer workflow and make it very easy to begin working. Great for Kaggle competitions, simple data exploration, or experimenting with different models.

The rough evaluation metric for this repo's success is how fast I can start working on a Kaggle challenge after downloading the data: getting insights on the data, its distributions, running baseline and finetuning models, getting loss curves and plots.

Current Workflow

  1. Add data to your data/ folder and edit the corresponding DataasetLoader in datasets/.
  2. Add your config and model to configs/ and models/.
  3. Run train.py, which saves model checkpoints, output predictions, and tensorboards in the same folder.
  4. Start tensorboard using the checkpoints/ folder with tensorboard --logdir=checkpoints/
  5. Start and stop training using python train.py --checkpoint=<checkpoint name>. The code should automatically resume training at the previous epoch and continue logging to the previous tensorboard.
  6. Run python test.py --checkpoint=<checkpoint name> to get final predictions.

Directory Structure

  • checkpoints/ (Only created once you run train.py)
  • data/
  • configs/
  • ai_toolkit/
    • datasets/
    • losses/
    • metrics/
    • models/
      • layers/
      • ...
    • visualizations/
    • args.py (Modify default hyperparameters manually)
    • metric_tracker.py
    • test.py
    • train.py
    • util.py
    • verify.py
    • viz.py (Stub, create more visualizations if necessary)
  • tests/

Goal Workflow

  1. Move data into data/.
  2. Fill in preprocess.py and dataset.py. (explore data by running python viz.py)
  3. Change args.py to specify input/output dimensions, batch size, etc.
  4. Run train.py, which saves model checkpoints, output predictions, and tensorboards in the same folder. Also automatically starts tensorboard server in a tmux session. Resume training at any point.
  5. Run test.py to get final predictions.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ai_toolkit-0.0.2.tar.gz (30.9 kB view details)

Uploaded Source

Built Distribution

ai_toolkit-0.0.2-py3-none-any.whl (43.4 kB view details)

Uploaded Python 3

File details

Details for the file ai_toolkit-0.0.2.tar.gz.

File metadata

  • Download URL: ai_toolkit-0.0.2.tar.gz
  • Upload date:
  • Size: 30.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.7.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.2

File hashes

Hashes for ai_toolkit-0.0.2.tar.gz
Algorithm Hash digest
SHA256 5665e6ea0f0f20aef43c3e9d4a4d2edb5bac64df9fe56e794a44b6ed3ee93bea
MD5 3d557a5cfafbf269d2b006ab3a81010b
BLAKE2b-256 21c1a8b8c6acd94b82e20eb7c7f2d1f267dd1c55efb901bb1f731f696cb13e54

See more details on using hashes here.

File details

Details for the file ai_toolkit-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: ai_toolkit-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 43.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.7.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.2

File hashes

Hashes for ai_toolkit-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 4597ff6e5e5245e66677dce45f549030f81d5878480bd1936e42f356f2bee3ba
MD5 877d01b287e9c26246501d0c564aeb3a
BLAKE2b-256 52746714bb0a32c1e08890d016cd26c63199ed427881476cf0721233b60e4ec2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page