Skip to main content

Serving pytorch models on an API in one line.

Project description

torch-deploy

Usage

Deploying a pretrained ResNet-18:

import torch
import torchvision.models as models
from torch_deploy import deploy

resnet18 = models.resnet18(pretrained=True)
resnet18.eval()
deploy(resnet18, pre=torch.tensor)

The default host and port is 0.0.0.0:8000.

Endpoints

/predict

Request body: application/json
Response body: application/json

Here's an example of how to use to use the /predict endpoint.

import requests
from PIL import Image
import numpy as np
from torchvision import transforms

im = Image.open('palm.jpg')
resize = transforms.Resize(224)
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
tensor = normalize(to_tensor(resize(im))).unsqueeze(0)
body = {"inputs": tensor.tolist()}
r = requests.post("http://127.0.0.1:8000/predict", json=body)
response = r.json()
output = np.array(response["output"])

Note that you need to send the model input in the request JSON body under the field "inputs". If you want to send a tensor or a numpy array in the request, you need to turn it into a list first.

The output of the model will be in the response JSON body under the "output" field.

Sample response format:

response = {"output": (your numpy array as a list here)}

Documentation

torch_deploy.deploy(
    model: nn.Module,
    pre: Union[List[Callable], Callable] = None,
    post: Union[List[Callable], Callable] = None,
    host: str = "0.0.0.0",
    port: int = 8000,
    ssl_keyfile: str = None,
    ssl_certfile: str = None,
    ssl_ca_certs: str = None,
    logdir: str = "./deploy_logs/",
    inference_fn: str = None
)

Easily converts a pytorch model to API for production usage.

  • model: A PyTorch model which subclasses nn.Module and is callable. Model used for the API.
  • pre: A function or list of functions to be applied to the input.
  • post: Function or list of functions applied to model output before being sent as a response.
  • host: The address for serving the model.
  • port: The port for serving the model.
  • ssl_keyfile, ssl_certfile, ssl_ca_certs: SSL configurations that are passed to uvicorn
  • logfile: Filename to create a file that stores date, ip address, and size of input for each access of the API. If None, no file will be created.
  • inference_fn: Name of the method of the model that should be called for the inputs. If None, the model itself will be called (If model is a nn.Module then it's equivalent to calling model.forward(inputs)).

Sample Response Format

Sample Code

Testing

Run python test_server.py first and then python test_client.py in another window to test.

Dependencies

torch, torchvision, fastapi[all], requests, numpy, pydantic

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-deploy-0.0.2.tar.gz (4.5 kB view details)

Uploaded Source

Built Distribution

pytorch_deploy-0.0.2-py3-none-any.whl (6.4 kB view details)

Uploaded Python 3

File details

Details for the file pytorch-deploy-0.0.2.tar.gz.

File metadata

  • Download URL: pytorch-deploy-0.0.2.tar.gz
  • Upload date:
  • Size: 4.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.6.0 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6

File hashes

Hashes for pytorch-deploy-0.0.2.tar.gz
Algorithm Hash digest
SHA256 0220abe0e91c6696a1c1d599720d6bce6d83da5c3cb56bcf35d3ae12cb2e1a3a
MD5 3106ef5fba5fe4b901871d3d9f1f621c
BLAKE2b-256 e27078ef40be44057a9d4501d7924a71a4d2edb82f9ae716cfd30c1a5529ebcc

See more details on using hashes here.

File details

Details for the file pytorch_deploy-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: pytorch_deploy-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 6.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.6.0 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6

File hashes

Hashes for pytorch_deploy-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 f88bd7cd5626130d0ac3c5e163fbaab450191c595be8ed70ccc03c48cce71b79
MD5 76621112694bf8fab563539636611eb7
BLAKE2b-256 209ebb0c42ff010f3d08cdc6940316220ef6e150ed18e5057695cde6772e37ec

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page