Skip to main content

PyTriton - Flask/FastAPI-like interface to simplify Triton's deployment in Python environments.

Project description

PyTriton is a Flask/FastAPI-like interface that simplifies Triton’s deployment in Python environments. The library allows serving Machine Learning models directly from Python through NVIDIA’s Triton Inference Server.

In PyTriton, as in Flask or FastAPI, you can define any Python function that executes a machine learning model prediction and exposes it through an HTTP/gRPC API. PyTriton installs Triton Inference Server in your environment and uses it for handling HTTP/gRPC requests and responses. Our library provides a Python API that allows attaching a Python function to Triton and a communication layer to send/receive data between Triton and the function. This solution helps utilize the performance features of Triton Inference Server, such as dynamic batching or response cache, without changing your model environment. Thus, it improves the performance of running inference on GPU for models implemented in Python. The solution is framework-agnostic and can be used along with frameworks like PyTorch, TensorFlow, or JAX.

Installation

The package can be installed from pypi using:

pip install -U nvidia-pytriton

More details about installation can be found in the documentation.

Example

The example presents how to run Python model in Triton Inference Server without need to change the current working environment. In the example we are using a simple Linear PyTorch model.

The requirement for the example is to have installed PyTorch in your environment. You can do it running:

pip install torch

In the next step define the Linear model:

import torch

model = torch.nn.Linear(2, 3).to("cuda").eval()

Create a function for handling inference request:

import numpy as np
from pytriton.decorators import batch


@batch
def infer_fn(**inputs: np.ndarray):
    (input1_batch,) = inputs.values()
    input1_batch_tensor = torch.from_numpy(input1_batch).to("cuda")
    output1_batch_tensor = model(input1_batch_tensor)  # Calling the Python model inference
    output1_batch = output1_batch_tensor.cpu().detach().numpy()
    return [output1_batch]

In the next step, create the connection between the model and Triton Inference Server using the bind method:

from pytriton.model_config import ModelConfig, Tensor
from pytriton.triton import Triton

# Connecting inference callback with Triton Inference Server
with Triton() as triton:
    # Load model into Triton Inference Server
    triton.bind(
        model_name="Linear",
        infer_func=infer_fn,
        inputs=[
            Tensor(dtype=np.float32, shape=(-1,)),
        ],
        outputs=[
            Tensor(dtype=np.float32, shape=(-1,)),
        ],
        config=ModelConfig(max_batch_size=128)
    )

Finally, serve the model with Triton Inference Server:

from pytriton.triton import Triton

with Triton() as triton:
    ...  # Load models here
    triton.serve()

The bind method is creating a connection between Triton Inference Server and the infer_fn which handle the inference queries. The inputs and outputs describe the model inputs and outputs that are exposed in Triton. The config field allows more parameters for model deployment.

The serve method is blocking and at this point the application will wait for incoming HTTP/gRPC requests. From that moment the model is available under name Linear in Triton server. The inference queries can be sent to localhost:8000/v2/models/Linear/infer which are passed to the infer_fn function.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

nvidia_pytriton-0.4.2-py3-none-manylinux_2_35_x86_64.whl (40.2 MB view details)

Uploaded Python 3manylinux: glibc 2.35+ x86-64

nvidia_pytriton-0.4.2-py3-none-manylinux_2_35_aarch64.whl (38.9 MB view details)

Uploaded Python 3manylinux: glibc 2.35+ ARM64

File details

Details for the file nvidia_pytriton-0.4.2-py3-none-manylinux_2_35_x86_64.whl.

File metadata

  • Download URL: nvidia_pytriton-0.4.2-py3-none-manylinux_2_35_x86_64.whl
  • Upload date:
  • Size: 40.2 MB
  • Tags: Python 3, manylinux: glibc 2.35+ x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.9.6 readme-renderer/42.0 requests/2.31.0 requests-toolbelt/1.0.0 urllib3/2.0.7 tqdm/4.66.1 importlib-metadata/6.8.0 keyring/24.2.0 rfc3986/2.0.0 colorama/0.4.6 CPython/3.10.12

File hashes

Hashes for nvidia_pytriton-0.4.2-py3-none-manylinux_2_35_x86_64.whl
Algorithm Hash digest
SHA256 8f79e8bda28961a49d5c64474c17664e10b1bb08018b7ee7bcbd81bcff86e266
MD5 8b4fc055b75356ae7d7349d684b4a7b5
BLAKE2b-256 2f532fefaf4f0ba0d481ea628a424677ce4d05a438b351538f33bb730b851fee

See more details on using hashes here.

File details

Details for the file nvidia_pytriton-0.4.2-py3-none-manylinux_2_35_aarch64.whl.

File metadata

  • Download URL: nvidia_pytriton-0.4.2-py3-none-manylinux_2_35_aarch64.whl
  • Upload date:
  • Size: 38.9 MB
  • Tags: Python 3, manylinux: glibc 2.35+ ARM64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.9.6 readme-renderer/42.0 requests/2.31.0 requests-toolbelt/1.0.0 urllib3/2.0.7 tqdm/4.66.1 importlib-metadata/6.8.0 keyring/24.2.0 rfc3986/2.0.0 colorama/0.4.6 CPython/3.10.12

File hashes

Hashes for nvidia_pytriton-0.4.2-py3-none-manylinux_2_35_aarch64.whl
Algorithm Hash digest
SHA256 0d204fa71774c21768c8f985d98aa55d80945fa7066b189e8885035517635956
MD5 fe6a2234371161b61224faaccc2e3619
BLAKE2b-256 2b43b54f16902ed68ca7887eee006e516cbc6f7a5186a93cfd4b8db756cd979c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page