Skip to main content

Sparse ML Callback project.

Project description

SparseML Callback

All credits to Sean Narenthiran. I am merely using his code for demonstrating purposes.

SparseML allows you to leverage sparsity to improve inference times substantially.

SparseML requires you to fine-tune your model with the SparseMLCallback + a SparseML Recipe. By training with the SparseMLCallback, you can leverage the DeepSparse engine to exploit the introduced sparsity, resulting in large performance improvements.

The SparseML callback requires the model to be ONNX exportable. This can be tricky when the model requires dynamic sequence lengths such as RNNs.

To use leverage SparseML & DeepSparse follow the below steps:

Choose your Sparse Recipe

To choose a recipe, have a look at recipes and Sparse Zoo.

It may be easier to infer a recipe via the UI dashboard using Sparsify which allows you to tweak and configure a recipe. This requires to import an ONNX model, which you can get from your LightningModule by doing model.to_onnx(output_path).

Train with SparseMLCallback

    from pytorch_lightning import LightningModule, Trainer
    from pl_bolts.callbacks import SparseMLCallback

    class MyModel(LightningModule):
        ...

    model = MyModel()

    trainer = Trainer(
        callbacks=SparseMLCallback(recipe_path='recipe.yaml')
    )

Export to ONNX!

Using the helper function, we handle any quantization/pruning internally and export the model into ONNX format. Note this assumes either you have implemented the property example_input_array in the model or you must provide a sample batch as below.

    import torch

    model = MyModel()
    ...

    # export the onnx model, using the `model.example_input_array`
    SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/')

    # export the onnx model, providing a sample batch
    SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/', sample_batch=torch.randn(1, 128, 128, dtype=torch.float32))

Once your model has been exported, you can import this into either Sparsify or DeepSparse.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pl_hub_sparse_ml_callback-0.0.2.tar.gz (8.1 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file pl_hub_sparse_ml_callback-0.0.2.tar.gz.

File metadata

  • Download URL: pl_hub_sparse_ml_callback-0.0.2.tar.gz
  • Upload date:
  • Size: 8.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.8

File hashes

Hashes for pl_hub_sparse_ml_callback-0.0.2.tar.gz
Algorithm Hash digest
SHA256 fcd2a45e4f13980e0a6178d509bf5268c8c17c58cfa46689b49d54a4a6e67865
MD5 060c85eaf263013139a1f5c718ad6a75
BLAKE2b-256 1584994d5868632354e2635e4b811884b0a78d1aaec475832c82ba0c27cf7124

See more details on using hashes here.

File details

Details for the file pl_hub_sparse_ml_callback-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: pl_hub_sparse_ml_callback-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 9.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.8

File hashes

Hashes for pl_hub_sparse_ml_callback-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 ab28a006f4af3ca9ea67037ea84df77b2e12ffee8d2da42570d73fbf0ca5d0c2
MD5 56fdb8122d6b1e002771487f4168c98a
BLAKE2b-256 4eece0eac4c760fa29b2236f28123ae8e70f634e5edc29cacac236703ec7b389

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page