training Pytorch models with onnxruntime
Project description
Train PyTorch models with ONNX Runtime
PyTorch/ORT is a Python package that uses ONNX Runtime to accelerate PyTorch model training.
Pre-requisites
You need a machine with at least one NVDIA GPU to run PyTorch/ORT.
You can install run PyTorch/ORT in your local environment, or with Docker. If you are using Docker, the following base image is suitable: nvidia/cuda:11.1.1-cudnn8-devel-ubuntu18.04
.
Install
-
Install CUDA
-
Install CuDNN
-
Install PyTorch/ORT and dependencies
pip install onnx ninja
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
pip install --pre onnxruntime-training -f https://onnxruntimepackages.z14.web.core.windows.net/onnxruntime_nightly.html
pip install torch-ort
to install release package of onnxruntime-training:
pip install onnxruntime-training
Test your installation
- Clone this repo
git clone git@github.com:pytorch/ort.git
- Install extra dependencies
pip install wget pandas sklearn transformers
- Run the training script
python ./ort/tests/bert_for_sequence_classification.py
Add PyTorch/ORT to your PyTorch training script
import onnxruntime
from torch_ort import ORTModule
model = ORTModule(model)
# PyTorch training script follows
Versioning
CUDA
The PyTorch/ORT package was built with CUDA 11.1. If you have a different version of CUDA installed, you should install the CUDA 11.1 toolkit.
This is a limitation that will be removed in the next release.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Hashes for torch_ort-0.0.10.dev20210419-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fda80edbb1bd9e495966d5c570e45232c956b4900eb31bd35c4b3bd2eebcc42d |
|
MD5 | 78a735393dde13568de6fd7db89de33d |
|
BLAKE2b-256 | c1badb190e5beee462e55972012315da199dfd3320017b27ec1381c29a723f61 |