Skip to main content

ML_adapter for torch.

Project description

waylay-ml-adapter-torch

Provides the ml_adapter.sklearn module as Waylay ML Adapter for pytorch.

Installation

pip install waylay-ml-adapter-torch

The torch module is needed, but NOT declared as a package dependency, to leave the user full control over the way the (rather heavy) pytorch installation is done. Use pip install waylay-ml-adapter-torch[torch] to include it.

You might want to install additional libraries such as torchaudio or torchvision.

Usage

This ML Adapter uses the standard torch mechanisms to save and load models within a waylay plugin or webscript. The model_path argument defines the file name of the serialized model in the function archive:

  • A model_path ending in weights.pt or weights.pth save/loads the model weights using its state_dict. It is a recommended, more robust method, but requires you to also specifiy a model_class.
  • Any other model_path with .pt or .pth suffix save/loads the entire model. It implicitly saves (references to) the used model class. You'll have to make sure that all dependencies used are also included or declared in the archive.
  • You can also pass an instantiated the model directly to the adapter.

Creating a model for a webscript

from ml_adapter.torch import V1TorchAdapter

# assuming we save a AutoEncoder torch.nn.Module class in a `autoencoder.py` file
from autoencoder import AutoEncoder
model = AutoEncoder()
# ... train the model ...

# a local directory to prepare the webscript archive
ARCHIVE_LOC='~/webscripts/autoencoder-pytorch'
# use a `weights` model path to use _weights_ serialization
MODEL_PATH='autoencoder.weights.pt'

adapter = V1TorchAdapter(
    model=model,
    model_path='model-weights.pt',
    location=ARCHIVE_LOC,
)

# add our model script to the webscript archive
await adapter.add_script('autoencoder.py')
# write the archive
await adapter.save()
# inspect the archive:
list(adapter.assets)
#> [requirements.txt <ml_adapter.base.assets.python.PythonRequirementsAsset>,
#> main.py <ml_adapter.base.assets.python.PythonScriptAsset>,
#> model-weights.pt <ml_adapter.torch.adapter.TorchModelWeightsAsset>,
#> autoencoder.py <ml_adapter.base.assets.python.PythonScriptAsset>]

Upload the adapter archive as webscript using the ml_tool SDK plugin

from waylay.sdk import WaylayClient
client = WaylayClient.from_profile('staging')
ref = await client.ml_tool.create_webscript(adapter, name='MyAutoEncoder', version='0.0.1')
ref = await client.ml_tool.wait_until_ready(ref)
await client.ml_tool.test_webscript(ref, [2,3,4])

The generated code in main.py uses the following to load your model:

MODEL_PATH = os.environ.get('MODEL_PATH', 'model-weights.pt')
MODEL_CLASS = os.environ.get('MODEL_CLASS', 'autoencoder.AutoEncoder')
adapter = V1TorchAdapter(model_path=MODEL_PATH, model_class=MODEL_CLASS)

You can modify that loading mechanism, e.g. by creating the model your self, and providing it as

adapter = V1TorchAdapter(model=model)

When you want additional processing before or after the model invocation that prevents your model to be loadable by the default torch.load mechanisms, you can alternatively use the ml_adapter.torch.V1TorchNoLoadAdapter.

This wrapper model class is then responsible for the loading of the model.

Classes exported by ml_adapter.torch

The module ml_adapter.torch exports the following classes

V1TorchAdapter

ml_adapter.torch.adapter.V1TorchAdapter extending ml_adapter.base.adapter.ModelAdapterBase

Adapts a torch model with torch arrays as input and output.

When initialized with a trained model (using a model parameter):

  • will store the model weights as model_weights.pt (alt: set the model_path parameter)
  • requires that the model class in a library or asset file (e.g. a class extending torch.nn.Module in an mymodel.py script asset) The generated server script will use this name as as model_class

To load from a serialized model, use the model_path (default model_weights.pt) and model_class (no default).

Alternatively, when the model_path does not have a weights.pt or weights.pth extension, the adapter will try to load it as a dill-serialized model. This is not recommended because of the brittleness of this serialization method with respect to versions.

V1TorchNoLoadAdapter

ml_adapter.torch.adapter.V1TorchNoLoadAdapter extending ml_adapter.torch.adapter.V1TorchAdapter

Adapts a callable with torch arrays as input and output.

This adapter does not manage the model as a standard asset. It relies on the model or model_class constructor arguments to define and load the model. When model is not provided, any model_path is passed as a constructor argument to model_class if the signature allows it.

Note that if you internally rely on torch models, the model constructor is responsible for

The model adapter will still enforce a torch.no_grad context around model invocations.

def load_my_model(weights_file='my_weights.pt'):
    wrapped_model = AWrappedTorchModel()
    wrapped_model.load_state_dict(torch.load(weights_file))
    wrapped_model.eval()
    wrapped_model.to('cpu')
    return wrapped_model

class MyTorchWrappingModel():
    def __init__(self, model_file):
        self.torch_model = load_my_model(model_file)

    # custom pre/postprocessing
    def __call__(self, x, y, z):
        # preprocess
        x = x + y + z
        result = this.torch_model(x)
        # postprocess
        return result[0]
adapter = V1TorchNoLoadAdapter(model_class=MyTorchWrappingModel)

If all you need is add pre- and post-processing of torch tensors, you can still use V1TorchAdapter to load the wrapped model, but might want to wrap the __call__ method or set another model_method

class MyTorchModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # ... initialize layers

    def forward(self, x):
        # ... inference

    # custom pre/postprocessing
    def __call__(self, x, y, z):
        # preprocess
        x = x + y + z
        result = super().__call__(x)
        # postprocess
        return result[0]

V1TorchMarshaller

ml_adapter.torch.marshall.V1TorchMarshaller extending ml_adapter.base.marshall.v1.base.V1ValueOrDictRequestMarshallerBase, ml_adapter.base.marshall.v1.base.V1ValueOrDictResponseMarshallerBase, ml_adapter.torch.marshall.V1TorchNativeEncoding

Convert v1 payload from and to torch tensors.

See also

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

waylay_ml_adapter_torch-0.1.1.tar.gz (10.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

waylay_ml_adapter_torch-0.1.1-py3-none-any.whl (9.8 kB view details)

Uploaded Python 3

File details

Details for the file waylay_ml_adapter_torch-0.1.1.tar.gz.

File metadata

  • Download URL: waylay_ml_adapter_torch-0.1.1.tar.gz
  • Upload date:
  • Size: 10.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.12

File hashes

Hashes for waylay_ml_adapter_torch-0.1.1.tar.gz
Algorithm Hash digest
SHA256 bfc4453b62482d002acc1a2560d5a03e8335157134d86d19010d6292d1a22e15
MD5 caea79da3085359ce14d047252a89558
BLAKE2b-256 fb251f9e6254dbe41b2466a475b95a60cdb07d26a06dcfd676e98fc90356596d

See more details on using hashes here.

File details

Details for the file waylay_ml_adapter_torch-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for waylay_ml_adapter_torch-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 358dac018cead682e45da287ea216768ec8fa5a1cf4c8e7191ec5542179c705a
MD5 721a31ecc797d42a2241b0d3124f961d
BLAKE2b-256 c9573c291c6a650c7a802417913d155664795b9c5ba96db1b0400d98ba2b042b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page