Skip to main content

Einsum optimization using opt_einsum and PyTorch FX

Project description

opt_einsum_fx

Documentation Status

Optimizing einsums and functions involving them using opt_einsum and PyTorch FX compute graphs.

Issues, questions, PRs, and any thoughts about further optimizing these kinds of operations are welcome!

For more information please see the docs.

Installation

PyPI

The latest release can be installed from PyPI:

$ pip install opt_einsum_fx

Source

To get the latest code, run:

$ git clone https://github.com/Linux-cpp-lisp/opt_einsum_fx.git

and install it by running

$ cd opt_einsum_fx/
$ pip install .

You can run the tests with

$ pytest tests/

Minimal example

import torch
import torch.fx
import opt_einsum_fx

def einmatvecmul(a, b, vec):
    """Batched matrix-matrix-vector product using einsum"""
    return torch.einsum("zij,zjk,zk->zi", a, b, vec)

graph_mod = torch.fx.symbolic_trace(einmatvecmul)
print("Original code:\n", graph_mod.code)
graph_opt = opt_einsum_fx.optimize_einsums_full(
    model=graph_mod,
    example_inputs=(
        torch.randn(7, 4, 5),
        torch.randn(7, 5, 3),
        torch.randn(7, 3)
    )
)
print("Optimized code:\n", graph_opt.code)

outputs

Original code:
import torch
def forward(self, a, b, vec):
    einsum_1 = torch.functional.einsum('zij,zjk,zk->zi', a, b, vec);  a = b = vec = None
    return einsum_1

Optimized code:
import torch
def forward(self, a, b, vec):
    einsum_1 = torch.functional.einsum('cb,cab->ca', vec, b);  vec = b = None
    einsum_2 = torch.functional.einsum('cb,cab->ca', einsum_1, a);  einsum_1 = a = None
    return einsum_2

We can measure the performance improvement (this is on a CPU):

from torch.utils.benchmark import Timer

batch = 1000
a, b, vec = torch.randn(batch, 4, 5), torch.randn(batch, 5, 8), torch.randn(batch, 8)

g = {"f": graph_mod, "a": a, "b": b, "vec": vec}
t_orig = Timer("f(a, b, vec)", globals=g)
print(t_orig.timeit(10_000))

g["f"] = graph_opt
t_opt = Timer("f(a, b, vec)", globals=g)
print(t_opt.timeit(10_000))

gives ~2x improvement:

f(a, b, vec)
  276.58 us
  1 measurement, 10000 runs , 1 thread

f(a, b, vec)
  118.84 us
  1 measurement, 10000 runs , 1 thread

Depending on your function and dimensions you may see even larger improvements.

License

opt_einsum_fx is distributed under an MIT license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

opt_einsum_fx-0.1.4.tar.gz (13.0 kB view details)

Uploaded Source

Built Distribution

opt_einsum_fx-0.1.4-py3-none-any.whl (13.2 kB view details)

Uploaded Python 3

File details

Details for the file opt_einsum_fx-0.1.4.tar.gz.

File metadata

  • Download URL: opt_einsum_fx-0.1.4.tar.gz
  • Upload date:
  • Size: 13.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.5.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for opt_einsum_fx-0.1.4.tar.gz
Algorithm Hash digest
SHA256 7eeb7f91ecb70be65e6179c106ea7f64fc1db6319e3d1289a4518b384f81e74f
MD5 f3a2881b4e011d4dddb826b0b48706f5
BLAKE2b-256 93de856dab99be0360c7275fee075eb0450a2ec82a54c4c33689606f62e9615b

See more details on using hashes here.

File details

Details for the file opt_einsum_fx-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: opt_einsum_fx-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 13.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.5.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for opt_einsum_fx-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 85f489f4c7c31fd88d5faf9669c09e61ec37a30098809fdcfe2a08a9e42f23c9
MD5 3cbf29fec8fb43e10954d4b7b4d64f42
BLAKE2b-256 8d4ce0370709aaf9d7ceb68f975cac559751e75954429a77e83202e680606560

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page