training Pytorch models with onnxruntime
Project description
Train PyTorch models with ONNX Runtime
PyTorch/ORT is a Python package that uses ONNX Runtime to accelerate PyTorch model training.
Pre-requisites
You need a machine with at least one NVIDIA or AMD GPU to run PyTorch/ORT.
You can install run PyTorch/ORT in your local environment, or with Docker. If you are using Docker, the following base image is suitable for Nvidia and AMD respectively : nvidia/cuda:11.1.1-cudnn8-devel-ubuntu18.04
or rocm/pytorch:rocm4.1.1_ubuntu18.04_py3.6_pytorch
.
Install for Nvidia GPUs
-
Install CUDA
-
Install CuDNN
-
Install PyTorch/ORT and dependencies
Nvidia CUDA version 11.1
pip install onnx ninja
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
pip install --pre onnxruntime-training -f https://onnxruntimepackages.z14.web.core.windows.net/onnxruntime_nightly_cu111.html
pip install torch-ort
Nvidia CUDA version 10.2
pip install onnx ninja
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html
pip install --pre onnxruntime-training -f https://onnxruntimepackages.z14.web.core.windows.net/onnxruntime_nightly_cu102.html
pip install torch-ort
Install for AMD GPUs
-
Install Rocm 4.1 base package (instructions)
-
Install Rocm 4.1 libraries (instructions)
-
Install Rocm 4.1 RCCL (instructions)
-
Install PyTorch/ORT and dependencies
AMD ROCM version 4.1
pip install onnx ninja
pip install --pre torch -f https://download.pytorch.org/whl/nightly/rocm4.1/torch_nightly.html
pip install --pre onnxruntime-training -f https://onnxruntimepackages.z14.web.core.windows.net/onnxruntime_nightly_rocm41.html
pip install torch-ort
to install release package of onnxruntime-training:
pip install onnxruntime-training
Test your installation
- Clone this repo
git clone git@github.com:pytorch/ort.git
- Install extra dependencies
pip install wget pandas sklearn transformers
- Run the training script
python ./ort/tests/bert_for_sequence_classification.py
Add PyTorch/ORT to your PyTorch training script
import onnxruntime
from torch_ort import ORTModule
model = ORTModule(model)
# PyTorch training script follows
Versioning
CUDA
The PyTorch/ORT package was built with CUDA 11.1. If you have a different version of CUDA installed, you should install the CUDA 11.1 toolkit.
This is a limitation that will be removed in the next release.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Hashes for torch_ort-0.0.10.dev20210505-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 590ccce4eed376266e13bedb4b0901d1d330870a2f7d786704f781d2447274bb |
|
MD5 | b16ec804e27f90557409607475b8daf3 |
|
BLAKE2b-256 | 7651faab82727fdb02fadfb1b976100ff5e5dffcdf247b7dcb7faece46b7e836 |