ML_adapter for torch.
Project description
waylay-ml-adapter-torch
Provides the ml_adapter.sklearn
module as Waylay ML Adapter for pytorch.
Installation
pip install waylay-ml-adapter-torch
You might want to install additional libraries such as torchaudio
or torchvision
.
Usage
This ML Adapter uses the standard torch mechanisms to save and load models within a waylay plugin or webscript.
The model_path
argument defines the file name of the serialized model in the function archive:
- A
model_path
ending inweights.pt
orweights.pth
save/loads the model weights using its state_dict. It is a recommended, more robust method, but requires you to also specifiy amodel_class
. - Any other
model_path
with.pt
or.pth
suffix save/loads the entire model. It implicitly saves (references to) the used model class. You'll have to make sure that all dependencies used are also included or declared in the archive. - You can also pass an instantiated the
model
directly to the adapter.
Creating a model for a _webscript
from ml_adapter.torch import V1TorchAdapter
# assuming we save a AutoEncoder torch.nn.Module class in a `autoencoder.py` file
from autoencoder import AutoEncoder
model = AutoEncoder()
# ... train the model ...
# a local directory to prepare the webscript archive
ARCHIVE_LOC='~/webscripts/autoencoder-pytorch'
# use a `weights` model path to use _weights_ serialization
MODEL_PATH='autoencoder.weights.pt'
adapter = V1TorchAdapter(
model=model,
model_path='model-weights.pt',
location=ARCHIVE_LOC,
)
# add our model script to the webscript archive
await adapter.add_script('autoencoder.py')
# write the archive
await adapter.save()
# inspect the archive:
list(adapter.assets)
#> [requirements.txt <ml_adapter.base.assets.python.PythonRequirementsAsset>,
#> main.py <ml_adapter.base.assets.python.PythonScriptAsset>,
#> model-weights.pt <ml_adapter.torch.adapter.TorchModelWeightsAsset>,
#> autoencoder.py <ml_adapter.base.assets.python.PythonScriptAsset>]
Upload the adapter archive as webscript using the ml_tool
SDK plugin
from waylay.sdk import WaylayClient
client = WaylayClient.from_profile('staging')
ref = await client.ml_tool.create_webscript(adapter, name='MyAutoEncoder', version='0.0.1')
ref = await client.ml_tool.wait_until_ready(ref)
await client.ml_tool.test_webscript(ref, [2,3,4])
The generated code in main.py
uses the following to load your model:
MODEL_PATH = os.environ.get('MODEL_PATH', 'model-weights.pt')
MODEL_CLASS = os.environ.get('MODEL_CLASS', 'autoencoder.AutoEncoder')
adapter = V1TorchAdapter(model_path=MODEL_PATH, model_class=MODEL_CLASS)
You can modify that loading mechanism, e.g. by creating the model your self, and providing it as
adapter = V1TorchAdapter(model=model)
Exported classes
This module exports the following classes:
ml_adapter.torch.V1TorchAdapter
Adapts a callable with torch arrays as input and output.
ml_adapter.torch.V1TorchMarshaller
Convert v1 payload from and to torch tensors.
See also
- waylay-ml-adapter-sklearn ML adapter for scikit-learn models.
- waylay-ml-adapter-sdk provides the
ml_tool
extension to the waylay-sdk - waylay-ml-adapter-base provides the basic ML adapter infrastructure.
- waylay-ml-adapter-api defines the remote data interfaces.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file waylay_ml_adapter_torch-0.0.3.tar.gz
.
File metadata
- Download URL: waylay_ml_adapter_torch-0.0.3.tar.gz
- Upload date:
- Size: 5.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 05089008bcf53647eed1baa22c428d554f408e46181f29ec541138a7399e1957 |
|
MD5 | 1c1a4109f589ac9640c66d5178640347 |
|
BLAKE2b-256 | 22ab218a7201c329e9228c0ad5088f5dceb5973cf4ece6fe4b36a585d2b26e53 |
File details
Details for the file waylay_ml_adapter_torch-0.0.3-py3-none-any.whl
.
File metadata
- Download URL: waylay_ml_adapter_torch-0.0.3-py3-none-any.whl
- Upload date:
- Size: 6.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a2c14b966f4959f3362668c588847fa049d081a5074eea706c710692f7b9f414 |
|
MD5 | 1290df11821a5d096a24880bf6af92ff |
|
BLAKE2b-256 | 217618470a43b970e43c7bbf5ee8b769ead82975467730b9a07a13a842d61719 |