Skip to main content

Universal Tensor Operations in Einstein-Inspired Notation for Python

Project description

einx - Universal Tensor Operations in Einstein-Inspired Notation

pytest Documentation PyPI version Python 3.8+

einx is a Python library that provides a universal interface to formulate tensor operations in frameworks such as Numpy, PyTorch, Jax and Tensorflow. The design is based on the following principles:

  1. Provide a set of elementary tensor operations following Numpy-like naming: einx.{sum|max|where|add|dot|flip|get_at|...}
  2. Use einx notation to express vectorization of the elementary operations. einx notation is inspired by einops, but introduces several novel concepts such as []-bracket notation and full composability that allow using it as a universal language for tensor operations.

einx can be integrated and mixed with existing code seamlessly. All operations are just-in-time compiled into regular Python functions using Python's exec() and invoke operations from the respective framework.

Getting started:

Installation

pip install einx

See Installation for more information.

What does einx look like?

Tensor manipulation

import einx
x = {np.asarray|torch.as_tensor|jnp.asarray|...}(...) # Create some tensor

einx.sum("a [b]", x)                              # Sum-reduction along second axis
einx.flip("... (g [c])", x, c=2)                  # Flip pairs of values along the last axis
einx.mean("b [s...] c", x)                        # Spatial mean-pooling
einx.sum("b (s [s2])... c", x, s2=2)              # Sum-pooling with kernel_size=stride=2
einx.add("a, b -> a b", x, y)                     # Outer sum

einx.get_at("b [h w] c, b i [2] -> b i c", x, y)  # Gather values at coordinates

einx.rearrange("b (q + k) -> b q, b k", x, q=2)   # Split
einx.rearrange("b c, 1 -> b (c + 1)", x, [42])    # Append number to each channel

                                                  # Apply custom operations:
einx.vmap("b [s...] c -> b c", x, op=np.mean)     # Spatial mean-pooling
einx.vmap("a [b], [b] c -> a c", x, y, op=np.dot) # Matmul

All einx functions simply forward computation to the respective backend, e.g. by internally calling np.reshape, np.transpose, np.sum with the appropriate arguments.

Common neural network operations

# Layer normalization
mean = einx.mean("b... [c]", x, keepdims=True)
var = einx.var("b... [c]", x, keepdims=True)
x = (x - mean) * torch.rsqrt(var + epsilon)

# Prepend class token
einx.rearrange("b s... c, c -> b (1 + (s...)) c", x, cls_token)

# Multi-head attention
attn = einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=8)
attn = einx.softmax("b q [k] h", attn)
x = einx.dot("b q k h, b k (h c) -> b q (h c)", attn, v)

# Matmul in linear layers
einx.dot("b...      [c1->c2]",  x, w)              # - Regular
einx.dot("b...   (g [c1->c2])", x, w)              # - Grouped: Same weights per group
einx.dot("b... ([g c1->g c2])", x, w)              # - Grouped: Different weights per group
einx.dot("b  [s...->s2]  c",    x, w)              # - Spatial mixing as in MLP-mixer

See Common neural network ops for more examples.

Optional: Deep learning modules

import einx.nn.{torch|flax|haiku|equinox|keras} as einn

batchnorm       = einn.Norm("[b...] c", decay_rate=0.9)
layernorm       = einn.Norm("b... [c]") # as used in transformers
instancenorm    = einn.Norm("b [s...] c")
groupnorm       = einn.Norm("b [s...] (g [c])", g=8)
rmsnorm         = einn.Norm("b... [c]", mean=False, bias=False)

channel_mix     = einn.Linear("b... [c1->c2]", c2=64)
spatial_mix1    = einn.Linear("b [s...->s2] c", s2=64)
spatial_mix2    = einn.Linear("b [s2->s...] c", s=(64, 64))
patch_embed     = einn.Linear("b (s [s2->])... [c1->c2]", s2=4, c2=64)

dropout         = einn.Dropout("[...]",       drop_rate=0.2)
spatial_dropout = einn.Dropout("[b] ... [c]", drop_rate=0.2)
droppath        = einn.Dropout("[b] ...",     drop_rate=0.2)

See examples/train_{torch|flax|haiku|equinox|keras}.py for example trainings on CIFAR10, GPT-2 and Mamba for working example implementations of language models using einx, and Tutorial: Neural networks for more details.

Just-in-time compilation

einx traces the required backend operations for a given call into graph representation and just-in-time compiles them into a regular Python function using Python's exec(). This reduces overhead to a single cache lookup and allows inspecting the generated function. For example:

>>> x = np.zeros((3, 10, 10))
>>> graph = einx.sum("... (g [c])", x, g=2, graph=True)
>>> print(graph)
import numpy as np
def op0(i0):
    x0 = np.reshape(i0, (3, 10, 2, 5))
    x1 = np.sum(x0, axis=3)
    return x1

See Just-in-time compilation for more details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

einx-0.3.0.tar.gz (84.8 kB view details)

Uploaded Source

Built Distribution

einx-0.3.0-py3-none-any.whl (103.0 kB view details)

Uploaded Python 3

File details

Details for the file einx-0.3.0.tar.gz.

File metadata

  • Download URL: einx-0.3.0.tar.gz
  • Upload date:
  • Size: 84.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.4

File hashes

Hashes for einx-0.3.0.tar.gz
Algorithm Hash digest
SHA256 17ff87c6a0f68ab358c1da489f00e95f1de106fd12ff17d0fb3e210aaa1e5f8c
MD5 86954b3b50240ab25ba6f891cd9fc310
BLAKE2b-256 95af2a2f83f981e969ae3ec5dc30f9b0cd1a258acabc2ff7b33eb9726e334e55

See more details on using hashes here.

File details

Details for the file einx-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: einx-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 103.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.4

File hashes

Hashes for einx-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 367d62bab8dbb8c4937308512abb6f746cc0920990589892ba0d281356d39345
MD5 b39e9fe21fd0829d6514f9fcc54b8920
BLAKE2b-256 90044a730d74fd908daad86d6b313f235cdf8e0cf1c255b392b7174ff63ea81a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page