project_description
Project description
dilated-attention-pytorch
(Unofficial) Implementation of DilatedAttention
from LongNet: Scaling Transformers to 1,000,000,000 Tokens in PyTorch.
Install
NOTE: This library depends on facebookresearch/xformers. If you're not using torch>=2.0.0
, you may need to install it from source. See their installation instructions.
PyPI:
TODO: Publish to PyPI
pip install dilated-attention-pytorch
From source:
pip install "dilated-attention-pytorch @ git+ssh://git@github.com/fkodom/dilated-attention-pytorch.git"
For contributors:
# Install all dev dependencies (tests etc.)
pip install "dilated-attention-pytorch[all] @ git+ssh://git@github.com/fkodom/dilated-attention-pytorch.git"
# Setup pre-commit hooks
pre-commit install
Usage
DilatedAttention
The LongNet paper introduces a new attention mechanism called DilatedAttention
. It is a drop-in replacement (see below) for "vanilla" attention that allows for much longer sequences to be processed.
NOTE:
DilatedAttention
only supportsbatch_first=True
. This is different from "vanilla" attention in PyTorch, which supports bothbatch_first=True
andbatch_first=False
.
Arguments:
segment_lengths
(required,list[int]
): Length of each attention segment. This is usually a geometric sequence increasing in powers of 2, such as[2048, 4096, 8192]
.dilation_rates
(required,list[int]
): Dilation rate for each segment. Like withsegment_lengths
, this is usually a geometric sequence increasing in powers of 2, such as[1, 2, 4]
.
import torch
from dilated_attention_pytorch.dilated_attention import DilatedAttention
dilated_attention = DilatedAttention(
segment_lengths=[2048, 4096, 8192],
dilation_rates=[1, 2, 4],
)
# shape: (batch_size, seq_len, num_heads, embed_dim)
# NOTE: 'seq_len' must be a multiple of 8192 (the largest segment length)
# NOTE: For best performance, use 'dtype=torch.float16' or `dtype=torch.bfloat16`
query = torch.randn(1, 8192, 8, 64, device="cuda", dtype=torch.float16)
key = torch.randn(1, 8192, 8, 64, device="cuda", dtype=torch.float16)
value = torch.randn(1, 8192, 8, 64, device="cuda", dtype=torch.float16)
out = dilated_attention(query, key, value, is_causal=False) # default: causal=False
print(out.shape)
# torch.Size([1, 8192, 8, 64])
MultiheadDilatedAttention
MultiheadDilatedAttention
is a drop-in replacement (see below) for nn.MultiheadAttention
that uses DilatedAttention
instead of "vanilla" attention. It also incorporates improvements from the MAGNETO architecture (nn.LayerNorm
placements), as mentioned in the LongNet paper.
NOTE:
MultiheadDilatedAttention
only supportsbatch_first=True
. This is different fromnn.MultiheadAttention
, which supports bothbatch_first=True
andbatch_first=False
.
Arguments:
segment_lengths
(required,list[int]
): Length of each attention segment. This is usually a geometric sequence increasing in powers of 2, such as[2048, 4096, 8192]
.dilation_rates
(required,list[int]
): Dilation rate for each segment. Like withsegment_lengths
, this is usually a geometric sequence increasing in powers of 2, such as[1, 2, 4]
.- Many of the same arguments from
nn.MultiheadAttention
. See theMultiheadDilatedAttention
class for more details.
from dilated_attention_pytorch.dilated_attention import MultiheadDilatedAttention
device = torch.device("cuda")
dtype = torch.float16
embed_dim = 512
# NOTE: Omitting most of the optional arguments for brevity
mhda = MultiheadDilatedAttention(
embed_dim=embed_dim,
num_heads=8,
segment_lengths=[2048, 4096, 8192],
dilation_rates=[1, 2, 4],
device=device, # optional
dtype=dtype, # optional
)
# shape: (batch_size, seq_len, embed_dim)
# NOTE: 'seq_len' must be a multiple of 8192 (the largest segment length)
x = torch.randn(1, 8192, embed_dim, device=device, dtype=dtype)
y = mhda(x, x, x, is_causal=False) # default: is_causal=False
print(y.shape)
# torch.Size([1, 8192, 512])
LongNet
The LongNet paper culminates in a transformer architecture, which can be trained for language modeling with very long context windows. I have implemented two LongNet
variants, based on the base configurations from the paper:
LongNetLM
- designed specifically for language modelingLongNet
- a more general encoder-decoder architecture, which is not specific to language modeling
Based on these implementations, it is fairly straightforward to adapt LongNet
to encoder- or decoder-only architectures, as needed for specific applications.
from dilated_attention_pytorch.long_net import LongNetLM, LongNet
device = torch.device("cuda")
dtype = torch.float16
# NOTE: Showing all default values, which are described in the paper.
net = LongNet(
d_model=768,
nhead=12,
num_encoder_layers=12,
num_decoder_layers=12,
dim_feedforward=3072,
segment_lengths=[2048, 4096, 8192, 16384, 32768],
dilation_rates=[1, 2, 4, 6, 12],
dropout=0.0,
activation="relu",
layer_norm_eps=1e-5,
device=device,
dtype=dtype,
)
# shape: (batch_size, seq_len, d_model)
x = torch.randn(1, 32768, 768, device=device, dtype=dtype)
with torch.no_grad():
y = net.forward(x, is_causal=True) # default: is_causal=True
print(y.shape)
# torch.Size([1, 32768, 768])
num_tokens = 10000 # (required) usually obtained from the tokenizer
lm = LongNetLM(
num_tokens=num_tokens,
d_model=768,
nhead=12,
num_encoder_layers=12,
num_decoder_layers=12,
dim_feedforward=3072,
segment_lengths=[2048, 4096, 8192, 16384, 32768],
dilation_rates=[1, 2, 4, 6, 12],
dropout=0.0,
activation="relu",
layer_norm_eps=1e-5,
device=device,
dtype=dtype,
)
# shape: (batch_size, seq_len)
x = torch.randint(0, num_tokens, (1, 32768), device=device, dtype=torch.long)
with torch.no_grad():
y = lm.forward(x, is_causal=True) # default: is_causal=True
print(y.shape)
# torch.Size([1, 32768, num_tokens])
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for dilated-attention-pytorch-0.1.0rc1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 20be016e4dbfa46e95866b3329ada7b904d9da9dcd114c2daa7a25fe41fd6002 |
|
MD5 | faf60dfe145f8dea7d6c82d8d8e491b1 |
|
BLAKE2b-256 | 54a5045e379eff3a9cdf10dd78218ac69a1a6986cb380dbdd6a4d62f2ba06c18 |
Hashes for dilated_attention_pytorch-0.1.0rc1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a097a4626947f41cecb3390dbd2a966c51388f2476dd320e4d92596af476c926 |
|
MD5 | ad030fdd2f50670db787741738af711c |
|
BLAKE2b-256 | d2763996b29ada7c61e67845a005b7eb563de90989659e6619009373c83e59ed |