Sparse ML Callback project.
Project description
SparseML Callback
All credits to Sean Narenthiran. I am merely using his code for demonstrating purposes.
SparseML allows you to leverage sparsity to improve inference times substantially.
SparseML requires you to fine-tune your model with the SparseMLCallback
+ a SparseML Recipe. By training with the SparseMLCallback
, you can leverage the DeepSparse engine to exploit the introduced sparsity, resulting in large performance improvements.
The SparseML callback requires the model to be ONNX exportable. This can be tricky when the model requires dynamic sequence lengths such as RNNs.
To use leverage SparseML & DeepSparse follow the below steps:
Choose your Sparse Recipe
To choose a recipe, have a look at recipes and Sparse Zoo.
It may be easier to infer a recipe via the UI dashboard using Sparsify which allows you to tweak and configure a recipe.
This requires to import an ONNX model, which you can get from your LightningModule
by doing model.to_onnx(output_path)
.
Train with SparseMLCallback
from pytorch_lightning import LightningModule, Trainer
from pl_bolts.callbacks import SparseMLCallback
class MyModel(LightningModule):
...
model = MyModel()
trainer = Trainer(
callbacks=SparseMLCallback(recipe_path='recipe.yaml')
)
Export to ONNX!
Using the helper function, we handle any quantization/pruning internally and export the model into ONNX format.
Note this assumes either you have implemented the property example_input_array
in the model or you must provide a sample batch as below.
import torch
model = MyModel()
...
# export the onnx model, using the `model.example_input_array`
SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/')
# export the onnx model, providing a sample batch
SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/', sample_batch=torch.randn(1, 128, 128, dtype=torch.float32))
Once your model has been exported, you can import this into either Sparsify or DeepSparse.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pl_hub_sparse_ml_callback-0.0.2.tar.gz
.
File metadata
- Download URL: pl_hub_sparse_ml_callback-0.0.2.tar.gz
- Upload date:
- Size: 8.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fcd2a45e4f13980e0a6178d509bf5268c8c17c58cfa46689b49d54a4a6e67865 |
|
MD5 | 060c85eaf263013139a1f5c718ad6a75 |
|
BLAKE2b-256 | 1584994d5868632354e2635e4b811884b0a78d1aaec475832c82ba0c27cf7124 |
File details
Details for the file pl_hub_sparse_ml_callback-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: pl_hub_sparse_ml_callback-0.0.2-py3-none-any.whl
- Upload date:
- Size: 9.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ab28a006f4af3ca9ea67037ea84df77b2e12ffee8d2da42570d73fbf0ca5d0c2 |
|
MD5 | 56fdb8122d6b1e002771487f4168c98a |
|
BLAKE2b-256 | 4eece0eac4c760fa29b2236f28123ae8e70f634e5edc29cacac236703ec7b389 |