将 PyTorch 模型转换为 Nexus 代码
Project description
Nexus Converter
将 PyTorch 模型转换为 Nexus 代码。
需要 PyTorch 模型的构造过程符合以下范式:
class MyModel(torch.nn.Module):
def __init__(self, num_class):
super().__init__()
""" 1. 定义每个网络层的参数 """
self.conv1 = Conv2d(...)
self.relu1 = ReLU()
...
def forward(self, x):
""" 2. 前向计算 """
x = self.conv1(x)
x = self.relu1(x)
...
即,在 __init__()
中构造模型的树形结构(将子模块,也就是网络中的每一层,设为这个模型的属性),在 forward()
中描述前向计算。在前向计算中只使用 __init__()
中的子模块,而不使用 torch.functional
中的函数式的接口。
然后,在主程序中首先引入 convert
模块中的 toNexusCode
:
from convert import toNexusCode
主程序中首先生成模型的实例,然后将其设为估值模式:
model = MyModel()
model.eval()
toNexusCode
只需要 model
和输入样例 data
即可导出代码:
code = toNexusCode(model, data)
生成的 code
为字符串的形式,如下:
auto data0 = Symbol::Variable("data0");
auto conv1 = Convolution("conv1", data0, {3, 3}, 64, false, {1, 1}, {1, 1}, {1, 1});
auto relu1 = Relu("relu1", conv1);
auto conv2 = Convolution("conv2", relu1, {3, 3}, 64, false, {1, 1}, {1, 1}, {1, 1});
auto relu2 = Relu("relu2", conv2);
auto pool1 = Pooling("pool1", relu2, 2, 2, 0);
auto conv3 = Convolution("conv3", pool1, {3, 3}, 128, false, {1, 1}, {1, 1}, {1, 1});
auto relu3 = Relu("relu3", conv3);
auto conv4 = Convolution("conv4", relu3, {3, 3}, 128, false, {1, 1}, {1, 1}, {1, 1});
auto relu4 = Relu("relu4", conv4);
auto pool2 = Pooling("pool2", relu4, 2, 2, 0);
auto conv5 = Convolution("conv5", pool2, {3, 3}, 256, false, {1, 1}, {1, 1}, {1, 1});
auto relu5 = Relu("relu5", conv5);
auto conv6 = Convolution("conv6", relu5, {3, 3}, 256, false, {1, 1}, {1, 1}, {1, 1});
auto relu6 = Relu("relu6", conv6);
auto conv7 = Convolution("conv7", relu6, {3, 3}, 256, false, {1, 1}, {1, 1}, {1, 1});
auto relu7 = Relu("relu7", conv7);
auto pool3 = Pooling("pool3", relu7, 2, 2, 0);
auto conv8 = Convolution("conv8", pool3, {3, 3}, 512, false, {1, 1}, {1, 1}, {1, 1});
auto relu8 = Relu("relu8", conv8);
auto conv9 = Convolution("conv9", relu8, {3, 3}, 512, false, {1, 1}, {1, 1}, {1, 1});
auto relu9 = Relu("relu9", conv9);
auto conv10 = Convolution("conv10", relu9, {3, 3}, 512, false, {1, 1}, {1, 1}, {1, 1});
auto relu10 = Relu("relu10", conv10);
auto pool4 = Pooling("pool4", relu10, 2, 2, 0);
auto conv11 = Convolution("conv11", pool4, {3, 3}, 512, false, {1, 1}, {1, 1}, {1, 1});
auto relu11 = Relu("relu11", conv11);
auto conv12 = Convolution("conv12", relu11, {3, 3}, 512, false, {1, 1}, {1, 1}, {1, 1});
auto relu12 = Relu("relu12", conv12);
auto conv13 = Convolution("conv13", relu12, {3, 3}, 512, false, {1, 1}, {1, 1}, {1, 1});
auto relu13 = Relu("relu13", conv13);
auto pool5 = Pooling("pool5", relu13, 2, 2, 0);
auto flatten = Flatten("flatten", pool5);
auto fc1 = FullyConnected("fc1", flatten, 4096, false, true);
auto drop1 = Dropout("drop1", fc1, 0.5);
auto fc2 = FullyConnected("fc2", drop1, 4096, false, true);
auto drop2 = Dropout("drop2", fc2, 0.5);
auto fc3 = FullyConnected("fc3", drop2, 10, false, true);
可以打印出来并粘贴到文本编辑器中。
目前在 PyTorch 中可以被转换的算子包括:
- Conv2d
- MaxPool2d
- Linear
- Dropout
- ReLU
- Flatten
- Sigmoid
- Tanh
- Softmax
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
nexusconverter-0.0.3.tar.gz
(3.9 kB
view hashes)
Built Distribution
Close
Hashes for nexusconverter-0.0.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9c6c0d52377ffd8b8108ad7b22c70bb4da6730549e7ef2b3b9af55fd6555a4b3 |
|
MD5 | 9b2a6e1c3ed4c3b9abf4a511875a9c56 |
|
BLAKE2b-256 | 2b72fc508f21d404ce9c41ee245c1137309b82df9fd1975ba62db7c0ab192b43 |