Skip to main content

The deep learning models convertor

Project description

gluon2pytorch

Build Status GitHub License Python Version

Gluon to PyTorch model convertor with script generation.

Installation

git clone https://github.com/nerox8664/gluon2pytorch
cd gluon2pytorch
pip install -e . 

or you can use pip:

pip install gluon2pytorch

How to use

It's the convertor of Gluon graph to a Pytorch model file + weights.

Firstly, we need to load (or create) Gluon Hybrid model:


class ReLUTest(mx.gluon.nn.HybridSequential):
    def __init__(self):
        super(ReLUTest, self).__init__()
        from mxnet.gluon import nn
        with self.name_scope():
            self.conv1 = nn.Conv2D(3, 32)
            self.relu = nn.Activation('relu')

    def hybrid_forward(self, F, x):
        x = F.relu(self.relu(self.conv1(x)))
        return x


if __name__ == '__main__':
    net = ReLUTest()
    
    # Make sure it's hybrid and initialized
    net.hybridize()
    net.collect_params().initialize()

The next step - call the converter:

    pytorch_model = gluon2pytorch(net, [(1, 3, 224, 224)], dst_dir=None, pytorch_module_name='ReLUTest')

Finally, we can check the difference

    input_np = np.random.uniform(-1, 1, (1, 3, 224, 224))

    gluon_output = net(mx.nd.array(input_np))
    pytorch_output = pytorch_model(torch.FloatTensor(input_np))
    check_error(gluon_output, pytorch_output)

Supported layers

Layers:

  • Linear
  • Conv2d
  • ConvTranspose2d (Deconvolution)
  • MaxPool2d
  • AvgPool2d
  • Global average pooling (as special case of AdaptiveAvgPool2d)
  • BatchNorm2d* Padding2d (constant, reflection, replication)

Reshape:

  • Flatten

Activations:

  • ReLU
  • LeakyReLU
  • Sigmoid
  • Softmax
  • SELU

Element-wise:

  • Addition
  • Concatenation
  • Subtraction
  • Multiplication

Models converted with gluon2pytorch:

  • ResNet*
  • SeNet
  • DenseNet*
  • DPN
  • Mobilenet

Code snippets

Look at the tests directory.

License

This software is covered by MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gluon2pytorch-0.0.2.linux-x86_64.tar.gz (12.5 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page