Skip to main content

Floyd Multi-Head Attention: a drop-in variant of PyTorch MHA with module and function APIs

Project description

FloydNet

Official implementation of an ICLR paper (TODO: add paper title, authors, and links/arXiv).

Figure Pivotal Attention Mechanism for 2-Floyd/3-Floyd.

This repository serves two audiences:

  • Engineering users: reusable PyTorch components (functional attention APIs and Transformer-style blocks) under src/.
  • Research users: scripts/configs to reproduce paper experiments under example/.

Introduction

FloydNet is the official PyTorch implementation accompanying an ICLR paper (TODO).
The repository provides:

  1. Reusable components: a drop-in attention/Transformer-block interface intended for integration into existing projects.
  2. Reproduction code: end-to-end training/evaluation pipelines to reproduce the benchmarks reported in the paper.

For algorithmic details, hyperparameter choices, and analysis, please refer to the paper (TODO: link).


Repository Structure

  • src/floydnet/
    Library code for reuse
    Contains the functional attention API and module/block implementations.

  • example/
    Experiment reproduction code
    Includes benchmark-specific scripts, configs, and data preparation utilities.


Using the Attention / Transformer Block

This section targets engineering users who want to import FloydNet as a dependency.

Installation

Option A: Install from PyPI

pip install floydnet

Option B: Install from source

git clone git@github.com:ocx-lab/FloydNet.git
cd FloydNet
pip install -e .

Requirements: Python >= 3.9, PyTorch >= 2.1 (see pyproject.toml).

Public API

FloydNet re-exports the public API from src/floydnet/__init__.py, so you can import from the top-level package:

  • Functional API:
    • pivotal_attention (see src/floydnet/functional.py)
  • Module / block API:
    • PivotalAttentionBlock (see src/floydnet/transformer.py)
from floydnet import pivotal_attention, PivotalAttentionBlock

Minimal usage example

import torch
from floydnet import pivotal_attention, PivotalAttentionBlock

# -------------------------
# Module API (Transformer-style block)
# Input is a 2D grid: (B, N, N, C)
# -------------------------
B, N, C = 2, 16, 64
x = torch.randn(B, N, N, C)

m = PivotalAttentionBlock(embed_dim=C, num_heads=8, dropout=0.0)
out = m(x)  # (B, N, N, C)
print(out.shape)

# -------------------------
# Functional API
# All inputs are 5D: (B, H, N, N, D)
# -------------------------
B, H, N, D = 2, 8, 16, 64
q_ik = torch.randn(B, H, N, N, D)
k_ij = torch.randn(B, H, N, N, D)
k_jk = torch.randn(B, H, N, N, D)
v_ij = torch.randn(B, H, N, N, D)
v_jk = torch.randn(B, H, N, N, D)

y = pivotal_attention(q_ik, k_ij, k_jk, v_ij, v_jk)  # (B, H, N, N, D)
print(y.shape)

Reproducing Paper Results

This section targets research users who want to reproduce the experiments in the paper.

See example/README.md For detailed description.

Environment setup

We recommend using uv to create an isolated environment for the reproduction code under example/.

cd /path/to/FloydNet

# 1) Create a uv virtual environment with Python 3.12
uv venv --python 3.12

# 2) Activate it
source .venv/bin/activate

# 3) Install extra dependencies for reproducing paper experiments
uv pip install -r example/requirements.txt

# 4) Install FloydNet (editable) for local development / imports
uv pip install -e .

Changelog (latest)

  • Full release with training and evaluation scripts for Graph Count, BREC, and TSP.
  • Added pivotal_attention3 functional API for 3-Floyd attention.
  • Added additional configuration options in PivotalAttentionBlock.

The full changelog is in CHANGELOG.md.

Citation

If you use this code in your research, please cite the paper:

@inproceedings{TODO,
  title     = {TODO},
  author    = {TODO},
  booktitle = {International Conference on Learning Representations (ICLR)},
  year      = {TODO},
  url       = {TODO}
}

(Alternatively, see CITATION.cff.)


License

This project is licensed under the Apache License 2.0. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

floydnet-0.1.1.tar.gz (11.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

floydnet-0.1.1-py3-none-any.whl (11.1 kB view details)

Uploaded Python 3

File details

Details for the file floydnet-0.1.1.tar.gz.

File metadata

  • Download URL: floydnet-0.1.1.tar.gz
  • Upload date:
  • Size: 11.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.26 {"installer":{"name":"uv","version":"0.9.26","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for floydnet-0.1.1.tar.gz
Algorithm Hash digest
SHA256 987ccfd7a483ece544fd1859e0c3b13b994fade07f99042f6c76f4291da1cef6
MD5 b06296eb9abdf3cc7a8b2f154b681d4e
BLAKE2b-256 ae42ac213eb99ff05f3bcdd4ceba12a0fc32fa846f762ffaf8f6b53f142ac499

See more details on using hashes here.

File details

Details for the file floydnet-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: floydnet-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 11.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.26 {"installer":{"name":"uv","version":"0.9.26","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for floydnet-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 596c1a3d9ba7870db246455eb4aa2f8af6e1ea5c80f55a53bce37f5eb6ac674d
MD5 b678350103ae6b8dd3a00c9ac1cc7d3b
BLAKE2b-256 d263e40053c24335271f1ac3d0c9178868473fca3021f7de6cc3234919488f7c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page