Skip to main content

Deep AI modules developed by MOGO RTX team

Project description

rtx_deep: Deep AI modules developed by MOGO RTX team, aims to accelerate the distributed training, int8-aware distributed training, distributed evaluation and inference, model tracing and optimization, and TensorRT deployment.

1 Dependency

torch>=1.10.0
tensorrt>=7.0
graphviz

2 Installation

pip3 install graphviz
apt-get install graphviz
python3 setup.py install

3 Examples

3.1 Graph Tracing and Model Optimization
import torch
import torch.nn as nn
import torch.nn.functional as F

import rtx_deep


class conv3x3_bn_relu(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1, dilation=1, groups=1):
        super(conv3x3_bn_relu, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x1 = self.net(x)
        return x1


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.net = nn.Sequential(
            conv3x3_bn_relu(64, 64),
            conv3x3_bn_relu(64, 64)
        )
    
    def forward(self, x):
        x1 = self.net(x)
        return x1

model = Model()
model.eval()
model.cuda()

input_data = torch.randn(1, 64, 1024, 1024).cuda()

# graph tracing
model_fx = rtx_deep.graph_tracer.ad_trace.graph_trace(model, function_name=None)

# Model Optimization
# conduct graph tracing in graph_optim_from_module automatically
model_fx_optim = rtx_deep.graph_tracer.graph_utils.graph_optim_from_module(model, function_name=None, sample_inputs=(input_data,))
3.2 Quantization-Aware Training
import torch
import torch.nn as nn
import torch.nn.functional as F

import rtx_deep


class conv3x3_bn_relu(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1, dilation=1, groups=1):
        super(conv3x3_bn_relu, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x1 = self.net(x)
        return x1


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.net = nn.Sequential(
            conv3x3_bn_relu(64, 64),
            conv3x3_bn_relu(64, 64)
        )
    
    def forward(self, x):
        x1 = self.net(x)
        return x1

model = Model()
model.eval()
model.cuda()

input_data = torch.randn(1, 64, 1024, 1024).cuda()

# Model Optimization
# conduct graph tracing in graph_optim_from_module automatically
model_fx_optim = rtx_deep.graph_tracer.graph_utils.graph_optim_from_module(model, function_name=None)

# qat
model_qat = rtx_deep.quant_lib.quant_utils.prepare_qat(model_fx_optim,
    sample_inputs=[input_data],
    observe_config_dic=dict(averaging_constant=0.05),
    quant_config_dic=dict(quant_min=-127, quant_max=127, is_symmetric=True, is_quant=True),
    disable_prefix=[])


# vis model network
rtx_deep.graph_tracer.vis_model.vis(model_fx_optim, './model_fx_optim.png')
rtx_deep.graph_tracer.vis_model.vis(model_qat, './model_qat.png')

# qat training
...
3.3 TensorRT Deployment
import torch
import torch.nn as nn
import torch.nn.functional as F

import rtx_deep
import rtx_deep_plugin
from rtx_deep.deploy_lib.convert_trt import InputTensor, torch2trt

class conv3x3_bn_relu(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1, dilation=1, groups=1):
        super(conv3x3_bn_relu, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x1 = self.net(x)
        return x1


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.net = nn.Sequential(
            conv3x3_bn_relu(64, 64),
            conv3x3_bn_relu(64, 64)
        )
    
    def forward(self, x):
        x1 = self.net(x)
        x2 = rtx_deep_plugin.max_op(x1, dim=1)
        return x2

model = Model()
model.eval()
model.cuda()

input_data = torch.randn(1, 64, 1024, 1024).cuda()

# Model Optimization
# conduct graph tracing in graph_optim_from_module automatically
model_fx_optim = rtx_deep.graph_tracer.graph_utils.graph_optim_from_module(model, function_name=None, sample_inputs=(input_data,))

# TensorRT Deployment
model_trt = torch2trt(
    model=model_fx_optim,
    input_specs=[InputTensor(input_data, 'input_data')],
    output_names=['max_value', 'max_index'],
    fp16_mode=True,
    #dla_core=0,
    strict_type_constraints=True,
    explicit_precision=True
)

# vis tensorrt network
rtx_deep.deploy_lib.tools.vis_trt.vis(model_trt.network, 'test.png')

error = model(input_data)[0] - model_trt(input_data)[0]
print(error.abs().max())

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

rtx_deep-1.3.8-py311-none-any.whl (125.2 kB view hashes)

Uploaded Python 3.11

rtx_deep-1.3.8-py310-none-any.whl (81.1 kB view hashes)

Uploaded Python 3.10

rtx_deep-1.3.8-py39-none-any.whl (80.5 kB view hashes)

Uploaded Python 3.9

rtx_deep-1.3.8-py38-none-any.whl (80.7 kB view hashes)

Uploaded Python 3.8

rtx_deep-1.3.8-py37-none-any.whl (80.8 kB view hashes)

Uploaded Python 3.7

rtx_deep-1.3.8-py36-none-any.whl (80.4 kB view hashes)

Uploaded Python 3.6

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page