Skip to main content

Light AI tools developed by Gang Zhang

Project description

light_deep_ai: Deep AI modules developed by MOGO RTX team, aims to accelerate the distributed training, int8-aware distributed training, distributed evaluation and inference, model tracing and optimization, and TensorRT deployment.

1 Dependency

torch>=1.10.0
tensorrt>=7.0
graphviz

2 Installation

pip3 install graphviz
apt-get install graphviz
python3 setup.py install

3 Examples

3.1 Graph Tracing and Model Optimization
import torch
import torch.nn as nn
import torch.nn.functional as F

import light_deep_ai


class conv3x3_bn_relu(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1, dilation=1, groups=1):
        super(conv3x3_bn_relu, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x1 = self.net(x)
        return x1


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.net = nn.Sequential(
            conv3x3_bn_relu(64, 64),
            conv3x3_bn_relu(64, 64)
        )
    
    def forward(self, x):
        x1 = self.net(x)
        return x1

model = Model()
model.eval()
model.cuda()

input_data = torch.randn(1, 64, 1024, 1024).cuda()

# graph tracing
model_fx = light_deep_ai.graph_tracer.ad_trace.graph_trace(model, function_name=None)

# Model Optimization
# conduct graph tracing in graph_optim_from_module automatically
model_fx_optim = light_deep_ai.graph_tracer.graph_utils.graph_optim_from_module(model, function_name=None, sample_inputs=(input_data,))
3.2 Quantization-Aware Training
import torch
import torch.nn as nn
import torch.nn.functional as F

import light_deep_ai


class conv3x3_bn_relu(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1, dilation=1, groups=1):
        super(conv3x3_bn_relu, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x1 = self.net(x)
        return x1


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.net = nn.Sequential(
            conv3x3_bn_relu(64, 64),
            conv3x3_bn_relu(64, 64)
        )
    
    def forward(self, x):
        x1 = self.net(x)
        return x1

model = Model()
model.eval()
model.cuda()

input_data = torch.randn(1, 64, 1024, 1024).cuda()

# Model Optimization
# conduct graph tracing in graph_optim_from_module automatically
model_fx_optim = light_deep_ai.graph_tracer.graph_utils.graph_optim_from_module(model, function_name=None)

# qat
model_qat = light_deep_ai.quant_lib.quant_utils.prepare_qat(model_fx_optim,
    sample_inputs=[input_data],
    observe_config_dic=dict(averaging_constant=0.05),
    quant_config_dic=dict(quant_min=-127, quant_max=127, is_symmetric=True, is_quant=True),
    disable_prefix=[])


# vis model network
light_deep_ai.graph_tracer.vis_model.vis(model_fx_optim, './model_fx_optim.png')
light_deep_ai.graph_tracer.vis_model.vis(model_qat, './model_qat.png')

# qat training
...
3.3 TensorRT Deployment
import torch
import torch.nn as nn
import torch.nn.functional as F

import light_deep_ai
import light_deep_ai_plugin
from light_deep_ai.deploy_lib.convert_trt import InputTensor, torch2trt

class conv3x3_bn_relu(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1, dilation=1, groups=1):
        super(conv3x3_bn_relu, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x1 = self.net(x)
        return x1


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.net = nn.Sequential(
            conv3x3_bn_relu(64, 64),
            conv3x3_bn_relu(64, 64)
        )
    
    def forward(self, x):
        x1 = self.net(x)
        x2 = light_deep_ai_plugin.max_op(x1, dim=1)
        return x2

model = Model()
model.eval()
model.cuda()

input_data = torch.randn(1, 64, 1024, 1024).cuda()

# Model Optimization
# conduct graph tracing in graph_optim_from_module automatically
model_fx_optim = light_deep_ai.graph_tracer.graph_utils.graph_optim_from_module(model, function_name=None, sample_inputs=(input_data,))

# TensorRT Deployment
model_trt = torch2trt(
    model=model_fx_optim,
    input_specs=[InputTensor(input_data, 'input_data')],
    output_names=['max_value', 'max_index'],
    fp16_mode=True,
    #dla_core=0,
    strict_type_constraints=True,
    explicit_precision=True
)

# vis tensorrt network
light_deep_ai.deploy_lib.tools.vis_trt.vis(model_trt.network, 'test.png')

error = model(input_data)[0] - model_trt(input_data)[0]
print(error.abs().max())

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

light_deep_ai-1.4.8-py311-none-any.whl (138.2 kB view hashes)

Uploaded Python 3.11

light_deep_ai-1.4.8-py310-none-any.whl (89.3 kB view hashes)

Uploaded Python 3.10

light_deep_ai-1.4.8-py39-none-any.whl (88.6 kB view hashes)

Uploaded Python 3.9

light_deep_ai-1.4.8-py38-none-any.whl (88.9 kB view hashes)

Uploaded Python 3.8

light_deep_ai-1.4.8-py37-none-any.whl (88.9 kB view hashes)

Uploaded Python 3.7

light_deep_ai-1.4.8-py36-none-any.whl (88.5 kB view hashes)

Uploaded Python 3.6

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page