Skip to main content

ML_adapter for torch.

Project description

waylay-ml-adapter-torch

Provides the ml_adapter.sklearn module as Waylay ML Adapter for pytorch.

Installation

pip install waylay-ml-adapter-torch

The torch module is needed, but NOT declared as a package dependency, to leave the user full control over the way the (rather heavy) pytorch installation is done. Use pip install waylay-ml-adapter-torch[torch] to include it.

You might want to install additional libraries such as torchaudio or torchvision.

Usage

This ML Adapter uses the standard torch mechanisms to save and load models within a waylay plugin or webscript. The model_path argument defines the file name of the serialized model in the function archive:

  • A model_path ending in weights.pt or weights.pth save/loads the model weights using its state_dict. It is a recommended, more robust method, but requires you to also specifiy a model_class.
  • Any other model_path with .pt or .pth suffix save/loads the entire model. It implicitly saves (references to) the used model class. You'll have to make sure that all dependencies used are also included or declared in the archive.
  • You can also pass an instantiated the model directly to the adapter.

Creating a model for a webscript

from ml_adapter.torch import V1TorchAdapter

# assuming we save a AutoEncoder torch.nn.Module class in a `autoencoder.py` file
from autoencoder import AutoEncoder
model = AutoEncoder()
# ... train the model ...

# a local directory to prepare the webscript archive
ARCHIVE_LOC='~/webscripts/autoencoder-pytorch'
# use a `weights` model path to use _weights_ serialization
MODEL_PATH='autoencoder.weights.pt'

adapter = V1TorchAdapter(
    model=model,
    model_path='model-weights.pt',
    location=ARCHIVE_LOC,
)

# add our model script to the webscript archive
await adapter.add_script('autoencoder.py')
# write the archive
await adapter.save()
# inspect the archive:
list(adapter.assets)
#> [requirements.txt <ml_adapter.base.assets.python.PythonRequirementsAsset>,
#> main.py <ml_adapter.base.assets.python.PythonScriptAsset>,
#> model-weights.pt <ml_adapter.torch.adapter.TorchModelWeightsAsset>,
#> autoencoder.py <ml_adapter.base.assets.python.PythonScriptAsset>]

Upload the adapter archive as webscript using the ml_tool SDK plugin

from waylay.sdk import WaylayClient
client = WaylayClient.from_profile('staging')
ref = await client.ml_tool.create_webscript(adapter, name='MyAutoEncoder', version='0.0.1')
ref = await client.ml_tool.wait_until_ready(ref)
await client.ml_tool.test_webscript(ref, [2,3,4])

The generated code in main.py uses the following to load your model:

MODEL_PATH = os.environ.get('MODEL_PATH', 'model-weights.pt')
MODEL_CLASS = os.environ.get('MODEL_CLASS', 'autoencoder.AutoEncoder')
adapter = V1TorchAdapter(model_path=MODEL_PATH, model_class=MODEL_CLASS)

You can modify that loading mechanism, e.g. by creating the model your self, and providing it as

adapter = V1TorchAdapter(model=model)

When you want additional processing before or after the model invocation that prevents your model to be loadable by the default torch.load mechanisms, you can alternatively use the ml_adapter.torch.V1TorchNoLoadAdapter.

This wrapper model class is then responsible for the loading of the model.

Classes exported by ml_adapter.torch

The module ml_adapter.torch exports the following classes

V1TorchAdapter

ml_adapter.torch.adapter.V1TorchAdapter extending ml_adapter.base.adapter.ModelAdapterBase

Adapts a torch model with torch arrays as input and output.

When initialized with a trained model (using a model parameter):

  • will store the model weights as model_weights.pt (alt: set the model_path parameter)
  • requires that the model class in a library or asset file (e.g. a class extending torch.nn.Module in an mymodel.py script asset) The generated server script will use this name as as model_class

To load from a serialized model, use the model_path (default model_weights.pt) and model_class (no default).

Alternatively, when the model_path does not have a weights.pt or weights.pth extension, the adapter will try to load it as a dill-serialized model. This is not recommended because of the brittleness of this serialization method with respect to versions.

V1TorchNoLoadAdapter

ml_adapter.torch.adapter.V1TorchNoLoadAdapter extending ml_adapter.torch.adapter.V1TorchAdapter

Adapts a callable with torch arrays as input and output.

This adapter does not manage the model as a standard asset. It relies on the model or model_class constructor arguments to define and load the model. When model is not provided, any model_path is passed as a constructor argument to model_class if the signature allows it.

Note that if you internally rely on torch models, the model constructor is responsible for

The model adapter will still enforce a torch.no_grad context around model invocations.

def load_my_model(weights_file='my_weights.pt'):
    wrapped_model = AWrappedTorchModel()
    wrapped_model.load_state_dict(torch.load(weights_file))
    wrapped_model.eval()
    wrapped_model.to('cpu')
    return wrapped_model

class MyTorchWrappingModel():
    def __init__(self, model_file):
        self.torch_model = load_my_model(model_file)

    # custom pre/postprocessing
    def __call__(self, x, y, z):
        # preprocess
        x = x + y + z
        result = this.torch_model(x)
        # postprocess
        return result[0]
adapter = V1TorchNoLoadAdapter(model_class=MyTorchWrappingModel)

If all you need is add pre- and post-processing of torch tensors, you can still use V1TorchAdapter to load the wrapped model, but might want to wrap the __call__ method or set another model_method

class MyTorchModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # ... initialize layers

    def forward(self, x):
        # ... inference

    # custom pre/postprocessing
    def __call__(self, x, y, z):
        # preprocess
        x = x + y + z
        result = super().__call__(x)
        # postprocess
        return result[0]

V1TorchMarshaller

ml_adapter.torch.marshall.V1TorchMarshaller extending ml_adapter.base.marshall.v1.base.V1ValueOrDictRequestMarshallerBase, ml_adapter.base.marshall.v1.base.V1ValueOrDictResponseMarshallerBase, ml_adapter.torch.marshall.V1TorchNativeEncoding

Convert v1 payload from and to torch tensors.

See also

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

waylay_ml_adapter_torch-0.1.0.tar.gz (10.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

waylay_ml_adapter_torch-0.1.0-py3-none-any.whl (9.8 kB view details)

Uploaded Python 3

File details

Details for the file waylay_ml_adapter_torch-0.1.0.tar.gz.

File metadata

  • Download URL: waylay_ml_adapter_torch-0.1.0.tar.gz
  • Upload date:
  • Size: 10.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.12

File hashes

Hashes for waylay_ml_adapter_torch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 f0f5b3cae6f5bd25999593159045666c5f6f318480500749d62e1093b72f092e
MD5 067032860cf0e6471889835fd4e23933
BLAKE2b-256 1fb66e198d32e66417755c1a0a87c2d0d8ff6f0226fb6d1e1c7db6f253f793e7

See more details on using hashes here.

File details

Details for the file waylay_ml_adapter_torch-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for waylay_ml_adapter_torch-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c11cf48bf2ddbcc1e497e68b0cbe5c0f88f283605cde955e41dd09cdb8fe01be
MD5 c888ac9231a0843841e68d1362f9102b
BLAKE2b-256 f8a299310fc207a9742e67a57324a0c3d9204066f0de018453b827675b8fca1e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page