A re-implementation of ViT containing utilities to convert to TensorRT engines and run in Triton.
Project description
ViT-TensorRT
A re-implementation of the original Vision Transformer (ViT) model in PyTorch. This repository also makes it easy to productionize the model with ONNX and TensorRT export.
Installation
ViT-TensorRT requires your system to have a NVIDIA GPU with CUDA installed. CUDA 12.4
has been tested with this repository.
To install vit-tensorrt
and its dependencies, run:
pip install vit-tensorrt
Training
Training the model can be achieved in a few lines of code:
from pathlib import Path
from vit_tensorrt.config import TrainConfig
from vit_tensorrt.model import ViT
model = ViT()
model.fit(TrainConfig(data_path=Path("path/to/data")))
This assumes the data is stored in a directory with the following structure:
├── data
├── train
├── images
├── uuid1.jpg
├── uuid2.jpg
└── ...
└── labels
├── uuid1.txt
├── uuid2.txt
└── ...
├── val/
└── ...
├── test/
└── ...
Where the label file contains a single number indicating the class of the corresponding image.
Export
ONNX export
Exporting the model to ONNX can be done with:
model.export_onnx("path/to/output.onnx")
The model will be exported with ONNX opset version 18
.
TensorRT export
Exporting the model to a TensorRT engine can be done with:
model.export_tensorrt("path/to/output.onnx")
The model will be exported with TensorRT version 10.4.0
.
Deploy with Triton Inference Server
TensorRT engines are exported using TensoRT version 10.4.0
therefore any Triton
Inference Servers compatible with this version of TensorRT can be used. The deployment
of ViT-TensorRT has been tested with nvcr.io/nvidia/tritonserver:24.09-py3
. If you
want to investigate alternatives, please refer to Triton's Containers.
The easiest way to deploy the model is by running the following in your terminal:
nvidia-docker run \
-dit \
--net=host \
--name vit-triton \
-v <path-to-model>:/models/vit/1 \
nvcr.io/nvidia/tritonserver:24.09-py3 \
tritonserver --model-repository=/models
The above assumes you have renamed your model to model.plan
. If you want to configure
how your model runs in Triton, please refer to Model Configuration.
Triton inference from Python client
To perform inference on the deployed model, you can use Python's Triton Inference Client:
from tritonclient import http
client = http.InferenceServerClient("localhost:8000")
# Create input data
inputs = [http.InferInput("input", [32, 3, 256, 256], "FP32")]
inputs[0].set_data_from_numpy(np.random.rand(32, 3, 256, 256).astype(np.float32))
# Run inference
results = client.infer("vit", inputs)
output = results.as_numpy("output")
Where input
and output
are the names of the input and output layers of the model,
respectively, and vit
is the name of the model in Triton. Make sure the input size you
specify matches the size that you trained the model with.
To increase model throughput, calls to Triton Inference Server should be made with shared GPU memory. This is more complicated to setup, but if you are interested please raise an issue on the repository and an example can be provided.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file vit_tensorrt-0.1.0.tar.gz
.
File metadata
- Download URL: vit_tensorrt-0.1.0.tar.gz
- Upload date:
- Size: 16.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: pdm/2.19.1 CPython/3.10.6 Linux/5.15.0-113-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d6db7d5d768f4ae694fee7d778670088ee3e872ab38d7e1f331095fc6663fbca |
|
MD5 | b6e04608f6adced202c39255c024414e |
|
BLAKE2b-256 | 080a08738a5686991213196892368a20c22285240373abca360ddfa767ec6cb2 |
File details
Details for the file vit_tensorrt-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: vit_tensorrt-0.1.0-py3-none-any.whl
- Upload date:
- Size: 19.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: pdm/2.19.1 CPython/3.10.6 Linux/5.15.0-113-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 29928ba2ae2c4b27b7d0af1c3a2606c1fb6c8105892e936c8cebd234838c2dbe |
|
MD5 | b22a672226ad0a486372e52ccc3e113a |
|
BLAKE2b-256 | 89849086a081307ef764fa60a2cd63065f33e5e7d4f4df022a47bb347eb71efb |