Skip to main content

ML_adapter for torch.

Project description

waylay-ml-adapter-torch

Provides the ml_adapter.sklearn module as Waylay ML Adapter for pytorch.

Installation

pip install waylay-ml-adapter-torch

You might want to install additional libraries such as torchaudio or torchvision.

Usage

This ML Adapter uses the standard torch mechanisms to save and load models within a waylay plugin or webscript. The model_path argument defines the file name of the serialized model in the function archive:

  • A model_path ending in weights.pt or weights.pth save/loads the model weights using its state_dict. It is a recommended, more robust method, but requires you to also specifiy a model_class.
  • Any other model_path with .pt or .pth suffix save/loads the entire model. It implicitly saves (references to) the used model class. You'll have to make sure that all dependencies used are also included or declared in the archive.
  • You can also pass an instantiated the model directly to the adapter.

Creating a model for a _webscript

from ml_adapter.torch import V1TorchAdapter

# assuming we save a AutoEncoder torch.nn.Module class in a `autoencoder.py` file
from autoencoder import AutoEncoder
model = AutoEncoder()
# ... train the model ...

# a local directory to prepare the webscript archive
ARCHIVE_LOC='~/webscripts/autoencoder-pytorch'
# use a `weights` model path to use _weights_ serialization
MODEL_PATH='autoencoder.weights.pt'

adapter = V1TorchAdapter(
    model=model,
    model_path='model-weights.pt',
    location=ARCHIVE_LOC,
)

# add our model script to the webscript archive
await adapter.add_script('autoencoder.py')
# write the archive
await adapter.save()
# inspect the archive:
list(adapter.assets)
#> [requirements.txt <ml_adapter.base.assets.python.PythonRequirementsAsset>,
#> main.py <ml_adapter.base.assets.python.PythonScriptAsset>,
#> model-weights.pt <ml_adapter.torch.adapter.TorchModelWeightsAsset>,
#> autoencoder.py <ml_adapter.base.assets.python.PythonScriptAsset>]

Upload the adapter archive as webscript using the ml_tool SDK plugin

from waylay.sdk import WaylayClient
client = WaylayClient.from_profile('staging')
ref = await client.ml_tool.create_webscript(adapter, name='MyAutoEncoder', version='0.0.1')
ref = await client.ml_tool.wait_until_ready(ref)
await client.ml_tool.test_webscript(ref, [2,3,4])

The generated code in main.py uses the following to load your model:

MODEL_PATH = os.environ.get('MODEL_PATH', 'model-weights.pt')
MODEL_CLASS = os.environ.get('MODEL_CLASS', 'autoencoder.AutoEncoder')
adapter = V1TorchAdapter(model_path=MODEL_PATH, model_class=MODEL_CLASS)

You can modify that loading mechanism, e.g. by creating the model your self, and providing it as

adapter = V1TorchAdapter(model=model)

Exported classes

This module exports the following classes:

ml_adapter.torch.V1TorchAdapter

Adapts a callable with torch arrays as input and output.

ml_adapter.torch.V1TorchMarshaller

Convert v1 payload from and to torch tensors.

See also

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

waylay_ml_adapter_torch-0.0.8.tar.gz (5.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

waylay_ml_adapter_torch-0.0.8-py3-none-any.whl (6.3 kB view details)

Uploaded Python 3

File details

Details for the file waylay_ml_adapter_torch-0.0.8.tar.gz.

File metadata

  • Download URL: waylay_ml_adapter_torch-0.0.8.tar.gz
  • Upload date:
  • Size: 5.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.0

File hashes

Hashes for waylay_ml_adapter_torch-0.0.8.tar.gz
Algorithm Hash digest
SHA256 279d728b816e253f69c03e73f3e13bc6b75cee0e5cd7581ef455971f94210ff0
MD5 93944ba6adbcc2c3646b6d9f934e9c40
BLAKE2b-256 eeeade165f9a939577038af32c7ee16bbe1d68ebad09390df4dc77e976e24608

See more details on using hashes here.

File details

Details for the file waylay_ml_adapter_torch-0.0.8-py3-none-any.whl.

File metadata

File hashes

Hashes for waylay_ml_adapter_torch-0.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 1f2ad10314dcf20792a334eee0c8726b16f32875e12802350e5e8232262eeef1
MD5 b365cf9ec78c7487bf641cbccccd4ac1
BLAKE2b-256 296959a65f9747cedc742a2bc5380e0ea20edd92c15fd9ecd161ccf6dd5fa75d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page