The deep learning models convertor
Project description
gluon2pytorch
Gluon to PyTorch model convertor with script generation.
Installation
git clone https://github.com/nerox8664/gluon2pytorch
cd gluon2pytorch
pip install -e .
or you can use pip
:
pip install gluon2pytorch
How to use
It's the convertor of Gluon graph to a Pytorch model file + weights.
Firstly, we need to load (or create) Gluon Hybrid model:
class ReLUTest(mx.gluon.nn.HybridSequential):
def __init__(self):
super(ReLUTest, self).__init__()
from mxnet.gluon import nn
with self.name_scope():
self.conv1 = nn.Conv2D(3, 32)
self.relu = nn.Activation('relu')
def hybrid_forward(self, F, x):
x = F.relu(self.relu(self.conv1(x)))
return x
if __name__ == '__main__':
net = ReLUTest()
# Make sure it's hybrid and initialized
net.hybridize()
net.collect_params().initialize()
The next step - call the converter:
pytorch_model = gluon2pytorch(net, [(1, 3, 224, 224)], dst_dir=None, pytorch_module_name='ReLUTest')
Finally, we can check the difference
input_np = np.random.uniform(-1, 1, (1, 3, 224, 224))
gluon_output = net(mx.nd.array(input_np))
pytorch_output = pytorch_model(torch.FloatTensor(input_np))
check_error(gluon_output, pytorch_output)
Supported layers
Layers:
- Linear
- Conv2d
- ConvTranspose2d (Deconvolution)
- MaxPool2d
- AvgPool2d
- Global average pooling (as special case of AdaptiveAvgPool2d)
- BatchNorm2d* Padding2d (constant, reflection, replication)
Reshape:
- Flatten
Activations:
- ReLU
- LeakyReLU
- Sigmoid
- Softmax
- SELU
Element-wise:
- Addition
- Concatenation
- Subtraction
- Multiplication
Models converted with gluon2pytorch:
- ResNet*
- SeNet
- DenseNet*
- DPN
- Mobilenet
Code snippets
Look at the tests
directory.
License
This software is covered by MIT License.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Close
Hashes for gluon2pytorch-0.0.2.linux-x86_64.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8f9d78af143892e09a89c46d36c9dd52dfe90b822f664ec8c8e8cedc2428d917 |
|
MD5 | 571c6409d67a2b0df4d9a61bc175bcd4 |
|
BLAKE2b-256 | b26ae84aac726082c648cf0407d6d91313916723924c44806c5d18f4d31f47a5 |