Skip to main content

Convert ONNX to PyTorch code.

Project description

onnx-pytorch

Build Status

Generates PyTorch code from ONNX.

Installation

  • From PyPI
pip install onnx-pytorch
  • From source
git clone https://github.com/fumihwh/onnx-pytorch.git
cd onnx-pytorch
pip install -r requirements.txt
pip install -e .

Usage

By Command Line

python -m onnx_pytorch.code_gen -h

usage: code_gen.py [-h] [--onnx_model_path ONNX_MODEL_PATH] [--output_dir OUTPUT_DIR] [--overwrite OVERWRITE] [--tensor_inplace TENSOR_INPLACE] [--continue_on_error CONTINUE_ON_ERROR] [--simplify_names SIMPLIFY_NAMES]

optional arguments:
  -h, --help            show this help message and exit
  --onnx_model_path ONNX_MODEL_PATH
                        The onnx model path.
  --output_dir OUTPUT_DIR
                        The output dir
  --overwrite OVERWRITE
                        Should overwrite the output dir.
  --tensor_inplace TENSOR_INPLACE
                        Try best to inplace tensor.
  --continue_on_error CONTINUE_ON_ERROR
                        Continue on error.
  --simplify_names SIMPLIFY_NAMES
                        Use indexing shorten name instead of original name.

By Python

from onnx_pytorch import code_gen
code_gen.gen("/path/to/onnx_model", "/path/to/output_dir")

A model.py file and variables/ folder will be created under output_dir/.

Tutorial

  1. Download resnet18 ONNX model.
wget https://github.com/onnx/models/raw/master/vision/classification/resnet/model/resnet18-v2-7.onnx
  1. Use onnx-pytorch to generate PyTorch code and variables.
from onnx_pytorch import code_gen
code_gen.gen("resnet18-v2-7.onnx", "./")
  1. Test result.
import numpy as np
import onnx
import onnxruntime
import torch
torch.set_printoptions(8)

from model import Model

model = Model()
model.eval()
inp = np.random.randn(1, 3, 224, 224).astype(np.float32)
with torch.no_grad():
  torch_outputs = model(torch.from_numpy(inp))

onnx_model = onnx.load("resnet18-v2-7.onnx")
sess_options = onnxruntime.SessionOptions()
session = onnxruntime.InferenceSession(onnx_model.SerializeToString(),
                                       sess_options)
inputs = {session.get_inputs()[0].name: inp}
ort_outputs = session.run(None, inputs)

print(
    "Comparison result:",
    np.allclose(torch_outputs.detach().numpy(),
                ort_outputs[0],
                atol=1e-5,
                rtol=1e-5))

Test

pytest onnx_pytorch/tests

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

onnx-pytorch-0.1.5.tar.gz (55.9 kB view details)

Uploaded Source

Built Distribution

onnx_pytorch-0.1.5-py3-none-any.whl (109.2 kB view details)

Uploaded Python 3

File details

Details for the file onnx-pytorch-0.1.5.tar.gz.

File metadata

  • Download URL: onnx-pytorch-0.1.5.tar.gz
  • Upload date:
  • Size: 55.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for onnx-pytorch-0.1.5.tar.gz
Algorithm Hash digest
SHA256 c3b9c20007c98470563c5ee423ac6606dcf70958d559d4f75bb99fc50696c24d
MD5 49f9fe179f529e50a10bbfb1085121d7
BLAKE2b-256 ada83d13a0432e8249c28d9a73db23c0651047e3e4a4a302a401376bad1afe00

See more details on using hashes here.

File details

Details for the file onnx_pytorch-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: onnx_pytorch-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 109.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for onnx_pytorch-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 706defc9f00bf18e576a55bed68121b3fa74751ad223e89d9b6b1d20168f735b
MD5 7e1a14d127cd818a1a0afbcdeb637a2c
BLAKE2b-256 40864f0079b63cdf66055fa39b617d2b0a0870135469ab5063e6541e29b1a23c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page