Skip to main content

Universal Notation for Tensor Operations in Python

Project description

einx - Universal Notation for Tensor Operations

pytest Documentation PyPI version Python 3.10+

einx is a notation and Python library that provides a universal interface to formulate tensor operations in frameworks such as Numpy, PyTorch, Jax, Tensorflow, and MLX.

Quickstart

Installation:

pip install einx

Example code:

import einx
import numpy as np

x = np.ones((10, 20, 30)) # Create some tensor (numpy/torch/jax/tensorflow/mlx/...)

y = einx.sum("a [b] c", x) # Call an einx operation

print(y.shape)

Documentation and tutorials: https://einx.readthedocs.io

What does einx look like?

z = einx.id("a (b c) -> (b a) c", x, b=2)             # Permute and (un)flatten axes
z = einx.sum("a [b]", x)                              # Sum-reduction along second axis
z = einx.flip("... (g [c])", x, c=2)                  # Flip pairs of values along the last axis
z = einx.mean("b [...] c", x)                         # Spatial mean-pooling
z = einx.multiply("a..., b... -> (a b)...", x, y)     # Kronecker product
z = einx.sum("b (s [ds])... c", x, ds=(2, 2))         # Sum-pooling with 2x2 kernel
z = einx.add("a, b -> a b", x, y)                     # Outer sum
z = einx.dot("a [b], [b] c -> a c", x, y)             # Matrix multiplication
z = einx.get_at("b [h w] c, b i [2] -> b i c", x, y)  # Gather values at coordinates
z = einx.id("b (q + k) -> b q, b k", x, q=2)          # Split
z = einx.id("b c, -> b (c + 1)", x, 42)               # Append number to each channel

See the documentation for more examples.

How does the notation work?

An einx operation consists of (1) an elementary operation and (2) an einx expression that describes how the elementary operation is vectorized. For example, the code

z = einx.{OP}("[c d] a, b -> a [e] b", x, y)

vectorizes the elementary operation {OP} according to the expression "[c d] a, b -> a [e] b".

The meaning of the string expression is defined by analogy with loop notation as follows. The full operation einx.{OP} will yield the same output as if the elementary operation {OP} were invoked in an analogous loop expression:

for a in range(...):
    for b in range(...):
        z[a, :, b] = {OP}(x[:, :, a], y[b])

See the tutorial for how an einx expression is mapped to the analogous loop expression.

How are einx operations implemented?

The analogy with loop notation is used only to define what the output of an operation will be. Internally, einx operations are compiled to Python code snippets that invoke operations from the respective tensor framework, rather than using for loops.

The compiled code snippet can be inspected by passing graph=True to the einx operation. For example:

>>> x = np.zeros((2, 3))
>>> y = np.zeros((3, 4))
>>> code = einx.add("a b, b c -> c b a", x, y, graph=True)
>>> print(code)

import numpy as np
def op(a, b):
    a = np.transpose(a, (1, 0))
    a = np.reshape(a, (1, 3, 2))
    b = np.transpose(b, (1, 0))
    b = np.reshape(b, (4, 3, 1))
    c = np.add(a, b)
    return c

Different backends may be used to compile an operation to different implementations, for example following Numpy-like notation, vmap-based notation, or einsum notation.

Which operations are supported?

Operations in the API: einx supports a large set of tensor operations in the namespace einx.*, including reduction, scalar, indexing, some shape-preserving operations, identity map and dot-product. See the documentation for a complete list.

Operations not in the API: einx additionally allows adapting custom Python functions to einx notation using einx adapters. For example:

# Define a custom elementary operation
def myoperation(x, y):
    x = 2 * x
    z = x + torch.sum(y)
    return z

# Adapt the operation to einx notation
einmyoperation = einx.torch.adapt_with_vmap(myoperation)

# Invoke as einx operation
z = einmyoperation("a [c], b [c] -> a b [c]", x, y)

This will yield the same output as if myoperation were invoked in loop notation:

for a in range(...):
    for b in range(...):
        z[a, b, :] = myoperation(x[a, :], y[b, :])

The interface of einmyoperation matches that of other einx operations. For example, the compiled code snippet can be inspected using graph=True:

>>> code = einmyoperation("a [c], b [c] -> a b [c]", x, y, graph=True)
>>> print(code)
# Constant const1: <function myoperation at 0x49e9aa3cd8a0>
import torch
def op(a, b):
    def c(d, e):
        f = const1(d, e)
        assert isinstance(f, torch.Tensor), "Expected 1st return value of the adapted function to be a tensor"
        assert (tuple(f.shape) == (3,)), "Expected 1st return value of the adapted function to be a tensor with shape (3,)"
        return f
    c = torch.vmap(c, in_dims=(None, 0), out_dims=0)
    c = torch.vmap(c, in_dims=(0, None), out_dims=0)
    g = c(a, b)
    return g

See the documentation for a list of supported adapters.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

einx-0.4.3.tar.gz (120.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

einx-0.4.3-py3-none-any.whl (163.8 kB view details)

Uploaded Python 3

File details

Details for the file einx-0.4.3.tar.gz.

File metadata

  • Download URL: einx-0.4.3.tar.gz
  • Upload date:
  • Size: 120.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for einx-0.4.3.tar.gz
Algorithm Hash digest
SHA256 be7d81ea1908b9f00e4a467840998fc483c33aa32aaaaa3ada6c8386f693edf9
MD5 3cfb0ea290b339b75f72a5ffb87cc776
BLAKE2b-256 8496df2cfa7418b175dddcf30a88711d01b79a32c3ac4d64b379ed89a3de2c08

See more details on using hashes here.

File details

Details for the file einx-0.4.3-py3-none-any.whl.

File metadata

  • Download URL: einx-0.4.3-py3-none-any.whl
  • Upload date:
  • Size: 163.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for einx-0.4.3-py3-none-any.whl
Algorithm Hash digest
SHA256 47ce54a0144f6dffcfacdd8fe2cc9e2e5e6485dda2471330ab75ee747dd22f39
MD5 630dd0dd2832987cfc72171f3f668150
BLAKE2b-256 d0286768d342b2b888f9facdddbd430daf015cab936e2598281ab436b1be1b4a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page