Skip to main content

No project description provided

Project description

Lightning Neural Compressor

License

This repository contains the implementation of the Lightning Neural Compressor. The main goal of this project is to provide Pytorch Lightning callbacks to use Intel® Neural Compressor. The callbacks aim at compressing a neural network so that it can be used on edge devices (i.e., mobile phones, raspberry pi, etc.). This project is a work in progress and is not ready for production use.

Current Status

The project is currently under development, starting with Quantization Aware Training, as the default callback has been deleted from Pytorch Lightning.

The project also supports Weight Pruning and should work at least with pruners related to the PytorchBasicPruner.

Installation

To install the dependencies for this project, use the following command to use pypi:

pip install -U lightning-nc

or directly by cloning the main branch:

git clone https://github.com/clementpoiret/lightning-nc
cd lightning-nc
pip install -e .

Usage

To use the Lightning Neural Compressor, import the callbacks from the lightning_nc module.

WARNING: Currently, the callbacks need the PyTorch model to be a nn.Module contained inside your LightningModule. This is not a huge limitation as the refactoring is easy and straightforward, such as:

import os

import lightning as L
import timm
import torch
import torch.nn.functional as F
from neural_compressor import QuantizationAwareTrainingConfig
from neural_compressor.config import Torch2ONNXConfig
from neural_compressor.training import WeightPruningConfig
from lightning_nc import QATCallback, WeightPruningCallback
from torch import Tensor, nn, optim, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor


# Define your main model here
class VeryComplexModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model("best_pretrained_model",
                                          pretrained=True)

        self.mlp = nn.Sequential(nn.Linear(self.backbone.num_features, 128),
                                 nn.ReLU(), nn.Linear(128, 10))

    def forward(self, x):
        return self.mlp(self.backbone(x))


# Then, define your LightningModule as usual
class Classifier(L.LightningModule):

    def __init__(self):
        super().__init__()

        # This is mandatory for the callbacks
        self.model = VeryComplexModel()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch

        # This is just to use MNIST images on a pretrained timm model, you can skip that
        x = x.repeat(1, 3, 1, 1)
        x = F.interpolate(x, size=(224, 224))

        y_hat = self.forward(x)

        loss = F.cross_entropy(y_hat, y)

        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)

        return [optimizer]


clf = Classifier()

# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

Now that everything is setup, the callbacks can be integrated into a PyTorch Lightning training routine:

# Define the configs for Pruning and Quantization
q_config = QuantizationAwareTrainingConfig()
p_config = WeightPruningConfig([{
    "op_names": ["backbone.*"],
    "start_step": 1,
    "end_step": 100,
    "target_sparsity": 0.5,
    "pruning_frequency": 1,
    "pattern": "4x1",
    "min_sparsity_ratio_per_op": 0.,
    "pruning_scope": "global",
}])

callbacks = [
    QATCallback(config=q_config),
    WeightPruningCallback(config=p_config),
]

trainer = L.Trainer(accelerator="gpu",
                    strategy="auto",
                    limit_train_batches=100,
                    max_epochs=1,
                    callbacks=callbacks)
trainer.fit(model=clf, train_dataloaders=train_loader)

Models can now be saved eaily such as:

clf.model.export(
    "model.onnx",
    Torch2ONNXConfig(
        dtype="int8",
        opset_version=17,
        quant_format="QOperator",  # or QDQ
        example_inputs=torch.randn(1, 3, 224, 224),
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={
            "input": {
                0: "batch_size"
            },
            "output": {
                0: "batch_size"
            },
        },
    ))

Contributing

If you would like to contribute to this project, please submit a pull request. All contributions are welcome!

License

This project is licensed under the MIT License. See the LICENSE.md file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightning_nc-0.0.3.tar.gz (6.3 kB view details)

Uploaded Source

Built Distribution

lightning_nc-0.0.3-py3-none-any.whl (8.4 kB view details)

Uploaded Python 3

File details

Details for the file lightning_nc-0.0.3.tar.gz.

File metadata

  • Download URL: lightning_nc-0.0.3.tar.gz
  • Upload date:
  • Size: 6.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.0 CPython/3.11.6 Linux/6.6.1-arch1-1

File hashes

Hashes for lightning_nc-0.0.3.tar.gz
Algorithm Hash digest
SHA256 b4736b13726fe0b2efce90bf7921d43f73dadec450d51d393ef8f4a752c23ab8
MD5 7c6d3ea2303438f962020ccf47a239ad
BLAKE2b-256 569fc3b0e91a6abf99b92a8d9d05485ec8ee2adbefed74fa581994d61ead0b71

See more details on using hashes here.

File details

Details for the file lightning_nc-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: lightning_nc-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 8.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.0 CPython/3.11.6 Linux/6.6.1-arch1-1

File hashes

Hashes for lightning_nc-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 3d711e195128cf13bcd76830c19bdec1f248093d7671b718dd6afd84e56a5daa
MD5 970cf6b98f5de3d7b3f07ebaf179b952
BLAKE2b-256 14fbc3b24a5df79410c96db314f5c97aa700424ff10fec5023090de04dba1bf3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page