Skip to main content

Convert a Chainer model into ONNX

Project description


PyPI GitHub license Build Status codecov

This is an add-on package for ONNX support by Chainer. ONNX-Chainer supports opset version <= 7.

Tested environment

  • ONNX >=1.3.0
    • opset version 7, 8
  • Chainer 5.0.0, 6.0.0a1
  • Python 3.5.5, 3.6.7
  • ONNX-Runtime 0.1.3

(You can still specify all opset versions <= 8, but please noted that opset versions <= 6 are not tested)


On Ubuntu 14.04/16.04

Please install Chainer first.

pip install chainer
pip install onnx-chainer

Run Test

1. Build Docker images

cd docker

2. Run tests

bash docker/

Quick Start

First, install ChainerCV to get the pre-trained models.

import numpy as np

import chainer
import chainercv.links as C
import onnx_chainer

model = C.VGG16(pretrained_model='imagenet')

# Pseudo input
x = np.zeros((1, 3, 224, 224), dtype=np.float32)

onnx_chainer.export(model, x, filename='vgg16.onnx')

Load models from MXNet

Please install MXNet 1.3.0b20180830 or newer one via pip: pip install mxnet --pre, then try the following code:

import collections

import mxnet
import numpy as np

import chainer
import chainer.functions as F
import chainercv.links as C
import onnx_chainer

# Prepare an input tensor
x = np.random.rand(1, 3, 224, 224).astype(np.float32) * 255

# Run the model on the data
with chainer.using_config('train', False):
    chainer_out = model(x).array

# Export Chainer model into ONNX
onnx_chainer.export(model, x, fn)

# Load ONNX model into MXNet symbol
sym, arg, aux = mxnet.contrib.onnx.import_model(fn)

# Find the name of input tensor
data_names = [graph_input for graph_input in sym.list_inputs()
                if graph_input not in arg and graph_input not in aux]
data_shapes = [(data_names[0], x.shape)]

# Create MXNet model
mod = mxnet.mod.Module(
    symbol=sym, data_names=data_names, context=mxnet.cpu(),
    for_training=False, data_shapes=data_shapes,
    arg_params=arg, aux_params=aux, allow_missing=True,

# Create input data
Batch = collections.namedtuple('Batch', ['data'])
input_data = Batch([mxnet.nd.array(x)])

# Forward computation using MXNet

# Retrieve the output of forward result
mxnet_out = mod.get_outputs()[0].asnumpy()

# Check the prediction results are same
assert np.argmax(chainer_out) == np.argmax(mxnet_out)

# Check both outputs have same values
np.testing.assert_almost_equal(chainer_out, mxnet_out, decimal=5)

Compile the Chainer model via ONNX

Please install TVM v0.4 first. You can find hwo to build it in this Dockerfile.

import collections

import numpy as np
import onnx

import chainer
import chainer.functions as F
import chainercv.links as C
import nnvm
import onnx_chainer
import tvm

model = C.ResNet50(pretrained_model='imagenet', arch='he')
# Change cover_all option to False to match the default behavior of MXNet's pooling
model.pool1 = lambda x: F.max_pooling_2d(x, ksize=3, stride=2, cover_all=False)
save_as_onnx_then_import_from_nnvm(model, 'resnet50.onnx')

# Prepare an input tensor
x = np.random.rand(1, 3, 224, 224).astype(np.float32) * 255

# Run the model on the data
with chainer.using_config('train', False):
    chainer_out = model(x).array

# Export Chainer model into ONNX
onnx_chainer.export(model, x, fn)

# Load the saved ONNX file using ONNX module
model_onnx = onnx.load(fn)

# Convert the ONNX model object into NNVM symbol
sym, params = nnvm.frontend.from_onnx(model_onnx)

# Choose the compilation target
target = 'llvm'

# Extract the name of input variable in the ONNX graph
input_name = sym.list_input_names()[0]
shape_dict = {input_name: x.shape}

# Compile the model using NNVM
graph, lib, params =
    sym, target, shape_dict, params=params,
    dtype={input_name: 'float32'})

# Convert the compiled model into TVM module
module = tvm.contrib.graph_runtime.create(graph, lib, tvm.cpu(0))

# Set the input tensor x
module.set_input(input_name, tvm.nd.array(x))

# Run the model

# Retrieve the inference result
out_shape = (1, 1000)
output = tvm.nd.empty(out_shape, ctx=tvm.cpu(0))
nnvm_output = module.get_output(0, output).asnumpy()

# Check both outputs have same values
np.testing.assert_almost_equal(chainer_out, nnvm_output, decimal=5)

Options of export function

Export function for chainer.Chain in ONNX format.

    This function performs a forward computation of the given
    :class:`~chainer.Chain`, ``model``, by passing the given arguments ``args``
    directly. It means, the output :class:`~chainer.Variable` object ``y`` to
    make the computational graph will be created by:

    y = model(*args)

        model (~chainer.Chain): The model object you want to export in ONNX
            format. It should have :meth:`__call__` method because the second
            argument ``args`` is directly given to the model by the ``[]``
        args (list or dict): The arguments which are given to the model
        filename (str or file-like object): The filename used for saving the
            resulting ONNX model. If None, nothing is saved to the disk.
        export_params (bool): If True, this function exports all the parameters
            included in the given model at the same time. If False, the
            exported ONNX model doesn't include any parameter values.
        graph_name (str): A string to be used for the ``name`` field of the
            graph in the exported ONNX model.
        save_text (bool): If True, the text format of the output ONNX model is
            also saved with ``.txt`` extention.
        opset_version (int): The operator set version of ONNX. If not specified
            or ``None`` is given, the latest opset version of the onnx module
            is used. If an integer is given, it will be ensured that all the
            operator version in the exported ONNX file is less than this value.

        A ONNX model object.

Supported Functions

Currently 59 Chainer Functions are supported to export in ONNX format.


  • ClippedReLU
  • ELU
  • HardSigmoid
  • LeakyReLU
  • LogSoftmax
  • PReLUFunction
  • ReLU
  • Sigmoid
  • Softmax
  • Softplus
  • Tanh


  • Cast
  • Concat
  • Copy
  • Depth2Space
  • ExpandDims
  • GetItem
  • Pad 12
  • Reshape
  • Space2Depth
  • SplitAxis
  • Squeeze
  • Tile
  • Transpose


  • Convolution2DFunction
  • ConvolutionND
  • Deconvolution2DFunction
  • DeconvolutionND
  • EmbedIDFunction 3
  • LinearFunction


  • Add
  • AddConstant
  • Absolute
  • BroadcastTo
  • Div
  • Mul
  • MulConstant
  • Neg
  • PowVarConst
  • Sub
  • Clip
  • Exp
  • Identity
  • MatMul
  • Maximum
  • Minimum
  • Sqrt
  • Square
  • LogSumExp
  • Max
  • Mean
  • Min
  • Prod
  • Sum
  • LinearInterpolate


  • Dropout 4


  • AveragePooling2D
  • AveragePoolingND
  • MaxPooling2D
  • MaxPoolingND
  • ROIPooling2D


  • BatchNormalization
  • FixedBatchNormalization
  • LocalResponseNormalization
  • NormalizeL2


Any contribution to ONNX-Chainer is welcome!

1: mode should be either 'constant', 'reflect', or 'edge'
2: ONNX doesn't support multiple constant values for Pad operation
3: Current ONNX doesn't support ignore_label for EmbedID
4: In test mode, all dropout layers aren't included in the exported file

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Filename, size & hash SHA256 hash help File type Python version Upload date
onnx-chainer-1.3.3.tar.gz (20.5 kB) Copy SHA256 hash SHA256 Source None

Supported by

Elastic Elastic Search Pingdom Pingdom Monitoring Google Google BigQuery Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN SignalFx SignalFx Supporter DigiCert DigiCert EV certificate StatusPage StatusPage Status page