Skip to main content

Tensor Operations Expressed in Einstein-Inspired Notation

Project description

einx - Tensor Operations in Einstein-Inspired Notation

pytest Documentation PyPI version Python 3.8+

einx is a Python library that allows formulating many tensor operations as concise expressions using Einstein notation. It is inspired by einops, but follows a novel and unique design:

  • Fully composable and powerful Einstein expressions with []-notation.
  • Support for many tensor operations (einx.{sum|max|where|add|dot|flip|get_at|...}) with Numpy-like naming.
  • Easy integration and mixing with existing code. Supports tensor frameworks Numpy, PyTorch, Tensorflow and Jax.
  • Just-in-time compilation of all operations into regular Python functions using Python's exec().

Optional:

  • Generalized neural network layers in Einstein notation. Supports PyTorch, Flax, Haiku, Equinox and Keras.

Getting started:

Installation

pip install einx

See Installation for more information.

What does einx look like?

Tensor manipulation

import einx
x = {np.asarray|torch.as_tensor|jnp.asarray|tf.convert_to_tensor}(...) # Create some tensor

einx.sum("a [b]", x)                              # Sum-reduction along columns
einx.flip("... (g [c])", x, c=2)                  # Flip pairs of values along the last axis
einx.mean("b [s...] c", x)                        # Global mean-pooling
einx.sum("b (s [s2])... c", x, s2=2)              # Sum-pooling with kernel_size=stride=2
einx.add("b... [c]", x, b)                        # Add bias

einx.get_at("b [h w] c, b i [2] -> b i c", x, y)  # Gather values at coordinates

einx.rearrange("b (q + k) -> b q, b k", x, q=2)   # Split
einx.rearrange("b c, 1 -> b (c + 1)", x, [42])    # Append number to each channel

einx.dot("... [c1|c2]", x, y)                     # Matmul = linear map from c1 to c2 channels

# Vectorizing map
einx.vmap("b [s...] c -> b c", x, op=np.mean)     # Global mean-pooling
einx.vmap("a [b], [b] c -> a c", x, y, op=np.dot) # Matmul

All einx functions simply forward computation to the respective backend, e.g. by internally calling np.reshape, np.transpose, np.sum with the appropriate arguments.

Common neural network operations

# Layer normalization
mean = einx.mean("b... [c]", x, keepdims=True)
var = einx.var("b... [c]", x, keepdims=True)
x = (x - mean) * torch.rsqrt(var + epsilon)

# Prepend class token
einx.rearrange("b s... c, c -> b (1 + (s...)) c", x, cls_token)

# Multi-head attention
attn = einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=8)
attn = einx.softmax("b q [k] h", attn)
x = einx.dot("b q k h, b k (h c) -> b q (h c)", attn, v)

# Matmul in linear layers
einx.dot("b...      [c1|c2]",  x, w)              # - Regular
einx.dot("b...   (g [c1|c2])", x, w)              # - Grouped: Same weights per group
einx.dot("b... ([g c1|g c2])", x, w)              # - Grouped: Different weights per group
einx.dot("b  [s...|s2]  c",    x, w)              # - Spatial mixing as in MLP-mixer

See Common neural network ops for more examples.

Deep learning modules

import einx.nn.{torch|flax|haiku|equinox|keras} as einn

batchnorm       = einn.Norm("[b...] c", decay_rate=0.9)
layernorm       = einn.Norm("b... [c]") # as used in transformers
instancenorm    = einn.Norm("b [s...] c")
groupnorm       = einn.Norm("b [s...] (g [c])", g=8)
rmsnorm         = einn.Norm("b... [c]", mean=False, bias=False)

channel_mix     = einn.Linear("b... [c1|c2]", c2=64)
spatial_mix1    = einn.Linear("b [s...|s2] c", s2=64)
spatial_mix2    = einn.Linear("b [s2|s...] c", s=(64, 64))
patch_embed     = einn.Linear("b (s [s2|])... [c1|c2]", s2=4, c2=64)

dropout         = einn.Dropout("[...]",       drop_rate=0.2)
spatial_dropout = einn.Dropout("[b] ... [c]", drop_rate=0.2)
droppath        = einn.Dropout("[b] ...",     drop_rate=0.2)

See examples/train_{torch|flax|haiku|equinox|keras}.py for example trainings on CIFAR10, GPT-2 and Mamba for working example implementations of language models using einx, and Tutorial: Neural networks for more details.

Just-in-time compilation

einx traces the required backend operations for a given call into graph representation and just-in-time compiles them into a regular Python function using Python's exec(). This reduces overhead to a single cache lookup and allows inspecting the generated function. For example:

>>> x = np.zeros((3, 10, 10))
>>> graph = einx.sum("... (g [c])", x, g=2, graph=True)
>>> print(graph)
# backend: einx.backend.numpy
def op0(i0):
    x1 = backend.reshape(i0, (3, 10, 2, 5))
    x0 = backend.sum(x1, axis=3)
    return x0

See Just-in-time compilation for more details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

einx-0.1.3.tar.gz (64.6 kB view details)

Uploaded Source

Built Distribution

einx-0.1.3-py3-none-any.whl (88.0 kB view details)

Uploaded Python 3

File details

Details for the file einx-0.1.3.tar.gz.

File metadata

  • Download URL: einx-0.1.3.tar.gz
  • Upload date:
  • Size: 64.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for einx-0.1.3.tar.gz
Algorithm Hash digest
SHA256 f85d46193246517d5fe3455cf9e5f5e6bf4c7a159864c83b5d20605e8fb8701d
MD5 3d7230f28ea7a3dbd7c4920f7d781693
BLAKE2b-256 14715f57e76b19a5d48d4f8e28202fd6adf742cf52caf7bc1462772c455eea8f

See more details on using hashes here.

File details

Details for the file einx-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: einx-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 88.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.2

File hashes

Hashes for einx-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 b614e6749dfd8dc24eb35c555b4bd7bd7cdf09b2e740943d115fb6cc9d21bded
MD5 f4fab75ded1d1509bf0873d7f0692e53
BLAKE2b-256 ecb5fdb2fe8d49bf812b0e3f5ee32c4ae53abe98cb6e7cd5013c7987f0fe36c4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page