Array (and numpy) API for ONNX
Project description
onnx-array-api implements a numpy API for ONNX. It gives the user the ability to convert functions written following the numpy API to convert that function into ONNX as well as to execute it.
import numpy as np
from onnx_array_api.npx import absolute, jit_onnx
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
def l1_loss(x, y):
return absolute(x - y).sum()
def l2_loss(x, y):
return ((x - y) ** 2).sum()
def myloss(x, y):
return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
jitted_myloss = jit_onnx(myloss)
x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
res = jitted_myloss(x, y)
print(res)
print(onnx_simple_text_plot(jitted_myloss.get_onnx()))
It supports eager mode as well:
import numpy as np
from onnx_array_api.npx import absolute, eager_onnx
def l1_loss(x, y):
err = absolute(x - y).sum()
print(f"l1_loss={err.numpy()}")
return err
def l2_loss(x, y):
err = ((x - y) ** 2).sum()
print(f"l2_loss={err.numpy()}")
return err
def myloss(x, y):
return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
eager_myloss = eager_onnx(myloss)
x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
res = eager_myloss(x, y)
print(res)
l1_loss=[0.04] l2_loss=[0.002] [0.042]
The library is released on pypi/onnx-array-api and its documentation is published at `
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
No source distribution files available for this release.See tutorial on generating distribution archives.
Built Distribution
File details
Details for the file onnx_array_api-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: onnx_array_api-0.1.1-py3-none-any.whl
- Upload date:
- Size: 64.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 23f70a059def814e878d6f47bfd1abd4fd65defa452d53b8a4e91a97516703d5 |
|
MD5 | 29ff3ba38459221de824b26da03a0cb2 |
|
BLAKE2b-256 | 9826abb3222ecaae347ed1dc1e71af82aeb500cc540ad798ffa405503dbcd345 |