Skip to main content

differential-transformer - Pytorch

Project description

Differential Transformer

An open source community implementation of the model from "DIFFERENTIAL TRANSFORMER" paper by Microsoft. Paper Link. "Differential attention takes the difference between two softmax attention functions to eliminate attention noise. The idea is analogous to differential amplifiers [19] proposed in electrical engineering,where the difference between two signals is used as output, so that we can null out the common-mode noise of the input. In addition, the design of noise-canceling headphones is based on a similar idea. We can directly reuse FlashAttention [8] as described in Appendix A, which significantly improves model efficiency."

Join our Discord Subscribe on YouTube Connect on LinkedIn Follow on X.com

Install

$ pip3 install differential-transformers

Usage Transformer

import torch
from differential_transformer.main import DifferentialTransformer
from loguru import logger

# Example usage:
# Example dimensions
batch_size = 32
seq_len = 128
embedding_dim = 64
h = 8
λ = 0.1
λinit = 0.05

# Create random input tensor
x = torch.randint(0, 256, (1, 1024))

# Instantiate and run the multi-head attention
multi_head = DifferentialTransformer(heads=h, dim=embedding_dim, λinit=λinit)
output = multi_head(x, λ=λ)

logger.info(f"Output shape: {output.shape}")

License

MIT

Citation

@misc{ye2024differentialtransformer,
    title={Differential Transformer}, 
    author={Tianzhu Ye and Li Dong and Yuqing Xia and Yutao Sun and Yi Zhu and Gao Huang and Furu Wei},
    year={2024},
    eprint={2410.05258},
    archivePrefix={arXiv},
    primaryClass={cs.CL},
    url={https://arxiv.org/abs/2410.05258}, 
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

differential_transformer-0.0.3.tar.gz (5.2 kB view hashes)

Uploaded Source

Built Distribution

differential_transformer-0.0.3-py3-none-any.whl (5.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page