Skip to main content

Generalized Optimal Transport Attention with Trainable Priors (PyTorch).

Project description

GOAT Attention

This repository provides the code for Generalized Optimal transport Attention with Trainable priors (GOAT). This is provided as a PyTorch multi-head attention module.

Installation

  • From PyPI (recommended):
uv add goat-attention
  • uv (editable, for development):
uv pip install -e .
  • pip (editable, for development):
pip install -e .

Quickstart

import torch
from goat import GoatAttention

B, L, S, E, H = 2, 5, 7, 64, 8
xq = torch.randn(B, L, E)
xk = torch.randn(B, S, E)
xv = torch.randn(B, S, E)

attn = GoatAttention(
    embed_dim=E,
    num_heads=H,
    batch_first=True,
    pos_rank=2,
    abs_rank=4,
    enable_key_bias=True,
)

out, weights = attn(xq, xk, xv, is_causal=False, need_weights=True)
print(out.shape, None if weights is None else weights.shape)

CLI

After installation:

goat info
goat smoke

Documentation

See docs/ (MkDocs-ready markdown):

  • docs/index.md
  • docs/usage.md
  • docs/api.md
  • docs/development.md

Development

uv pip install -e ".[dev]"
pytest

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

goat_attention-0.1.0.tar.gz (18.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

goat_attention-0.1.0-py3-none-any.whl (20.1 kB view details)

Uploaded Python 3

File details

Details for the file goat_attention-0.1.0.tar.gz.

File metadata

  • Download URL: goat_attention-0.1.0.tar.gz
  • Upload date:
  • Size: 18.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for goat_attention-0.1.0.tar.gz
Algorithm Hash digest
SHA256 b5c24577169f62d186e11c9e49c5db458cdddecaea7635a35cfccb0fbdad0386
MD5 815b9fea4f4fb9445990b6220cae4ada
BLAKE2b-256 b84566c3620d39e37783d6d9acaae850c4023b3859f8f99ffbc45cd412ee98c1

See more details on using hashes here.

Provenance

The following attestation bundles were made for goat_attention-0.1.0.tar.gz:

Publisher: publish.yml on elonlit/goat

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file goat_attention-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: goat_attention-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 20.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for goat_attention-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ec9bb4815a177d8027da8ed444da02bdc845114f75adb3306ca2763a5604aa71
MD5 607ffbfb11e39ada20064825679b3966
BLAKE2b-256 667dce555ca9107e95cbd2a5e06e88eeb4fea0faecb37534215b639595a92599

See more details on using hashes here.

Provenance

The following attestation bundles were made for goat_attention-0.1.0-py3-none-any.whl:

Publisher: publish.yml on elonlit/goat

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page