Typed tensor-shaping, masking, padding, device-routing, and checkpoint utilities for PyTorch and `einops`. A superset of `lucidrains/torch-einops-utils` with similar utilities from other lucidrains packages. Adds strict typing, extensive tests, and comprehensive docstrings.
Project description
torch_einops_kit
Typed tensor-shaping, masking, padding, device-routing, and checkpoint utilities for PyTorch and einops.
This repository is a superset of lucidrains/torch-einops-utils. The upstream repository is a compact collection of small utilities that show up repeatedly in lucidrains model repositories. torch_einops_kit keeps that role. The main difference is emphasis. This fork adds roughly 6000 lines of tests, typing, and docstrings so the utility layer is easier to trust, easier to search, and easier to apply correctly.
torch_einops_kit is most useful when combined with other lucidrains repositories. Repositories such as dreamer4, metacontroller, mimic-video, pi-zero-pytorch, sdft-pytorch, and locoformer repeatedly need operations such as align_dims_left, shape_with_replace, lens_to_mask, pad_sequence, safe_cat, and pack_with_inverse. This package centralizes those operations in one typed import surface instead of re-implementing the same tensor utility layer in each model repository.
If you already know torch-einops-utils, torch_einops_kit began as a typed substitute for that package and has since grown into a superset. In addition to everything from torch-einops-utils, this repository centralizes small utility functions that appear repeatedly in other lucidrains model repositories but were never collected in one place, such as l2norm, once, pack_one, and unpack_one. The function family remains the same kind: small PyTorch and einops helpers for shape work, masks, padding, optional tensors, PyTree traversal, device routing, and checkpoint reconstruction. The import path is torch_einops_kit, not torch_einops_utils. The relationship is conceptual, not literal import-path compatibility.
Use torch_einops_kit when you want strict typing, a py.typed marker, focused modules, extensive tests, and docstrings written for both humans and AI assistants. Use upstream when you want the most compact possible version of the same idea.
At a glance
- Project name:
torch_einops_kit. - Import path:
torch_einops_kit. - Python requirement:
>=3.10. - Runtime dependencies:
torch,einops, andtyping-extensions. - Root package exports: helper functions, slicing helpers, rank-alignment helpers, mask helpers, safe concatenation helpers, padding helpers, normalization helpers, and PyTree /
einopshelpers. - Submodules with dedicated imports:
torch_einops_kit.device,torch_einops_kit.einops,torch_einops_kit.save_load, andtorch_einops_kit.scaleValues. - Typing status: the package ships a
py.typedmarker and the repository uses strict type checking. - Best fit: lucidrains-style model repositories that work with variable-length tensors,
einopspatterns, optional intermediate tensors, and nestedtorch.nn.Modulegraphs.
Installation
uv
uv add torch_einops_kit
pip
pip install torch_einops_kit
Import map
Import most tensor helpers from the package root:
from torch_einops_kit import (
align_dims_left,
and_masks,
broadcast_cat,
l2norm,
lens_to_mask,
masked_mean,
maybe,
once,
or_masks,
pad_sequence_and_cat,
pad_sequence,
safe_cat,
safe_stack,
shape_with_replace,
slice_at_dim,
tree_flatten_with_inverse,
tree_map_tensor,
)
Import einops pack and unpack helpers from torch_einops_kit.einops:
from torch_einops_kit.einops import (
pack_one,
pack_with_inverse,
unpack_one,
)
Import device decorators from torch_einops_kit.device:
from torch_einops_kit.device import (
module_device,
move_inputs_to_device,
move_inputs_to_module_device,
)
Import checkpoint decorators from torch_einops_kit.save_load:
from torch_einops_kit.save_load import (
dehydrate_config,
rehydrate_config,
save_load,
)
Quick examples
Batch variable-length tensors and build a mask
This example comes directly from the test suite.
import torch
from torch_einops_kit import lens_to_mask, pad_sequence
x = torch.randn(2, 4, 5)
y = torch.randn(2, 3, 5)
z = torch.randn(2, 1, 5)
packed, lens = pad_sequence([x, y, z], dim=1, return_lens=True)
mask = lens_to_mask(lens)
assert packed.shape == (3, 2, 4, 5)
assert lens.tolist() == [4, 3, 1]
assert torch.allclose(mask.sum(dim=-1), lens)
Pack with einops and keep an inverse function
This example also comes directly from the test suite.
import torch
from torch_einops_kit import pack_with_inverse
t = torch.randn(3, 12, 2, 2)
packed, inverse = pack_with_inverse(t, "b * d")
assert packed.shape == (3, 24, 2)
restored = inverse(packed)
assert restored.shape == (3, 12, 2, 2)
Decorate a module with save / load helpers
This example is adapted from tests/test_save_load.py.
from pathlib import Path
from torch import Tensor, nn
from torch_einops_kit.save_load import save_load
@save_load()
class SimpleNet(nn.Module):
def __init__(self, dim: int, hidden_dim: int) -> None:
super().__init__()
self.dim = dim
self.hidden_dim = hidden_dim
self.net = nn.Linear(dim, hidden_dim)
def forward(self, x: Tensor) -> Tensor:
return self.net(x)
path = Path("model.pt")
model = SimpleNet(10, 20)
model.save(str(path))
restored = SimpleNet.init_and_load(str(path))
assert restored.dim == 10
assert restored.hidden_dim == 20
Root API reference
Optional-value and structure helpers
| Name | Contract |
|---|---|
exists(v) |
Returns True exactly when v is not None. Falsy values such as 0, False, and [] still count as existing values. The return type is a TypeGuard, so static analyzers narrow the type after the check. |
default(v, d) |
Returns v when v exists. Returns d when v is None. |
compact(arr) |
Removes every None value from arr and returns a list of the remaining values. |
divisible_by(num, den) |
Returns False when den == 0. Otherwise returns whether num % den == 0. |
identity(t, *args, **kwargs) |
Returns t unchanged and ignores every extra argument. Useful as a no-op callable. |
first(arr) |
Returns arr[0]. Use first when the sequence supports integer indexing and index 0 is valid. |
map_values(fn, v) |
Recursively applies fn to every leaf value inside nested list, tuple, and dict structures. Container shape is preserved. |
maybe(fn) |
Wraps fn so the wrapped callable returns None when the first argument is None. When fn itself is None, maybe(None) returns identity. |
once(fn) |
Wraps fn so the wrapped callable executes at most once. On the first call the function runs and returns its result. On every subsequent call the wrapper returns None without calling fn. |
safe(fn) |
Decorator for callables whose first argument is Sequence[Tensor]. The decorator removes None values from that first argument before the call. If no tensors remain, the decorator returns None instead of calling fn. |
Shape and slicing helpers
| Name | Contract |
|---|---|
shape_with_replace(t, replace_dict=None) |
Returns a derived torch.Size based on t.shape. The function does not modify t. Keys in replace_dict must be non-negative integers less than t.ndim. |
slice_at_dim(t, slc, dim=-1) |
Applies slc to one dimension and preserves every other dimension. Negative dim values are normalized before indexing. |
slice_left_at_dim(t, length, dim=-1) |
Keeps the first length values along dim. When length == 0, the function returns an empty slice along dim. |
slice_right_at_dim(t, length, dim=-1) |
Keeps the last length values along dim. When length == 0, the function returns an empty slice along dim. |
Rank-alignment and singleton-dimension helpers
These functions change shape by inserting singleton dimensions. These functions do not add numeric padding values.
| Name | Contract |
|---|---|
pad_ndim(t, (left, right)) |
Returns a reshaped view of t with left leading singleton dimensions and right trailing singleton dimensions. Raises ValueError when either count is negative. |
pad_left_ndim(t, ndims) |
Prepends ndims singleton dimensions. |
pad_right_ndim(t, ndims) |
Appends ndims singleton dimensions. |
pad_left_ndim_to(t, ndims) |
Ensures that t.ndim >= ndims by prepending singleton dimensions when needed. Returns t unchanged when t.ndim >= ndims. |
pad_right_ndim_to(t, ndims) |
Ensures that t.ndim >= ndims by appending singleton dimensions when needed. Returns t unchanged when t.ndim >= ndims. |
align_dims_left(tensors, ndim=None) |
Aligns every tensor to a shared rank by appending trailing singleton dimensions. Existing dimensions stay left-aligned. When ndim is None, the target rank is the maximum rank across the input tensors. |
Mask helpers
| Name | Contract |
|---|---|
lens_to_mask(lens, max_len=None) |
Converts integer length values to a boolean mask with shape (*lens.shape, max_len). Position i in the final axis is True when i < lens[...]. |
reduce_masks(masks, op) |
Filters out None values, then reduces the remaining masks left-to-right with op. Returns None when no non-None masks remain. |
and_masks(masks) |
Equivalent to reduce_masks(masks, torch.logical_and). Returns None when no active mask remains. |
or_masks(masks) |
Equivalent to reduce_masks(masks, torch.logical_or). Returns None when no active mask remains. |
Concatenation and stacking helpers
| Name | Contract |
|---|---|
broadcast_cat(tensors, dim=0) |
Broadcasts tensor groups before concatenation. |
safe_stack(tensors, dim=0) |
Removes None values, then applies torch.stack. Returns None when no active tensor remains. Even a single tensor is stacked, so the result gains a new dimension. |
safe_cat(tensors, dim=0) |
Removes None values, then applies torch.cat. Returns None when no active tensor remains. A single surviving tensor is returned unchanged because torch.cat over one tensor is identity. |
Numeric padding and batching helpers
These functions add numeric padding values along an existing tensor dimension.
| Name | Contract |
|---|---|
pad_at_dim(t, (left, right), dim=-1, value=0.0) |
Pads or trims t along one dimension. Positive values add elements. Negative values trim elements. |
pad_left_at_dim(t, pad, dim=-1, value=0.0) |
Prepends pad values along dim. |
pad_right_at_dim(t, pad, dim=-1, value=0.0) |
Appends pad values along dim. |
pad_left_at_dim_to(t, length, dim=-1, value=0.0) |
Ensures that t.shape[dim] >= length by left-padding when needed. Returns t unchanged when the target length is already satisfied. |
pad_right_at_dim_to(t, length, dim=-1, value=0.0) |
Ensures that t.shape[dim] >= length by right-padding when needed. Returns t unchanged when the target length is already satisfied. |
pad_sequence(tensors, ...) |
Pads every tensor in tensors to the maximum length along dim. The function can return a stacked tensor or a list of padded tensors. The function can also return original lengths or padding widths. Returns None for empty input. |
pad_sequence_and_cat(tensors, dim_cat=0, dim=-1, value=0.0, left=False) |
Equivalent to pad_sequence(..., return_stacked=False) followed by torch.cat(..., dim=dim_cat). Returns None for empty input. |
pad_sequence return modes
pad_sequence is overloaded. The return type depends on return_stacked and return_lens.
return_stacked |
return_lens |
Return type |
|---|---|---|
True |
False |
Tensor | None |
True |
True |
tuple[Tensor, Tensor] | None |
False |
False |
list[Tensor] | None |
False |
True |
tuple[list[Tensor], Tensor] | None |
When pad_lens=True and return_lens=True, the second tensor contains padding widths rather than original lengths.
Normalization and masked reduction helpers
| Name | Contract |
|---|---|
l2norm(t) |
Normalizes each vector in t to unit length along the last dimension by dividing by its L2 norm. Delegates to torch.nn.functional.normalize with p=2 and dim=-1. |
masked_mean(t, mask=None, dim=None, eps=1e-5) |
Computes a masked mean. When mask is None, the function falls back to t.mean(...). When no masked position is selected and dim is None, the function returns zero by summing over the empty selection. When mask.ndim < t.ndim, the function right-pads mask rank before broadcasting. |
PyTree helpers
| Name | Contract |
|---|---|
tree_map_tensor(fn, tree) |
Applies fn to every tensor leaf in a PyTree and leaves non-tensor leaves unchanged. |
tree_flatten_with_inverse(tree) |
Returns a flat list of leaves and an inverse function that reconstructs the original PyTree shape from a replacement iterable of leaves. |
scaleValues submodule reference
The torch_einops_kit.scaleValues submodule contains vector normalization, masked mean computation, and the RMSNorm layer. l2norm and masked_mean are also re-exported from the package root.
| Name | Contract |
|---|---|
l2norm(t) |
Normalizes each vector in t to unit length along the last dimension. Delegates to torch.nn.functional.normalize with p=2 and dim=-1. |
masked_mean(t, mask=None, dim=None, eps=1e-5) |
Computes a masked mean. When mask is None, the function falls back to t.mean(...). When no masked position is selected and dim is None, the function returns zero by summing over the empty selection. When mask.ndim < t.ndim, the function right-pads mask rank before broadcasting. |
RMSNorm(dim) |
torch.nn.Module that normalizes the last feature axis to unit length, multiplies by √dim, and applies a learned per-feature gamma parameter. Use as a pre-normalization layer before attention, feedforward, or linear projection sublayers in transformer-style modules. |
einops submodule reference
The torch_einops_kit.einops submodule contains pack and unpack utilities with paired inverse functions.
| Name | Contract |
|---|---|
pack_one(t, pattern) |
Packs one tensor using an einops pattern and returns the packed tensor and shape metadata for paired reconstruction with unpack_one. |
pack_with_inverse(t, pattern) |
Wraps einops.pack and stores the corresponding inverse. Input t may be one tensor or a list[Tensor]. The inverse returns the matching kind and optionally accepts a different inv_pattern for unpacking. |
unpack_one(t, ps, pattern) |
Unpacks one tensor using packed-shape metadata produced by pack_one. |
device submodule reference
The torch_einops_kit.device submodule contains three utilities for device inference and automatic tensor movement.
| Name | Contract |
|---|---|
module_device(m) |
Returns the device of the first parameter or registered buffer in m. Returns None when m has neither parameters nor buffers. |
move_inputs_to_device(device) |
Decorator that recursively moves every tensor inside positional and keyword arguments to device before calling the wrapped function. Non-tensor values pass through unchanged. |
move_inputs_to_module_device(fn) |
Decorator for methods whose first argument is a torch.nn.Module. The decorator infers the target device with module_device(self) and moves every tensor argument after self to that device. If module_device(self) returns None, the call is a no-op. |
save_load submodule reference
The torch_einops_kit.save_load submodule contains the checkpoint decorator and the two advanced configuration helpers that support nested decorated modules.
| Name | Contract |
|---|---|
save_load(...) |
Class decorator for torch.nn.Module subclasses. The decorator records constructor arguments on each instance and adds instance save / load methods plus a classmethod that reconstructs a new instance from a checkpoint. |
dehydrate_config(config, config_instance_var_name) |
Walks nested list, tuple, and dict structures and replaces decorated module instances with reconstruction records. |
rehydrate_config(config) |
Walks nested list, tuple, and dict structures and reconstructs decorated modules from the stored reconstruction records. |
Default save_load method names
The default decorator call is @save_load(). That default adds:
save(path, overwrite=True)load(path, strict=True)init_and_load(path, strict=True)
The decorator can rename all three methods and can rename the instance attribute that stores constructor configuration.
Checkpoint contents
The generated checkpoint payload stores:
model: the result ofstate_dict().config: a pickled constructor-configuration payload.version: an optional version string.
Nested module reconstruction
Nested module graphs are a first-class use case. If one decorated module is passed as a constructor argument to another decorated module, save_load can dehydrate that nested module graph during save and rehydrate that nested module graph during init_and_load.
Version behavior
When both the checkpoint version and the current decorator version exist and differ, load emits UserWarning and still loads the checkpoint state.
Edge-case conventions
These conventions matter for both humans and AI assistants.
existstests onlyNone.0,False, empty strings, empty lists, and empty dictionaries still count as existing values.maybe(fn)short-circuits on the first argument only.maybe(fn)(None, ...)returnsNonewithout callingfn.safe,safe_cat,safe_stack,reduce_masks,and_masks, andor_masksall treatNonevalues as absent values rather than false values.safe_cat,safe_stack,reduce_masks,and_masks,or_masks,pad_sequence, andpad_sequence_and_catcan returnNonefor empty effective input. Generated code must account for that return path.pad_left_at_dim_to,pad_right_at_dim_to,pad_left_ndim_to, andpad_right_ndim_toreturn the original tensor unchanged when no growth is needed.shape_with_replaceaccepts only non-negative dimension keys.slice_at_dim,slice_left_at_dim, andslice_right_at_dimdo accept negativedimvalues.module_devicereturnsNonefor stateless modules such asnn.Identity().move_inputs_to_module_devicebecomes a no-op whenmodule_device(self)returnsNone.save_load.loadmay warn on version mismatch but still restore model state.save_load.init_and_loadrequires a checkpoint that contains serialized constructor configuration.
Typing and AI-assistant notes
This repository is written to be readable by humans, machine translation systems, search tools, and AI assistants.
- The package ships a
py.typedmarker. - The repository uses strict type checking in
pyrightconfig.json. existsusesTypeGuard, so a type checker can narrowT | NonetoTafter anexists(...)check.pad_sequence,pack_with_inverse, andmaybeuse overloads so type checkers can preserve useful return information.- The documentation keeps identifier names stable.
pad_sequencealways meanspad_sequence.safe_catalways meanssafe_cat. This consistency helps retrieval and generated code.
If you are an AI assistant adapting code from torch-einops-utils, use these translation rules:
- Replace
from torch_einops_utils import ...withfrom torch_einops_kit import ...for root helpers. - Replace upstream
save_loadimports withfrom torch_einops_kit.save_load import save_load. - Replace upstream device decorator imports with
from torch_einops_kit.device import .... - Do not erase
None-return paths with unconditional casts. Several helpers intentionally returnNonefor empty effective input. - Do not assume
pad_sequencealways returns a stacked tensor. Readreturn_stackedandreturn_lensfirst.
Differences from upstream torch-einops-utils
This repository is not a repackaged mirror of upstream. This repository makes a different trade-off.
- Upstream is intentionally compact.
- This fork splits the implementation across focused modules such as
_helpers.py,_padding.py,device.py, andsave_load.pywhile still re-exporting most tensor helpers from the package root. - This fork adds strict typing, a
py.typedmarker, extensive tests, and detailed docstrings. - This fork is best treated as a typed, documented branch of the same utility idea rather than a literal import-path-compatible drop-in replacement.
Repository layout
src/torch_einops_kit/— package source.src/torch_einops_kit/device.py— device inference and input-routing decorators.src/torch_einops_kit/save_load.py— checkpoint save / load decorator and nested reconstruction helpers.src/torch_einops_kit/scaleValues.py— vector normalization, masked mean, and theRMSNormlayer.tests/— regression tests and usage examples for helpers, masks, padding, device routing, and checkpoint reconstruction.
Development
Set up the repository:
uv sync
Run the test suite:
pytest
Run static analysis:
pyright
ruff check .
My recovery
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file torch_einops_kit-0.1.1.tar.gz.
File metadata
- Download URL: torch_einops_kit-0.1.1.tar.gz
- Upload date:
- Size: 51.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
88ec73470ac73b9b36081d64cd0192b9ea18fa565959e7a677231e540595166e
|
|
| MD5 |
d382294759ea1b21b766e21a9cde2134
|
|
| BLAKE2b-256 |
c1d0f9f3eb99eecf642d7247144b1b3fe1dacf07f1c98485e4278ba2207a22d4
|