Skip to main content

Universal Notation for Tensor Operations in Python

Project description

einx - Universal Notation for Tensor Operations

pytest Documentation PyPI version Python 3.10+

einx is a notation and Python library that provides a universal interface to formulate tensor operations in frameworks such as Numpy, PyTorch, Jax, Tensorflow, and MLX.

Quickstart

Installation:

pip install einx

Example code:

import einx
import numpy as np

x = np.ones((10, 20, 30)) # Create some tensor (numpy/torch/jax/tensorflow/mlx/...)

y = einx.sum("a [b] c", x) # Call an einx operation

print(y.shape)

Documentation and tutorials: https://einx.readthedocs.io

What does einx look like?

z = einx.id("a (b c) -> (b a) c", x, b=2)             # Permute and (un)flatten axes
z = einx.sum("a [b]", x)                              # Sum-reduction along second axis
z = einx.flip("... (g [c])", x, c=2)                  # Flip pairs of values along the last axis
z = einx.mean("b [...] c", x)                         # Spatial mean-pooling
z = einx.multiply("a..., b... -> (a b)...", x, y)     # Kronecker product
z = einx.sum("b (s [ds])... c", x, ds=(2, 2))         # Sum-pooling with 2x2 kernel
z = einx.add("a, b -> a b", x, y)                     # Outer sum
z = einx.dot("a [b], [b] c -> a c", x, y)             # Matrix multiplication
z = einx.get_at("b [h w] c, b i [2] -> b i c", x, y)  # Gather values at coordinates
z = einx.id("b (q + k) -> b q, b k", x, q=2)          # Split
z = einx.id("b c, -> b (c + 1)", x, 42)               # Append number to each channel

See the documentation for more examples.

How does the notation work?

An einx operation consists of (1) an elementary operation and (2) an einx expression that describes how the elementary operation is vectorized. For example, the code

z = einx.{OP}("[c d] a, b -> a [e] b", x, y)

vectorizes the elementary operation {OP} according to the expression "[c d] a, b -> a [e] b".

The meaning of the string expression is defined by analogy with loop notation as follows. The full operation einx.{OP} will yield the same output as if the elementary operation {OP} were invoked in an analogous loop expression:

for a in range(...):
    for b in range(...):
        z[a, :, b] = {OP}(x[:, :, a], y[b])

See the tutorial for how an einx expression is mapped to the analogous loop expression.

How are einx operations implemented?

The analogy with loop notation is used only to define what the output of an operation will be. Internally, einx operations are compiled to Python code snippets that invoke operations from the respective tensor framework, rather than using for loops.

The compiled code snippet can be inspected by passing graph=True to the einx operation. For example:

>>> x = np.zeros((2, 3))
>>> y = np.zeros((3, 4))
>>> code = einx.add("a b, b c -> c b a", x, y, graph=True)
>>> print(code)

import numpy as np
def op(a, b):
    a = np.transpose(a, (1, 0))
    a = np.reshape(a, (1, 3, 2))
    b = np.transpose(b, (1, 0))
    b = np.reshape(b, (4, 3, 1))
    c = np.add(a, b)
    return c

Different backends may be used to compile an operation to different implementations, for example following Numpy-like notation, vmap-based notation, or einsum notation.

Which operations are supported?

Operations in the API: einx supports a large set of tensor operations in the namespace einx.*, including reduction, scalar, indexing, some shape-preserving operations, identity map and dot-product. See the documentation for a complete list.

Operations not in the API: einx additionally allows adapting custom Python functions to einx notation using einx adapters. For example:

# Define a custom elementary operation
def myoperation(x, y):
    x = 2 * x
    z = x + torch.sum(y)
    return z

# Adapt the operation to einx notation
einmyoperation = einx.torch.adapt_with_vmap(myoperation)

# Invoke as einx operation
z = einmyoperation("a [c], b [c] -> a b [c]", x, y)

This will yield the same output as if myoperation were invoked in loop notation:

for a in range(...):
    for b in range(...):
        z[a, b, :] = myoperation(x[a, :], y[b, :])

The interface of einmyoperation matches that of other einx operations. For example, the compiled code snippet can be inspected using graph=True:

>>> code = einmyoperation("a [c], b [c] -> a b [c]", x, y, graph=True)
>>> print(code)
# Constant const1: <function myoperation at 0x49e9aa3cd8a0>
import torch
def op(a, b):
    def c(d, e):
        f = const1(d, e)
        assert isinstance(f, torch.Tensor), "Expected 1st return value of the adapted function to be a tensor"
        assert (tuple(f.shape) == (3,)), "Expected 1st return value of the adapted function to be a tensor with shape (3,)"
        return f
    c = torch.vmap(c, in_dims=(None, 0), out_dims=0)
    c = torch.vmap(c, in_dims=(0, None), out_dims=0)
    g = c(a, b)
    return g

See the documentation for a list of supported adapters.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

einx-0.4.2.tar.gz (114.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

einx-0.4.2-py3-none-any.whl (139.9 kB view details)

Uploaded Python 3

File details

Details for the file einx-0.4.2.tar.gz.

File metadata

  • Download URL: einx-0.4.2.tar.gz
  • Upload date:
  • Size: 114.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for einx-0.4.2.tar.gz
Algorithm Hash digest
SHA256 9aeea7053a4f683235e81b62383069085cc919a012748c8575f3557455377e11
MD5 426aa3ed081886bde8fd58f4a898791e
BLAKE2b-256 36b32b0acbde4f763b72f65da58cea54f35386f537569c4c39f3d37de1c03710

See more details on using hashes here.

File details

Details for the file einx-0.4.2-py3-none-any.whl.

File metadata

  • Download URL: einx-0.4.2-py3-none-any.whl
  • Upload date:
  • Size: 139.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for einx-0.4.2-py3-none-any.whl
Algorithm Hash digest
SHA256 c3d01878ffe34f2e038a14b98f82b4b47bba32ee2f127618ab31851e439ba2f0
MD5 ec2fef0d86065f7eb70bcfc3bb2c73e9
BLAKE2b-256 01977afa0d833cae4d8e651b8265b7c95613199d9749a4976149b5fd66052a06

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page