differential-transformer - Pytorch
Project description
Differential Transformer
An open source community implementation of the model from "DIFFERENTIAL TRANSFORMER" paper by Microsoft. Paper Link. "Differential attention takes the difference between two softmax attention functions to eliminate attention noise. The idea is analogous to differential amplifiers [19] proposed in electrical engineering,where the difference between two signals is used as output, so that we can null out the common-mode noise of the input. In addition, the design of noise-canceling headphones is based on a similar idea. We can directly reuse FlashAttention [8] as described in Appendix A, which significantly improves model efficiency."
Install
$ pip3 install differential-transformers
Usage Transformer
import torch
from differential_transformer.main import DifferentialTransformer
from loguru import logger
# Example usage:
# Example dimensions
batch_size = 32
seq_len = 128
embedding_dim = 64
h = 8
λ = 0.1
λinit = 0.05
# Create random input tensor
x = torch.randint(0, 256, (1, 1024))
# Instantiate and run the multi-head attention
multi_head = DifferentialTransformer(heads=h, dim=embedding_dim, λinit=λinit)
output = multi_head(x, λ=λ)
logger.info(f"Output shape: {output.shape}")
License
MIT
Citation
@misc{ye2024differentialtransformer,
title={Differential Transformer},
author={Tianzhu Ye and Li Dong and Yuqing Xia and Yutao Sun and Yi Zhu and Gao Huang and Furu Wei},
year={2024},
eprint={2410.05258},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2410.05258},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for differential_transformer-0.0.3.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | cb6acb67ad9ee80802c4e9a410de65c483e6bb25f7aa993ae255204b3de141f5 |
|
MD5 | f713f28bd886136c55f691ac1e6278e5 |
|
BLAKE2b-256 | 97ee2c9081c5b6cbf3a903d4d2f4f832fc87380bef8dff0c5e60bbd1cda753b5 |
Hashes for differential_transformer-0.0.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c99c054b086f6dd668d6dda0b1f66b499e1ad80ad8824a9ca99d11efbf4b5ab2 |
|
MD5 | 1fa3d1ec48797864b86eb050b52b42b5 |
|
BLAKE2b-256 | c9d41eed50fd3298090d975e821df436a4dc572352d1ef978c9d118a4b26edba |