Skip to main content

Strassen Attention

Project description

Strassen Attention (wip)

Implementation of Strassen attention, from Kozachinskiy et al. of National Center of AI in Chile

Install

$ pip install strassen-attention

Usage

import torch
from strassen_attention import strassen_attend

q = torch.randn(1, 8, 32, 16)
k = torch.randn(1, 8, 32, 16)
v = torch.randn(1, 8, 32, 16)

attended = strassen_attend(
    q,
    k,
    k.clone(),
    v,
    v.clone()
)

assert attended.shape == q.shape

For the multi-head attention module

from strassen_attention.strassen_mha import StrassenMHA

mha = StrassenMHA(dim = 512, causal = True)

tokens = torch.randn(1, 256, 512)

assert mha(tokens).shape == tokens.shape

Citations

@misc{kozachinskiy2025strassenattentionunlockingcompositional,
    title   = {Strassen Attention: Unlocking Compositional Abilities in Transformers Based on a New Lower Bound Method}, 
    author  = {Alexander Kozachinskiy and Felipe Urrutia and Hector Jimenez and Tomasz Steifer and Germán Pizarro and Matías Fuentes and Francisco Meza and Cristian B. Calderon and Cristóbal Rojas},
    year    = {2025},
    eprint  = {2501.19215},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG},
    url     = {https://arxiv.org/abs/2501.19215}, 
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

strassen_attention-0.0.9.tar.gz (463.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

strassen_attention-0.0.9-py3-none-any.whl (6.4 kB view details)

Uploaded Python 3

File details

Details for the file strassen_attention-0.0.9.tar.gz.

File metadata

  • Download URL: strassen_attention-0.0.9.tar.gz
  • Upload date:
  • Size: 463.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.23

File hashes

Hashes for strassen_attention-0.0.9.tar.gz
Algorithm Hash digest
SHA256 66b57c78bdee115520b1089a40b5046a3f9dc1e5c3b6b5bf36868bbdbee84d0c
MD5 788f8995551554f93dcda694833f15bc
BLAKE2b-256 1fa8f3e835a3c43dc68518d5506f127ad4a1794d47f2a9287c0d71f3b5b44b55

See more details on using hashes here.

File details

Details for the file strassen_attention-0.0.9-py3-none-any.whl.

File metadata

File hashes

Hashes for strassen_attention-0.0.9-py3-none-any.whl
Algorithm Hash digest
SHA256 9db8a8bf1bca287e72b57e4305b50239ddeffa5e1dc25573fb7c231cc256d205
MD5 368ec2ff260081b94d887e513a0f2368
BLAKE2b-256 e0180ae9eeb91ba49a22970c0ae5567294aeb7d46c4a764ebb7547619232e66a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page