Skip to main content

将 PyTorch 模型转换为 Nexus 代码

Project description

Nexus Converter

将 PyTorch 模型转换为 Nexus 代码。

需要 PyTorch 模型的构造过程符合以下范式:

class MyModel(torch.nn.Module):
    def __init__(self, num_class):
        super().__init__()

        """ 1. 定义每个网络层的参数 """
        self.conv1 = Conv2d(...)
        self.relu1 = ReLU()
        ...

    def forward(self, x):
        """ 2. 前向计算 """
        x = self.conv1(x)
        x = self.relu1(x)
        ...

即,在 __init__() 中构造模型的树形结构(将子模块,也就是网络中的每一层,设为这个模型的属性),在 forward() 中描述前向计算。在前向计算中只使用 __init__() 中的子模块,而不使用 torch.functional 中的函数式的接口。

然后,在主程序中首先引入 convert 模块中的 toNexusCode

from convert import toNexusCode

主程序中首先生成模型的实例,然后将其设为估值模式:

model = MyModel()
model.eval()

toNexusCode 只需要 model 和输入样例 data 即可导出代码:

code = toNexusCode(model, data)

生成的 code 为字符串的形式,如下:

auto data0 = Symbol::Variable("data0");
auto conv1 = Convolution("conv1", data0, {3, 3}, 64, false, {1, 1}, {1, 1}, {1, 1});
auto relu1 = Relu("relu1", conv1);
auto conv2 = Convolution("conv2", relu1, {3, 3}, 64, false, {1, 1}, {1, 1}, {1, 1});
auto relu2 = Relu("relu2", conv2);
auto pool1 = Pooling("pool1", relu2, 2, 2, 0);
auto conv3 = Convolution("conv3", pool1, {3, 3}, 128, false, {1, 1}, {1, 1}, {1, 1});
auto relu3 = Relu("relu3", conv3);
auto conv4 = Convolution("conv4", relu3, {3, 3}, 128, false, {1, 1}, {1, 1}, {1, 1});
auto relu4 = Relu("relu4", conv4);
auto pool2 = Pooling("pool2", relu4, 2, 2, 0);
auto conv5 = Convolution("conv5", pool2, {3, 3}, 256, false, {1, 1}, {1, 1}, {1, 1});
auto relu5 = Relu("relu5", conv5);
auto conv6 = Convolution("conv6", relu5, {3, 3}, 256, false, {1, 1}, {1, 1}, {1, 1});
auto relu6 = Relu("relu6", conv6);
auto conv7 = Convolution("conv7", relu6, {3, 3}, 256, false, {1, 1}, {1, 1}, {1, 1});
auto relu7 = Relu("relu7", conv7);
auto pool3 = Pooling("pool3", relu7, 2, 2, 0);
auto conv8 = Convolution("conv8", pool3, {3, 3}, 512, false, {1, 1}, {1, 1}, {1, 1});
auto relu8 = Relu("relu8", conv8);
auto conv9 = Convolution("conv9", relu8, {3, 3}, 512, false, {1, 1}, {1, 1}, {1, 1});
auto relu9 = Relu("relu9", conv9);
auto conv10 = Convolution("conv10", relu9, {3, 3}, 512, false, {1, 1}, {1, 1}, {1, 1});
auto relu10 = Relu("relu10", conv10);
auto pool4 = Pooling("pool4", relu10, 2, 2, 0);
auto conv11 = Convolution("conv11", pool4, {3, 3}, 512, false, {1, 1}, {1, 1}, {1, 1});
auto relu11 = Relu("relu11", conv11);
auto conv12 = Convolution("conv12", relu11, {3, 3}, 512, false, {1, 1}, {1, 1}, {1, 1});
auto relu12 = Relu("relu12", conv12);
auto conv13 = Convolution("conv13", relu12, {3, 3}, 512, false, {1, 1}, {1, 1}, {1, 1});
auto relu13 = Relu("relu13", conv13);
auto pool5 = Pooling("pool5", relu13, 2, 2, 0);
auto flatten = Flatten("flatten", pool5);
auto fc1 = FullyConnected("fc1", flatten, 4096, false, true);
auto drop1 = Dropout("drop1", fc1, 0.5);
auto fc2 = FullyConnected("fc2", drop1, 4096, false, true);
auto drop2 = Dropout("drop2", fc2, 0.5);
auto fc3 = FullyConnected("fc3", drop2, 10, false, true);

可以打印出来并粘贴到文本编辑器中。

目前在 PyTorch 中可以被转换的算子包括:

  • Conv2d
  • MaxPool2d
  • Linear
  • Dropout
  • ReLU
  • Flatten
  • Sigmoid
  • Tanh
  • Softmax

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nexusconverter-0.0.3.tar.gz (3.9 kB view hashes)

Uploaded Source

Built Distribution

nexusconverter-0.0.3-py3-none-any.whl (4.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page