Serving pytorch models on an API in one line.
Project description
torch-deploy
Usage
Deploying a pretrained ResNet-18:
import torch
import torchvision.models as models
from torch_deploy import deploy
resnet18 = models.resnet18(pretrained=True)
resnet18.eval()
deploy(resnet18, pre=torch.tensor)
The default host and port is 0.0.0.0:8000.
Endpoints
/predict
Request body: application/json
Response body: application/json
Here's an example of how to use to use the /predict endpoint.
import requests
from PIL import Image
import numpy as np
from torchvision import transforms
im = Image.open('palm.jpg')
resize = transforms.Resize(224)
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
tensor = normalize(to_tensor(resize(im))).unsqueeze(0)
body = {"inputs": tensor.tolist()}
r = requests.post("http://127.0.0.1:8000/predict", json=body)
response = r.json()
output = np.array(response["output"])
Note that you need to send the model input in the request JSON body under the field "inputs". If you want to send a tensor or a numpy array in the request, you need to turn it into a list first.
The output of the model will be in the response JSON body under the "output" field.
Sample response format:
response = {"output": (your numpy array as a list here)}
Documentation
torch_deploy.deploy(
model: nn.Module,
pre: Union[List[Callable], Callable] = None,
post: Union[List[Callable], Callable] = None,
host: str = "0.0.0.0",
port: int = 8000,
ssl_keyfile: str = None,
ssl_certfile: str = None,
ssl_ca_certs: str = None,
logdir: str = "./deploy_logs/",
inference_fn: str = None
)
Easily converts a pytorch model to API for production usage.
model
: A PyTorch model which subclasses nn.Module and is callable. Model used for the API.pre
: A function or list of functions to be applied to the input.post
: Function or list of functions applied to model output before being sent as a response.host
: The address for serving the model.port
: The port for serving the model.ssl_keyfile
,ssl_certfile
,ssl_ca_certs
: SSL configurations that are passed to uvicornlogfile
: Filename to create a file that stores date, ip address, and size of input for each access of the API. IfNone
, no file will be created.inference_fn
: Name of the method of the model that should be called for the inputs. IfNone
, the model itself will be called (Ifmodel
is ann.Module
then it's equivalent to callingmodel.forward(inputs)
).
Sample Response Format
Sample Code
Testing
Run python test_server.py
first and then python test_client.py
in another window to test.
Dependencies
torch, torchvision, fastapi[all], requests, numpy, pydantic
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pytorch_deploy-0.0.4-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ffb41b1bb1a85ef3f45098c84ff62263b3b81e78c5d86d07e4486d0878585497 |
|
MD5 | b8628a536cd229bfa5171a7a1b715bd9 |
|
BLAKE2b-256 | 142486f961f9970974bcdbd31e5cc0a323417fd5393265a5c378cf23c821ab82 |