Open source library for running inference workload with Hugging Face Deep Learning Containers on Amazon SageMaker.
Project description
SageMaker Hugging Face Inference Toolkit
SageMaker Hugging Face Inference Toolkit is an open-source library for serving 🤗 Transformers models on Amazon SageMaker. This library provides default pre-processing, predict and postprocessing for certain 🤗 Transformers models and tasks. It utilizes the SageMaker Inference Toolkit for starting up the model server, which is responsible for handling inference requests.
For Training, see Run training on Amazon SageMaker.
For the Dockerfiles used for building SageMaker Hugging Face Containers, see AWS Deep Learning Containers.
For information on running Hugging Face jobs on Amazon SageMaker, please refer to the 🤗 Transformers documentation.
For notebook examples: SageMaker Notebook Examples.
💻 Getting Started with 🤗 Inference Toolkit
needs to be adjusted -> currently pseudo code
Install Amazon SageMaker Python SDK
pip install sagemaker --upgrade
Create a Amazon SageMaker endpoint with a trained model.
from sagemaker.huggingface import HuggingFaceModel
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
transformers_version='4.6',
pytorch_version='1.7',
py_version='py36',
model_data='s3://my-trained-model/artifacts/model.tar.gz',
role=role,
)
# deploy model to SageMaker Inference
huggingface_model.deploy(initial_instance_count=1,instance_type="ml.m5.xlarge")
Create a Amazon SageMaker endpoint with a model from the 🤗 Hub.
note: This is an experimental feature, where the model will be loaded after the endpoint is created. Not all sagemaker features are supported, e.g. MME
from sagemaker.huggingface import HuggingFaceModel
# Hub Model configuration. https://huggingface.co/models
hub = {
'HF_MODEL_ID':'distilbert-base-uncased-distilled-squad',
'HF_TASK':'question-answering'
}
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
transformers_version='4.6',
pytorch_version='1.7',
py_version='py36',
env=hub,
role=role,
)
# deploy model to SageMaker Inference
huggingface_model.deploy(initial_instance_count=1,instance_type="ml.m5.xlarge")
🛠️ Environment variables
The SageMaker Hugging Face Inference Toolkit implements various additional environment variables to simplify your deployment experience. A full list of environment variables is given below.
HF_TASK
The HF_TASK
environment variable defines the task for the used 🤗 Transformers pipeline. A full list of tasks can be find here.
HF_TASK="question-answering"
HF_MODEL_ID
The HF_MODEL_ID
environment variable defines the model id, which will be automatically loaded from huggingface.co/models when creating or SageMaker Endpoint. The 🤗 Hub provides +10 000 models all available through this environment variable.
HF_MODEL_ID="distilbert-base-uncased-finetuned-sst-2-english"
HF_MODEL_REVISION
The HF_MODEL_REVISION
is an extension to HF_MODEL_ID
and allows you to define/pin a revision of the model to make sure you always load the same model on your SageMaker Endpoint.
HF_MODEL_REVISION="03b4d196c19d0a73c7e0322684e97db1ec397613"
HF_API_TOKEN
The HF_API_TOKEN
environment variable defines the your Hugging Face authorization token. The HF_API_TOKEN
is used as a HTTP bearer authorization for remote files, like private models. You can find your token at your settings page.
HF_API_TOKEN="api_XXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
🧑🏻💻 User defined code/modules
The Hugging Face Inference Toolkit allows user to override the default methods of the HuggingFaceHandlerService
. Therefor the need to create a named code/
with a inference.py
file in it.
For example:
model.tar.gz/
|- pytorch_model.bin
|- ....
|- code/
|- inference.py
|- requirements.txt
In this example, pytroch_model.bin
is the model file saved from training, inference.py
is the custom inference module, and requirements.txt
is a requirements file to add additional dependencies.
The custom module can override the following methods:
model_fn(model_dir)
: overrides the default method for loading the model, the return valuemodel
will be used in thepredict()
for predicitions. It receives argument themodel_dir
, the path to your unzippedmodel.tar.gz
.transform_fn(model, data, content_type, accept_type)
: Overrides the default transform function with custom implementation. Customers using this would have to implementpreprocess
,predict
andpostprocess
steps in thetransform_fn
. NOTE: This method can't be combined withinput_fn
,predict_fn
oroutput_fn
mentioned below.input_fn(input_data, content_type)
: overrides the default method for prerprocessing, the return valuedata
will be used in thepredict()
method for predicitions. The input isinput_data
, the raw body of your request andcontent_type
, the content type form the request Header.predict_fn(processed_data, model)
: overrides the default method for predictions, the return valuepredictions
will be used in thepostprocess()
method. The input isprocessed_data
, the result of thepreprocess()
method.output_fn(prediction, accept)
: overrides the default method for postprocessing, the return valueresult
will be the respond of your request(e.g.JSON
). The inputs arepredictions
, the result of thepredict()
method andaccept
the return accept type from the HTTP Request, e.g.application/json
🤝 Contributing
Please read CONTRIBUTING.md for details on our code of conduct, and the process for submitting pull requests to us.
📜 License
SageMaker Hugging Face Inference Toolkit is licensed under the Apache 2.0 License.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for sagemaker-huggingface-inference-toolkit-1.3.1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | d65395f3b652fa4dbd248d2a22469f7345a0bfe592c833613d87972cd421fe22 |
|
MD5 | 7596132b19007b95fd7b89b53292156b |
|
BLAKE2b-256 | 3974daf733107ddf975c48a6f3d9f8570bb26685b6960e2ff342191d989dedc5 |
Hashes for sagemaker_huggingface_inference_toolkit-1.3.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 398f82d429e6b0a17a711dde64931b0c3fa30922e9629fae649d0168659ce35f |
|
MD5 | ae821f01b8e820d3595f9643ce003867 |
|
BLAKE2b-256 | bade10f4a4c49de1b41e36da4bc88b82313b1f571081630f6aca651fa1471ede |