Skip to main content

Floyd Multi-Head Attention: a drop-in variant of PyTorch MHA with module and function APIs

Project description

FloydNet

License Python PyTorch

Official implementation of an ICLR paper (TODO: add paper title, authors, and links/arXiv).

Figure Pivotal Attention Mechanism for 2-Floyd/3-Floyd.

This repository serves two audiences:

  • Engineering users: Reusable PyTorch components (functional attention APIs and Transformer-style blocks) under src/.
  • Research users: Scripts/configs to reproduce paper experiments (TSP, Graph Isomorphism, BREC) under example/.

Introduction

FloydNet is the official PyTorch implementation accompanying an ICLR paper (TODO).
The repository provides:

  1. Reusable components: a drop-in attention/Transformer-block interface intended for integration into existing projects.
  2. Reproduction code: end-to-end training/evaluation pipelines to reproduce the benchmarks reported in the paper.

For algorithmic details, hyperparameter choices, and analysis, please refer to the paper (TODO: link).


Repository Structure

  • src/floydnet/
    Library code for reuse
    Contains the functional attention API and module/block implementations.

  • example/
    Experiment reproduction code
    Includes benchmark-specific scripts, configs, and data preparation utilities.


Installation

Option A: Install from PyPI

pip install floydnet

Option B: Install from source

git clone git@github.com:ocx-lab/FloydNet.git
cd FloydNet
pip install -e .

Requirements: Python >= 3.9, PyTorch >= 2.1 (see pyproject.toml).

Public API

FloydNet re-exports the public API from src/floydnet/__init__.py, so you can import from the top-level package:

  • Functional API:
    • pivotal_attention (see src/floydnet/functional.py)
  • Module / block API:
    • PivotalAttentionBlock (see src/floydnet/transformer.py)
from floydnet import pivotal_attention, PivotalAttentionBlock

Minimal usage example

import torch
from floydnet import pivotal_attention, PivotalAttentionBlock

# -------------------------
# Module API (Transformer-style block)
# Input is a 2D grid: (B, N, N, C)
# -------------------------
B, N, C = 2, 16, 64
x = torch.randn(B, N, N, C)

m = PivotalAttentionBlock(embed_dim=C, num_heads=8, dropout=0.0)
out = m(x)  # (B, N, N, C)
print(out.shape)

# -------------------------
# Functional API
# All inputs are 5D: (B, H, N, N, D)
# -------------------------
B, H, N, D = 2, 8, 16, 64
q_ik = torch.randn(B, H, N, N, D)
k_ij = torch.randn(B, H, N, N, D)
k_jk = torch.randn(B, H, N, N, D)
v_ij = torch.randn(B, H, N, N, D)
v_jk = torch.randn(B, H, N, N, D)

y = pivotal_attention(q_ik, k_ij, k_jk, v_ij, v_jk)  # (B, H, N, N, D)
print(y.shape)

Reproducing Paper Results

This section targets research users who want to reproduce the experiments in the paper.

See example/README.md For detailed description.

Environment setup

We recommend using uv to create an isolated environment for the reproduction code under example/.

cd /path/to/FloydNet

# 1) Create a uv virtual environment with Python 3.12
uv venv --python 3.12

# 2) Activate it
source .venv/bin/activate

# 3) Install extra dependencies for reproducing paper experiments
uv pip install -r example/requirements.txt

# 4) Install FloydNet (editable) for local development / imports
uv pip install -e .

Changelog (latest)

  • Full release with training and evaluation scripts for Graph Count, BREC, and TSP.
  • Added pivotal_attention3 functional API for 3-Floyd attention.
  • Added additional configuration options in PivotalAttentionBlock.

The full changelog is in CHANGELOG.md.

Citation

If you use this code in your research, please cite the paper:

@inproceedings{TODO,
  title     = {TODO},
  author    = {TODO},
  booktitle = {International Conference on Learning Representations (ICLR)},
  year      = {TODO},
  url       = {TODO}
}

(Alternatively, see CITATION.cff.)


License

This project is licensed under the Apache License 2.0. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

floydnet-0.1.2.tar.gz (15.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

floydnet-0.1.2-py3-none-any.whl (15.8 kB view details)

Uploaded Python 3

File details

Details for the file floydnet-0.1.2.tar.gz.

File metadata

  • Download URL: floydnet-0.1.2.tar.gz
  • Upload date:
  • Size: 15.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for floydnet-0.1.2.tar.gz
Algorithm Hash digest
SHA256 6626476fc0583ccdaa8cdba01226b7d4a23a855842683da9f633d77d7ff8e8de
MD5 13ac022fb70e64b23c490b85230c4dba
BLAKE2b-256 ea3d77added48d107a76748c9908ed2154a3f817cfe85395fe62cfa202f79672

See more details on using hashes here.

File details

Details for the file floydnet-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: floydnet-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 15.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for floydnet-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 584daa5466901582d348653b99e925806e6d2cafa238dad714b75331260388ed
MD5 2117a5a5e1cb64802db4ec6356d303f4
BLAKE2b-256 5de7f4a6eb19577f9f4102e4863ad57aca0f269e55fa8a53515fbc36b37410ff

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page