ML_adapter for torch.
Project description
waylay-ml-adapter-torch
Provides the ml_adapter.sklearn module as Waylay ML Adapter for pytorch.
Installation
pip install waylay-ml-adapter-torch
The torch module is needed, but NOT declared as a package dependency, to leave the user full control over the way the (rather heavy) pytorch installation is done.
Use pip install waylay-ml-adapter-torch[torch] to include it.
You might want to install additional libraries such as torchaudio or torchvision.
Usage
This ML Adapter uses the standard torch mechanisms to save and load models within a waylay plugin or webscript.
The model_path argument defines the file name of the serialized model in the function archive:
- A
model_pathending inweights.ptorweights.pthsave/loads the model weights using its state_dict. It is a recommended, more robust method, but requires you to also specifiy amodel_class. - Any other
model_pathwith.ptor.pthsuffix save/loads the entire model. It implicitly saves (references to) the used model class. You'll have to make sure that all dependencies used are also included or declared in the archive. - You can also pass an instantiated the
modeldirectly to the adapter.
Creating a model for a webscript
from ml_adapter.torch import V1TorchAdapter
# assuming we save a AutoEncoder torch.nn.Module class in a `autoencoder.py` file
from autoencoder import AutoEncoder
model = AutoEncoder()
# ... train the model ...
# a local directory to prepare the webscript archive
ARCHIVE_LOC='~/webscripts/autoencoder-pytorch'
# use a `weights` model path to use _weights_ serialization
MODEL_PATH='autoencoder.weights.pt'
adapter = V1TorchAdapter(
model=model,
model_path='model-weights.pt',
location=ARCHIVE_LOC,
)
# add our model script to the webscript archive
await adapter.add_script('autoencoder.py')
# write the archive
await adapter.save()
# inspect the archive:
list(adapter.assets)
#> [requirements.txt <ml_adapter.base.assets.python.PythonRequirementsAsset>,
#> main.py <ml_adapter.base.assets.python.PythonScriptAsset>,
#> model-weights.pt <ml_adapter.torch.adapter.TorchModelWeightsAsset>,
#> autoencoder.py <ml_adapter.base.assets.python.PythonScriptAsset>]
Upload the adapter archive as webscript using the ml_tool SDK plugin
from waylay.sdk import WaylayClient
client = WaylayClient.from_profile('staging')
ref = await client.ml_tool.create_webscript(adapter, name='MyAutoEncoder', version='0.0.1')
ref = await client.ml_tool.wait_until_ready(ref)
await client.ml_tool.test_webscript(ref, [2,3,4])
The generated code in main.py uses the following to load your model:
MODEL_PATH = os.environ.get('MODEL_PATH', 'model-weights.pt')
MODEL_CLASS = os.environ.get('MODEL_CLASS', 'autoencoder.AutoEncoder')
adapter = V1TorchAdapter(model_path=MODEL_PATH, model_class=MODEL_CLASS)
You can modify that loading mechanism, e.g. by creating the model your self, and providing it as
adapter = V1TorchAdapter(model=model)
When you want additional processing before or after the model invocation that prevents
your model to be loadable by the default torch.load mechanisms, you can alternatively
use the ml_adapter.torch.V1TorchNoLoadAdapter.
This wrapper model class is then responsible for the loading of the model.
Classes exported by ml_adapter.torch
The module ml_adapter.torch exports the following classes
V1TorchAdapter
ml_adapter.torch.adapter.V1TorchAdapter extending ml_adapter.base.adapter.ModelAdapterBase
Adapts a torch model with torch arrays as input and output.
When initialized with a trained model (using a model parameter):
- will store the model weights as
model_weights.pt(alt: set themodel_pathparameter) - requires that the model class in a library or asset file
(e.g. a class extending
torch.nn.Modulein anmymodel.pyscript asset) The generated server script will use this name as asmodel_class
To load from a serialized model, use the model_path (default model_weights.pt)
and model_class (no default).
Alternatively, when the model_path does not have a weights.pt or weights.pth
extension, the adapter will try to load it as a dill-serialized model.
This is not recommended because of the brittleness of this serialization method
with respect to versions.
V1TorchNoLoadAdapter
ml_adapter.torch.adapter.V1TorchNoLoadAdapter extending ml_adapter.torch.adapter.V1TorchAdapter
Adapts a callable with torch arrays as input and output.
This adapter does not manage the model as a standard asset.
It relies on the model or model_class constructor arguments
to define and load the model.
When model is not provided, any model_path is passed as a constructor
argument to model_class if the signature allows it.
Note that if you internally rely on torch models, the model constructor is responsible for
- setting that wrapped model to evaluation mode
- setting the model to the
correct device and/or dtype
(normally to
"cuda" if torch.cuda.is_available() else "cpu")
The model adapter will still enforce a
torch.no_grad
context around model invocations.
def load_my_model(weights_file='my_weights.pt'):
wrapped_model = AWrappedTorchModel()
wrapped_model.load_state_dict(torch.load(weights_file))
wrapped_model.eval()
wrapped_model.to('cpu')
return wrapped_model
class MyTorchWrappingModel():
def __init__(self, model_file):
self.torch_model = load_my_model(model_file)
# custom pre/postprocessing
def __call__(self, x, y, z):
# preprocess
x = x + y + z
result = this.torch_model(x)
# postprocess
return result[0]
adapter = V1TorchNoLoadAdapter(model_class=MyTorchWrappingModel)
If all you need is add pre- and post-processing of torch tensors,
you can still use V1TorchAdapter to load the wrapped model,
but might want to wrap the __call__ method or set another model_method
class MyTorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
# ... initialize layers
def forward(self, x):
# ... inference
# custom pre/postprocessing
def __call__(self, x, y, z):
# preprocess
x = x + y + z
result = super().__call__(x)
# postprocess
return result[0]
V1TorchMarshaller
ml_adapter.torch.marshall.V1TorchMarshaller extending ml_adapter.base.marshall.v1.base.V1ValueOrDictRequestMarshallerBase, ml_adapter.base.marshall.v1.base.V1ValueOrDictResponseMarshallerBase, ml_adapter.torch.marshall.V1TorchNativeEncoding
Convert v1 payload from and to torch tensors.
See also
- waylay-ml-adapter-sklearn ML adapter for scikit-learn models.
- waylay-ml-adapter-sdk provides the
ml_toolextension to the waylay-sdk - waylay-ml-adapter-base provides the basic ML adapter infrastructure.
- waylay-ml-adapter-api defines the remote data interfaces.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file waylay_ml_adapter_torch-0.1.0.tar.gz.
File metadata
- Download URL: waylay_ml_adapter_torch-0.1.0.tar.gz
- Upload date:
- Size: 10.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f0f5b3cae6f5bd25999593159045666c5f6f318480500749d62e1093b72f092e
|
|
| MD5 |
067032860cf0e6471889835fd4e23933
|
|
| BLAKE2b-256 |
1fb66e198d32e66417755c1a0a87c2d0d8ff6f0226fb6d1e1c7db6f253f793e7
|
File details
Details for the file waylay_ml_adapter_torch-0.1.0-py3-none-any.whl.
File metadata
- Download URL: waylay_ml_adapter_torch-0.1.0-py3-none-any.whl
- Upload date:
- Size: 9.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c11cf48bf2ddbcc1e497e68b0cbe5c0f88f283605cde955e41dd09cdb8fe01be
|
|
| MD5 |
c888ac9231a0843841e68d1362f9102b
|
|
| BLAKE2b-256 |
f8a299310fc207a9742e67a57324a0c3d9204066f0de018453b827675b8fca1e
|