Einsum optimization using opt_einsum and PyTorch FX
Project description
opt_einsum_fx
Optimizing einsums and functions involving them using opt_einsum
and PyTorch FX compute graphs.
Issues, questions, PRs, and any thoughts about further optimizing these kinds of operations are welcome!
For more information please see the docs.
Installation
PyPI
The latest release can be installed from PyPI:
$ pip install opt_einsum_fx
Source
To get the latest code, run:
$ git clone https://github.com/Linux-cpp-lisp/opt_einsum_fx.git
and install it by running
$ cd opt_einsum_fx/
$ pip install .
You can run the tests with
$ pytest tests/
Minimal example
import torch
import torch.fx
import opt_einsum_fx
def einmatvecmul(a, b, vec):
"""Batched matrix-matrix-vector product using einsum"""
return torch.einsum("zij,zjk,zk->zi", a, b, vec)
graph_mod = torch.fx.symbolic_trace(einmatvecmul)
print("Original code:\n", graph_mod.code)
graph_opt = opt_einsum_fx.optimize_einsums_full(
model=graph_mod,
example_inputs=(
torch.randn(7, 4, 5),
torch.randn(7, 5, 3),
torch.randn(7, 3)
)
)
print("Optimized code:\n", graph_opt.code)
outputs
Original code:
import torch
def forward(self, a, b, vec):
einsum_1 = torch.functional.einsum('zij,zjk,zk->zi', a, b, vec); a = b = vec = None
return einsum_1
Optimized code:
import torch
def forward(self, a, b, vec):
einsum_1 = torch.functional.einsum('cb,cab->ca', vec, b); vec = b = None
einsum_2 = torch.functional.einsum('cb,cab->ca', einsum_1, a); einsum_1 = a = None
return einsum_2
We can measure the performance improvement (this is on a CPU):
from torch.utils.benchmark import Timer
batch = 1000
a, b, vec = torch.randn(batch, 4, 5), torch.randn(batch, 5, 8), torch.randn(batch, 8)
g = {"f": graph_mod, "a": a, "b": b, "vec": vec}
t_orig = Timer("f(a, b, vec)", globals=g)
print(t_orig.timeit(10_000))
g["f"] = graph_opt
t_opt = Timer("f(a, b, vec)", globals=g)
print(t_opt.timeit(10_000))
gives ~2x improvement:
f(a, b, vec)
276.58 us
1 measurement, 10000 runs , 1 thread
f(a, b, vec)
118.84 us
1 measurement, 10000 runs , 1 thread
Depending on your function and dimensions you may see even larger improvements.
License
opt_einsum_fx
is distributed under an MIT license.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for opt_einsum_fx-0.1.4-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 85f489f4c7c31fd88d5faf9669c09e61ec37a30098809fdcfe2a08a9e42f23c9 |
|
MD5 | 3cbf29fec8fb43e10954d4b7b4d64f42 |
|
BLAKE2b-256 | 8d4ce0370709aaf9d7ceb68f975cac559751e75954429a77e83202e680606560 |