Skip to main content

ML_adapter for torch.

Project description

waylay-ml-adapter-torch

Provides the ml_adapter.sklearn module as Waylay ML Adapter for pytorch.

Installation

pip install waylay-ml-adapter-torch

The torch module is needed, but NOT declared as a package dependency, to leave the user full control over the way the (rather heavy) pytorch installation is done. Use pip install waylay-ml-adapter-torch[torch] to include it.

You might want to install additional libraries such as torchaudio or torchvision.

Usage

This ML Adapter uses the standard torch mechanisms to save and load models within a waylay plugin or webscript. The model_path argument defines the file name of the serialized model in the function archive:

  • A model_path ending in weights.pt or weights.pth save/loads the model weights using its state_dict. It is a recommended, more robust method, but requires you to also specifiy a model_class.
  • Any other model_path with .pt or .pth suffix save/loads the entire model. It implicitly saves (references to) the used model class. You'll have to make sure that all dependencies used are also included or declared in the archive.
  • You can also pass an instantiated the model directly to the adapter.

Creating a model for a webscript

from ml_adapter.torch import V1TorchAdapter

# assuming we save a AutoEncoder torch.nn.Module class in a `autoencoder.py` file
from autoencoder import AutoEncoder
model = AutoEncoder()
# ... train the model ...

# a local directory to prepare the webscript archive
ARCHIVE_LOC='~/webscripts/autoencoder-pytorch'
# use a `weights` model path to use _weights_ serialization
MODEL_PATH='autoencoder.weights.pt'

adapter = V1TorchAdapter(
    model=model,
    model_path='model-weights.pt',
    location=ARCHIVE_LOC,
)

# add our model script to the webscript archive
await adapter.add_script('autoencoder.py')
# write the archive
await adapter.save()
# inspect the archive:
list(adapter.assets)
#> [requirements.txt <ml_adapter.base.assets.python.PythonRequirementsAsset>,
#> main.py <ml_adapter.base.assets.python.PythonScriptAsset>,
#> model-weights.pt <ml_adapter.torch.adapter.TorchModelWeightsAsset>,
#> autoencoder.py <ml_adapter.base.assets.python.PythonScriptAsset>]

Upload the adapter archive as webscript using the ml_tool SDK plugin

from waylay.sdk import WaylayClient
client = WaylayClient.from_profile('staging')
ref = await client.ml_tool.create_webscript(adapter, name='MyAutoEncoder', version='0.0.1')
ref = await client.ml_tool.wait_until_ready(ref)
await client.ml_tool.test_webscript(ref, [2,3,4])

The generated code in main.py uses the following to load your model:

MODEL_PATH = os.environ.get('MODEL_PATH', 'model-weights.pt')
MODEL_CLASS = os.environ.get('MODEL_CLASS', 'autoencoder.AutoEncoder')
adapter = V1TorchAdapter(model_path=MODEL_PATH, model_class=MODEL_CLASS)

You can modify that loading mechanism, e.g. by creating the model your self, and providing it as

adapter = V1TorchAdapter(model=model)

When you want additional processing before or after the model invocation that prevents your model to be loadable by the default torch.load mechanisms, you can alternatively use the ml_adapter.torch.V1TorchNoLoadAdapter.

This wrapper model class is then responsible for the loading of the model.

Exported classes

This module exports the following classes:

ml_adapter.torch.V1TorchAdapter

Adapts a torch model with torch arrays as input and output.

When initialized with a trained model (using a model parameter):

  • will store the model weights as model_weights.pt (alt: set the model_path parameter)
  • requires that the model class in a library or asset file (e.g. a class extending torch.nn.Module in an mymodel.py script asset) The generated server script will use this name as as model_class

To load from a serialized model, use the model_path (default model_weights.pt) and model_class (no default).

Alternatively, when the model_path does not have a weights.pt or weights.pth extension, the adapter will try to load it as a dill-serialized model. This is not recommended because of the brittleness of this serialization method with respect to versions.

ml_adapter.torch.V1TorchNoLoadAdapter

Adapts a callable with torch arrays as input and output.

This adapter does not manage the model as a standard asset. It relies on the model or model_class constructor arguments to define and load the model. When model is not provided, any model_path is passed as a constructor argument to model_class if the signature allows it.

Note that if you internally rely on torch models, the model constructor is responsible for

The model adapter will still enforce a torch.no_grad context around model invocations.

def load_my_model(weights_file='my_weights.pt'):
    wrapped_model = AWrappedTorchModel()
    wrapped_model.load_state_dict(torch.load(weights_file))
    wrapped_model.eval()
    wrapped_model.to('cpu')
    return wrapped_model

class MyTorchWrappingModel():
    def __init__(self, model_file):
        self.torch_model = load_my_model(model_file)

    # custom pre/postprocessing
    def __call__(self, x, y, z):
        # preprocess
        x = x + y + z
        result = this.torch_model(x)
        # postprocess
        return result[0]
adapter = V1TorchNoLoadAdapter(model_class=MyTorchWrappingModel)

If all you need is add pre- and post-processing of torch tensors, you can still use V1TorchAdapter to load the wrapped model, but might want to wrap the __call__ method or set another model_method

class MyTorchModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # ... initialize layers

    def forward(self, x):
        # ... inference

    # custom pre/postprocessing
    def __call__(self, x, y, z):
        # preprocess
        x = x + y + z
        result = super().__call__(x)
        # postprocess
        return result[0]

ml_adapter.torch.V1TorchMarshaller

Convert v1 payload from and to torch tensors.

See also

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

waylay_ml_adapter_torch-0.0.10.tar.gz (9.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

waylay_ml_adapter_torch-0.0.10-py3-none-any.whl (9.3 kB view details)

Uploaded Python 3

File details

Details for the file waylay_ml_adapter_torch-0.0.10.tar.gz.

File metadata

File hashes

Hashes for waylay_ml_adapter_torch-0.0.10.tar.gz
Algorithm Hash digest
SHA256 813bc22d9445e8186ad22364cfd72b4e28447f914b3c7e149e17cb3b8a8d9c87
MD5 befa4c021dad907db3f7e3378e366778
BLAKE2b-256 018110be23e381b9df9302ab1ed17fa3c5f9d479d33b4fed9df97042f4105b96

See more details on using hashes here.

File details

Details for the file waylay_ml_adapter_torch-0.0.10-py3-none-any.whl.

File metadata

File hashes

Hashes for waylay_ml_adapter_torch-0.0.10-py3-none-any.whl
Algorithm Hash digest
SHA256 fb0c45f6d2fd8fe4dfd14ec78b029a0bb679ec43448a1534b1c6096e1e008830
MD5 e8803dc81e793ad3188df2d728c9f84e
BLAKE2b-256 d6b8883f85a96363930b30732dafe5728120acc344e58d10004a9636eea0bb22

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page