Skip to main content

Einsum optimization using opt_einsum and PyTorch FX

Project description

opt_einsum_fx

Documentation Status

Optimizing einsums and functions involving them using opt_einsum and PyTorch FX compute graphs.

Issues, questions, PRs, and any thoughts about further optimizing these kinds of operations are welcome!

For more information please see the docs.

Installation

PyPI

The latest release can be installed from PyPI:

$ pip install opt_einsum_fx

Source

To get the latest code, run:

$ git clone https://github.com/Linux-cpp-lisp/opt_einsum_fx.git

and install it by running

$ cd opt_einsum_fx/
$ pip install .

You can run the tests with

$ pytest tests/

Minimal example

import torch
import torch.fx
import opt_einsum_fx

def einmatvecmul(a, b, vec):
    """Batched matrix-matrix-vector product using einsum"""
    return torch.einsum("zij,zjk,zk->zi", a, b, vec)

graph_mod = torch.fx.symbolic_trace(einmatvecmul)
print("Original code:\n", graph_mod.code)
graph_opt = opt_einsum_fx.optimize_einsums_full(
    model=graph_mod,
    example_inputs=(
        torch.randn(7, 4, 5),
        torch.randn(7, 5, 3),
        torch.randn(7, 3)
    )
)
print("Optimized code:\n", graph_opt.code)

outputs

Original code:
import torch
def forward(self, a, b, vec):
    einsum_1 = torch.functional.einsum('zij,zjk,zk->zi', a, b, vec);  a = b = vec = None
    return einsum_1

Optimized code:
import torch
def forward(self, a, b, vec):
    einsum_1 = torch.functional.einsum('cb,cab->ca', vec, b);  vec = b = None
    einsum_2 = torch.functional.einsum('cb,cab->ca', einsum_1, a);  einsum_1 = a = None
    return einsum_2

We can measure the performance improvement (this is on a CPU):

from torch.utils.benchmark import Timer

batch = 1000
a, b, vec = torch.randn(batch, 4, 5), torch.randn(batch, 5, 8), torch.randn(batch, 8)

g = {"f": graph_mod, "a": a, "b": b, "vec": vec}
t_orig = Timer("f(a, b, vec)", globals=g)
print(t_orig.timeit(10_000))

g["f"] = graph_opt
t_opt = Timer("f(a, b, vec)", globals=g)
print(t_opt.timeit(10_000))

gives ~2x improvement:

f(a, b, vec)
  276.58 us
  1 measurement, 10000 runs , 1 thread

f(a, b, vec)
  118.84 us
  1 measurement, 10000 runs , 1 thread

Depending on your function and dimensions you may see even larger improvements.

License

opt_einsum_fx is distributed under an MIT license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

opt_einsum_fx-0.1.4.tar.gz (13.0 kB view hashes)

Uploaded Source

Built Distribution

opt_einsum_fx-0.1.4-py3-none-any.whl (13.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page