Tensor Operations Expressed in Einstein-Inspired Notation
Project description
einx - Tensor Operations in Einstein-Inspired Notation
einx is a Python library that allows formulating many tensor operations as concise expressions using Einstein notation. It is inspired by einops, but follows a novel and unique design:
- Fully composable and powerful Einstein expressions with
[]
-notation. - Support for many tensor operations (
einx.{sum|max|where|add|dot|flip|get_at|...}
) with Numpy-like naming. - Easy integration and mixing with existing code. Supports tensor frameworks Numpy, PyTorch, Tensorflow, Jax and others.
- Just-in-time compilation of all operations into regular Python functions using Python's
exec()
.
Optional:
- Generalized neural network layers in Einstein notation. Supports PyTorch, Flax, Haiku, Equinox and Keras.
Getting started:
Installation
pip install einx
See Installation for more information.
What does einx look like?
Tensor manipulation
import einx
x = {np.asarray|torch.as_tensor|jnp.asarray|...}(...) # Create some tensor
einx.sum("a [b]", x) # Sum-reduction along second axis
einx.flip("... (g [c])", x, c=2) # Flip pairs of values along the last axis
einx.mean("b [s...] c", x) # Spatial mean-pooling
einx.sum("b (s [s2])... c", x, s2=2) # Sum-pooling with kernel_size=stride=2
einx.add("a, b -> a b", x, y) # Outer sum
einx.get_at("b [h w] c, b i [2] -> b i c", x, y) # Gather values at coordinates
einx.rearrange("b (q + k) -> b q, b k", x, q=2) # Split
einx.rearrange("b c, 1 -> b (c + 1)", x, [42]) # Append number to each channel
einx.dot("... [c1->c2]", x, y) # Matmul = linear map from c1 to c2 channels
# Apply custom operations:
einx.vmap("b [s...] c -> b c", x, op=np.mean) # Global mean-pooling
einx.vmap("a [b], [b] c -> a c", x, y, op=np.dot) # Matmul
All einx functions simply forward computation to the respective backend, e.g. by internally calling np.reshape
, np.transpose
, np.sum
with the appropriate arguments.
Common neural network operations
# Layer normalization
mean = einx.mean("b... [c]", x, keepdims=True)
var = einx.var("b... [c]", x, keepdims=True)
x = (x - mean) * torch.rsqrt(var + epsilon)
# Prepend class token
einx.rearrange("b s... c, c -> b (1 + (s...)) c", x, cls_token)
# Multi-head attention
attn = einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=8)
attn = einx.softmax("b q [k] h", attn)
x = einx.dot("b q k h, b k (h c) -> b q (h c)", attn, v)
# Matmul in linear layers
einx.dot("b... [c1->c2]", x, w) # - Regular
einx.dot("b... (g [c1->c2])", x, w) # - Grouped: Same weights per group
einx.dot("b... ([g c1->g c2])", x, w) # - Grouped: Different weights per group
einx.dot("b [s...->s2] c", x, w) # - Spatial mixing as in MLP-mixer
See Common neural network ops for more examples.
Deep learning modules
import einx.nn.{torch|flax|haiku|equinox|keras} as einn
batchnorm = einn.Norm("[b...] c", decay_rate=0.9)
layernorm = einn.Norm("b... [c]") # as used in transformers
instancenorm = einn.Norm("b [s...] c")
groupnorm = einn.Norm("b [s...] (g [c])", g=8)
rmsnorm = einn.Norm("b... [c]", mean=False, bias=False)
channel_mix = einn.Linear("b... [c1->c2]", c2=64)
spatial_mix1 = einn.Linear("b [s...->s2] c", s2=64)
spatial_mix2 = einn.Linear("b [s2->s...] c", s=(64, 64))
patch_embed = einn.Linear("b (s [s2->])... [c1->c2]", s2=4, c2=64)
dropout = einn.Dropout("[...]", drop_rate=0.2)
spatial_dropout = einn.Dropout("[b] ... [c]", drop_rate=0.2)
droppath = einn.Dropout("[b] ...", drop_rate=0.2)
See examples/train_{torch|flax|haiku|equinox|keras}.py
for example trainings on CIFAR10, GPT-2 and Mamba for working example implementations of language models using einx, and Tutorial: Neural networks for more details.
Just-in-time compilation
einx traces the required backend operations for a given call into graph representation and just-in-time compiles them into a regular Python function using Python's exec()
. This reduces overhead to a single cache lookup and allows inspecting the generated function. For example:
>>> x = np.zeros((3, 10, 10))
>>> graph = einx.sum("... (g [c])", x, g=2, graph=True)
>>> print(graph)
import numpy as np
def op0(i0):
x0 = np.reshape(i0, (3, 10, 2, 5))
x1 = np.sum(x0, axis=3)
return x1
See Just-in-time compilation for more details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file einx-0.2.1.tar.gz
.
File metadata
- Download URL: einx-0.2.1.tar.gz
- Upload date:
- Size: 78.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d5ebecdf54dc9327d761441c05fb80b8d299f2f425e5d032e627ac2bde61b531 |
|
MD5 | da59da055a546a35456032851c6e9865 |
|
BLAKE2b-256 | 261d8b9713ff42423032577b6d886bf7638fc856348ccb21cb2f8c6efe9bbf52 |
File details
Details for the file einx-0.2.1-py3-none-any.whl
.
File metadata
- Download URL: einx-0.2.1-py3-none-any.whl
- Upload date:
- Size: 98.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | df66c68ab4eaf593b70252eaca34a48a44d711ce7733b4a779ce783387d062dc |
|
MD5 | ed363288f3b7f670a2e9f1cf47019c69 |
|
BLAKE2b-256 | 4e42280fe2424e39b5611d017033567a065f79a2eaa43cf02f236aa07ccda448 |