Deep AI modules developed by MOGO RTX team
Project description
rtx_deep
: Deep AI modules developed by MOGO RTX team, aims to accelerate the distributed training, int8-aware distributed training, distributed evaluation and inference, model tracing and optimization, and TensorRT deployment.
1 Dependency
torch>=1.8.0
tensorrt>=7.0
graphviz
2 Installation
pip3 install graphviz
apt-get install graphviz
python3 setup.py install
3 Examples
3.1 Graph Tracing and Model Optimization
import torch
import torch.nn as nn
import torch.nn.functional as F
import rtx_deep
import rtx_deep_plugin
class conv3x3_bn_relu(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, dilation=1, groups=1):
super(conv3x3_bn_relu, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU(inplace=True)
)
def forward(self, x):
x1 = self.net(x)
return x1
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.net = nn.Sequential(
conv3x3_bn_relu(64, 64),
conv3x3_bn_relu(64, 64)
)
def forward(self, x):
x1 = self.net(x)
x2 = rtx_deep_plugin.max_op(x1, dim=1)
return x2
model = Model()
model.eval()
model.cuda()
input_data = torch.randn(1, 64, 1024, 1024).cuda()
# graph tracing
model_fx = rtx_deep.graph_tracer.ad_trace.graph_trace(model, function_name=None)
# Model Optimization
# conduct graph tracing in graph_optim_from_module automatically
model_fx_optim = rtx_deep.graph_tracer.graph_utils.graph_optim_from_module(model, function_name=None, sample_inputs=(input_data,))
3.2 Quantization-Aware Training
import torch
import torch.nn as nn
import torch.nn.functional as F
import rtx_deep
import rtx_deep_plugin
class conv3x3_bn_relu(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, dilation=1, groups=1):
super(conv3x3_bn_relu, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU(inplace=True)
)
def forward(self, x):
x1 = self.net(x)
return x1
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.net = nn.Sequential(
conv3x3_bn_relu(64, 64),
conv3x3_bn_relu(64, 64)
)
def forward(self, x):
x1 = self.net(x)
x2 = rtx_deep_plugin.max_op(x1, dim=1)
return x2
model = Model()
model.eval()
model.cuda()
input_data = torch.randn(1, 64, 1024, 1024).cuda()
# Model Optimization
# conduct graph tracing in graph_optim_from_module automatically
model_fx_optim = rtx_deep.graph_tracer.graph_utils.graph_optim_from_module(model, function_name=None, sample_inputs=(input_data,))
# qat
model_qat = rtx_deep.quant_lib.quant_utils.prepare_qat(model_fx_optim,
sample_inputs=[input_data],
observe_config_dic=dict(averaging_constant=0.05),
quant_config_dic=dict(quant_min=-127, quant_max=127, is_symmetric=True, is_quant=True),
disable_prefix=[])
# qat training
...
3.3 TensorRT Deployment
import torch
import torch.nn as nn
import torch.nn.functional as F
import rtx_deep
import rtx_deep_plugin
from rtx_deep.deploy_lib.convert_trt import InputTensor, torch2trt
class conv3x3_bn_relu(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, dilation=1, groups=1):
super(conv3x3_bn_relu, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU(inplace=True)
)
def forward(self, x):
x1 = self.net(x)
return x1
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.net = nn.Sequential(
conv3x3_bn_relu(64, 64),
conv3x3_bn_relu(64, 64)
)
def forward(self, x):
x1 = self.net(x)
x2 = rtx_deep_plugin.max_op(x1, dim=1)
return x2
model = Model()
model.eval()
model.cuda()
input_data = torch.randn(1, 64, 1024, 1024).cuda()
# Model Optimization
# conduct graph tracing in graph_optim_from_module automatically
model_fx_optim = rtx_deep.graph_tracer.graph_utils.graph_optim_from_module(model, function_name=None, sample_inputs=(input_data,))
# TensorRT Deployment
model_trt = torch2trt(
model=model_fx,
input_specs=[InputTensor(input_data, 'input_data')],
output_names=['max_value', 'max_index'],
fp16_mode=True,
#dla_core=0,
strict_type_constraints=True,
explicit_precision=True
)
# vis tensorrt network
rtx_deep.deploy_lib.tools.vis_trt.vis(model_trt.network, 'test.png')
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
No source distribution files available for this release.See tutorial on generating distribution archives.
Built Distributions
rtx_deep-1.3.1-py311-none-any.whl
(116.8 kB
view hashes)
rtx_deep-1.3.1-py310-none-any.whl
(76.3 kB
view hashes)
rtx_deep-1.3.1-py39-none-any.whl
(75.6 kB
view hashes)
rtx_deep-1.3.1-py38-none-any.whl
(75.9 kB
view hashes)
rtx_deep-1.3.1-py37-none-any.whl
(75.9 kB
view hashes)
rtx_deep-1.3.1-py36-none-any.whl
(75.5 kB
view hashes)
Close
Hashes for rtx_deep-1.3.1-py311-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 80082668514685a655795f856716b02fa05bb899eaaa2d766b997482afaaba6d |
|
MD5 | 85af3538c69b64767fe5df788de1f6ab |
|
BLAKE2b-256 | 124b9f108c6080f6355cae60047b28b3ca62e69776ac644e2fd221290e024fbc |
Close
Hashes for rtx_deep-1.3.1-py310-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b0e6d22bd4280240c3bf0e5510a5499212d7496303777d81d448791b27926273 |
|
MD5 | f32067acbd9c2928bd3a3b5132b971b0 |
|
BLAKE2b-256 | a7bb7229934bbbe6b4ebb6dadf17de99d24150a3403c21dd732023790927d044 |
Close
Hashes for rtx_deep-1.3.1-py39-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 77e333ce07c30b5887e2071c7da3454802851b04b7dd2e8374fbcef1f1e17a68 |
|
MD5 | d0aeea0adb61d580d60831d6c34ee6ae |
|
BLAKE2b-256 | 9993350dba77ac94362cf002de96014b3fc907d116d884db6404e64a02066ec7 |
Close
Hashes for rtx_deep-1.3.1-py38-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a50408554c556967e1006a4b74d871af242304b127b0e69049b74bb677b3444d |
|
MD5 | c0b8b0253730f5564ab629a90e66aeef |
|
BLAKE2b-256 | c6f2df11d1463e340ed4932043652acd69e4dc9290b09b3e0d552e22ce117ef9 |
Close
Hashes for rtx_deep-1.3.1-py37-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c18edff7f71bdc2e3af1010ad65b75c8e9d35cf2a225f17a2f6db44412397cb5 |
|
MD5 | c8617ceb6b670dd3fe8a0acbf917945b |
|
BLAKE2b-256 | cd650a7d42f24df86eaf63afe58efaa2c35e5940a16956ea45992da38d91d0ab |
Close
Hashes for rtx_deep-1.3.1-py36-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5cb7dac2839a54099c54e3f7456092cd2a75251c413490e7b2e6ea00ab92cc32 |
|
MD5 | 5fb64e9fe5d493d64581cf8f938cd529 |
|
BLAKE2b-256 | 87b0bbbd933f47ab965f50397c72ad620e915f353aa64936730d1fd5ee44dd7d |