Skip to main content

Deep Cross Attention Language Model

Project description

Deep Cross Attention

Implementation of the proposed DeepCrossAttention by Mike Heddes while at Google research, in Pytorch

My analysis is although I still prefer Hyper Connections, they have an important idea here that I have been trying concurrently. Mainly the queries, keys, values can be routed from different layers of the past. The reason this is cool is because it generalizes the recent value residual learning improvement. It may (or may not) also address an issue for neural memories

Appreciation

  • Minh Hoang for spotting some issues with the GRN

Install

$ pip install deep-cross-attention

Usage

import torch
from deep_cross_attention import DCAGPT

gpt = DCAGPT(
    num_tokens = 256,
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    past_layers_k = 2
)

ids = torch.randint(0, 256, (2, 4096))

logits = gpt(ids) # (2, 4096, 256)

Example

First

$ pip install .[examples]

Next

$ python train.py

Citations

@inproceedings{Heddes2025DeepCrossAttentionST,
    title   = {DeepCrossAttention: Supercharging Transformer Residual Connections},
    author  = {Mike Heddes and Adel Javanmard and Kyriakos Axiotis and Gang Fu and MohammadHossein Bateni and Vahab S. Mirrokni},
    year    = {2025},
    url     = {https://api.semanticscholar.org/CorpusID:276250576}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deep_cross_attention-0.1.1.tar.gz (6.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

deep_cross_attention-0.1.1-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file deep_cross_attention-0.1.1.tar.gz.

File metadata

File hashes

Hashes for deep_cross_attention-0.1.1.tar.gz
Algorithm Hash digest
SHA256 69b26a96168b28c7f59fce83c3d63d097f54df61cda623c789df169b3609cfb9
MD5 c46e9af05617a69db978f2ef8327a119
BLAKE2b-256 fa38f01ac08faba268f6ae89699dbaa94b8c46fd0b43983001a63f7e00839dfc

See more details on using hashes here.

File details

Details for the file deep_cross_attention-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for deep_cross_attention-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 74f9483a12ebe836386e69dc84a8b4abee8818defc246bb33ac920d13f05eb71
MD5 0a729ea865876b7101107e38a5dad576
BLAKE2b-256 9e7b32e9d5ba04915d10b2d8ac692047f350981fcd1cc9ca0cd63a1217beaebd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page