Simple implementations of attention modules adapted for the biological data domain
Project description
Install
Since PyTorch is a dependency of h5torch
, we recommend installing PyTorch independently first, as your system may require a specific version (e.g. CUDA drivers).
After PyTorch installation, h5torch
can be installed using pip
pip install bio-attention
Usage
Package roadmap
- Implement typing
LEGACY documentation
THIS REPO USED TO BE A 2D SLIDING WINDOW ATTENTION REPO
2D Sliding Window Attention
Stand-alone PyTorch implementation of 2D sliding window attention. Introduced by and part of CpG Transformer located at this repo and detailed in our preprint paper.
Contents
sliding_window_attn.py
contains three PyTorch modules: RelPositionalWindowEmbedding
, MultiDimWindowAttention
, and MultiDimWindowTransformerLayer
. The modules have been programmed in a way so that they can be used to do 1D sliding window attention, as well as >= 2-dimensional sliding window attention. In the multidimensional case, sliding window attention is applied over the first dimension following the batch dimension and full self-attention is applied over all the others.
Sliding windows are efficiently obtained using the unfold
operation.
Positional embeddings are relative sinusoidal ones as described in Transformer-XL. Note that positional encodings are applied for the dimension in which sliding windows are applied. To inform the model of position in other dimensions, this should be encoded in the input itself.
Usage
from sliding_window_attn import MultiDimWindowTransformerLayer
# one layer:
layer = MultiDimWindowTransformerLayer(
hidden_dim=64, # number of input & output hidden dimensions (int)
head_dim=8, # hidden dimensionality of each SA head (int)
n_head=8, # number of SA heads (int)
ff_dim=256, # number of feed-forward hidden dimensions (int)
window=21, # window size of sliding window, should be odd. (int) (default=21)
dropout=0.20, # dropout rate on the self-attention matrix (float) (default=0.20)
activation='relu', # activation used in feed-forward, either 'relu' or 'gelu' (str) (default='relu')
layernorm=True # whether to apply layernorm after attn+res and ff+res (bool) (default=True)
)
# model consisting of 4 layers:
model = nn.Sequential([MultiDimWindowTransformerLayer(64, 8, 8, 256),
MultiDimWindowTransformerLayer(64, 8, 8, 256),
MultiDimWindowTransformerLayer(64, 8, 8, 256),
MultiDimWindowTransformerLayer(64, 8, 8, 256)])
# 2D sequence input:
# batch size = 1
# sequence dim1 length = 512 (sliding window SA)
# sequence dim2 length = 4 (full SA)
# hidden = 64
x = torch.randn(1, 512, 4, 64)
pos = torch.cumsum(torch.randint(1, 7, (1, 512)), 1)
# if all positional indices follow on eachother by one: pos = torch.arange(512).unsqueeze(0)
x, pos = model((x, pos))
The same model can also be used for 1D sequence inputs:
# batch size = 1
# sequence dim1 length = 512 (sliding window SA)
# hidden = 64
x = torch.randn(1, 512, 64)
pos = torch.cumsum(torch.randint(1, 7, (1, 512)), 1)
x, pos = model((x, pos))
Or even 3D (or more) sequence input:
# batch size = 1
# sequence dim1 length = 512 (sliding window SA)
# sequence dim2 length = 4 (full SA)
# sequence dim3 length = 3 (full SA)
# hidden = 64
x = torch.randn(1, 512, 4, 3, 64)
pos = torch.cumsum(torch.randint(1, 7, (1, 512)), 1)
x, pos = model((x, pos))
Note that computational complexity will scale quadratically with each added dimension.
For example: the attention matrix (per head) for the above 1D example is: 512 * 21
.
For the 2D example this becomes: (512*4) * (21*4)
.
And for the 3D example: (512*4*3) * (21*4*3)
.
Citation
If you find this repository useful in your research, please cite our paper.
@article{dewaele2021cpg,
author = {Gaetan De Waele and Jim Clauwaert and Gerben Menschaert and Willem Waegeman},
title = {CpG Transformer for imputation of single-cell methylomes},
year = {2021},
doi = {10.1101/2021.06.08.447547},
URL = {https://www.biorxiv.org/content/early/2021/06/09/2021.06.08.447547}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for bio_attention-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3a5a9735b66395b90f0007e00a2b6317fbb643a3c5801e99b9436f45057f64d1 |
|
MD5 | ebdb5437d4f21da49995ccf995bcb80e |
|
BLAKE2b-256 | ba6c0392a572e5ab0cc99dee1b0a43f4451df328aa42f9f2b5240b922999309b |