Skip to main content

NetsPresso Python Package

Project description


Start training models (including ViTs) with NetsPresso Trainer, compress and deploy your model with NetsPresso!

Table of contents

Getting started

Write your training script in train.py like:

from netspresso_trainer import set_arguments, train

args_parsed, args = set_arguments(is_graphmodule_training=False)
train(args_parsed, args, is_graphmodule_training=False)

Then, train your model with your own configuraiton:

netspresso-train\
  --data config/data/beans.yaml\
  --augmentation config/augmentation/resnet.yaml\
  --model config/model/resnet.yaml\
  --training config/training/resnet.yaml\
  --logging config/logging.yaml\
  --environment config/environment.yaml

Please refer to scripts/example_train.sh.

NetsPresso Trainer is compatible with NetsPresso service. We provide NetsPresso Trainer tutorial that contains whole procedure from model train to model compression and benchmark. Please refer to our colab tutorial.

Installation

Prerequisites

  • Python 3.8 | 3.9 | 3.10
  • PyTorch 1.13.0 (recommended) (compatible with: 1.11.x - 1.13.x)

Install with pypi (stable)

pip install netspresso_trainer

Install with GitHub

pip install git+https://github.com:Nota-NetsPresso/netspresso-trainer.git@stable

To install with editable mode,

git clone https://github.com:Nota-NetsPresso/netspresso-trainer.git .
pip install -e netspresso-trainer

Set-up with docker

Please clone this repository and refer to Dockerfile and docker-compose-example.yml.
For docker users, we provide more detailed guide in our Docs.

Tensorboard

We provide basic tensorboard to track your training status. Run the tensorboard with the following command:

tensorboard --logdir ./outputs --port 50001 --bind_all

where PORT for tensorboard is 50001.
Note that the default directory of saving result will be ./outputs directory.

Pretrained weights

For now, we provide the pretrained weight from other awesome repositories. We have converted several models' weights into our own model architectures.
In the near soon, we are planning to provide the pretrained weights directly trained from our resources.
We appreciate all the original authors and we also do our best to make other values.

Download all weights (Google Drvie)

Family Model Link Origianl repository
ResNet resnet50 Google Drive torchvision
ViT vit_tiny Google Drive apple/ml-cvnets
MobileViT mobilevit_s Google Drive apple/ml-cvnets
SegFormer segformer Google Drive (Hugging Face) nvidia
EfficientForemer efficientformer_l1_3000d Google Drive snap-research/EfficientFormer
PIDNet pidnet_s Google Drive XuJiacong/PIDNet
MobileNetV3 mobilenetv3_small Google Drive torchvision

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

netspresso_trainer-0.0.10.tar.gz (119.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

netspresso_trainer-0.0.10-py3-none-any.whl (170.2 kB view details)

Uploaded Python 3

File details

Details for the file netspresso_trainer-0.0.10.tar.gz.

File metadata

  • Download URL: netspresso_trainer-0.0.10.tar.gz
  • Upload date:
  • Size: 119.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.6

File hashes

Hashes for netspresso_trainer-0.0.10.tar.gz
Algorithm Hash digest
SHA256 04114096aad6cbf3f16649952c93ca501a9a1c36e71e89a856750618ec7d84c3
MD5 70f3bafd62f9be40c175409a688baf64
BLAKE2b-256 084be00c3370048e7f3628885cb737f889d8ecfb4c3355af11fbea7fef155d69

See more details on using hashes here.

File details

Details for the file netspresso_trainer-0.0.10-py3-none-any.whl.

File metadata

File hashes

Hashes for netspresso_trainer-0.0.10-py3-none-any.whl
Algorithm Hash digest
SHA256 936efa93bf146847b3c4f0c76cf47d78ddac81cbb52bb6ced08f1eea9009ad9b
MD5 fab0c4b2221a992fdc8134f036ad28d8
BLAKE2b-256 26ac2d993d78e8fe652b4957daaa55433003fa171e944e590ad9e3a811eacc01

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page