Skip to main content

Strassen Attention

Project description

Strassen Attention

Implementation of Strassen attention, from Kozachinskiy et al. of National Center of AI in Chile 🇨🇱

Install

$ pip install strassen-attention

Usage

import torch
from strassen_attention import strassen_attend

q = torch.randn(1, 8, 32, 16)
k = torch.randn(1, 8, 32, 16)
v = torch.randn(1, 8, 32, 16)

attended = strassen_attend(
    q,
    k,
    k.clone(),
    v,
    v.clone()
)

assert attended.shape == q.shape

For the multi-head attention module

import torch
from strassen_attention.strassen_mha import StrassenMHA

mha = StrassenMHA(dim = 512, causal = True)

tokens = torch.randn(1, 256, 512)

assert mha(tokens).shape == tokens.shape

Strassen attention transformer

import torch
import torch
from strassen_attention.strassen_transformer import StrassenTransformer

transformer = StrassenTransformer(dim = 512, depth = 4)

x = torch.randn(1, 16 * 16, 512)
assert transformer(x).shape == x.shape

Citations

@misc{kozachinskiy2025strassenattentionunlockingcompositional,
    title   = {Strassen Attention: Unlocking Compositional Abilities in Transformers Based on a New Lower Bound Method}, 
    author  = {Alexander Kozachinskiy and Felipe Urrutia and Hector Jimenez and Tomasz Steifer and Germán Pizarro and Matías Fuentes and Francisco Meza and Cristian B. Calderon and Cristóbal Rojas},
    year    = {2025},
    eprint  = {2501.19215},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG},
    url     = {https://arxiv.org/abs/2501.19215}, 
}
@article{Peng2024OnLO,
    title     = {On Limitations of the Transformer Architecture},
    author    = {Binghui Peng and Srini Narayanan and Christos Papadimitriou},
    journal   = {ArXiv},
    year      = {2024},
    volume    = {abs/2402.08164},
    url       = {https://api.semanticscholar.org/CorpusID:267636545}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

strassen_attention-0.1.5.tar.gz (37.0 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

strassen_attention-0.1.5-py3-none-any.whl (7.1 kB view details)

Uploaded Python 3

File details

Details for the file strassen_attention-0.1.5.tar.gz.

File metadata

  • Download URL: strassen_attention-0.1.5.tar.gz
  • Upload date:
  • Size: 37.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.23

File hashes

Hashes for strassen_attention-0.1.5.tar.gz
Algorithm Hash digest
SHA256 f43e363cb47d193343dad07ff4ec3c8343c6d733973cb10e2773045ff5b6c212
MD5 ff99555ec222cdc2ec8f78b9849e1abb
BLAKE2b-256 144c4e19651cfaa8bc96e7fb0808eed957052ea80daec4908a487a12db4ce5e5

See more details on using hashes here.

File details

Details for the file strassen_attention-0.1.5-py3-none-any.whl.

File metadata

File hashes

Hashes for strassen_attention-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 827eb2d83b00fbee1ad9f3d24d7d12ee7809b5840b466fa6dfcdba597b7a397a
MD5 57f0c8f16a51f9048fa4c6f907f21511
BLAKE2b-256 feae096c7d0031668770fde5482757d06ee1b84595f38998a8056433d5370cbd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page