Skip to main content

Simple implementations of attention modules adapted for the biological data domain

Project description

bio-attention

Simple implementations of attention modules adapted for the biological data domain.

PyPi Version GitHub license

Install

Since PyTorch is a dependency of h5torch, we recommend installing PyTorch independently first, as your system may require a specific version (e.g. CUDA drivers).

After PyTorch installation, h5torch can be installed using pip

pip install bio-attention

Usage

Package roadmap

  • Implement typing

LEGACY documentation

THIS REPO USED TO BE A 2D SLIDING WINDOW ATTENTION REPO

2D Sliding Window Attention

Stand-alone PyTorch implementation of 2D sliding window attention. Introduced by and part of CpG Transformer located at this repo and detailed in our preprint paper.

Contents

sliding_window_attn.py contains three PyTorch modules: RelPositionalWindowEmbedding, MultiDimWindowAttention, and MultiDimWindowTransformerLayer. The modules have been programmed in a way so that they can be used to do 1D sliding window attention, as well as >= 2-dimensional sliding window attention. In the multidimensional case, sliding window attention is applied over the first dimension following the batch dimension and full self-attention is applied over all the others.

Sliding windows are efficiently obtained using the unfold operation.

Positional embeddings are relative sinusoidal ones as described in Transformer-XL. Note that positional encodings are applied for the dimension in which sliding windows are applied. To inform the model of position in other dimensions, this should be encoded in the input itself.

Usage

from sliding_window_attn import MultiDimWindowTransformerLayer

# one layer:
layer = MultiDimWindowTransformerLayer(
    hidden_dim=64,     # number of input & output hidden dimensions (int)
    head_dim=8,        # hidden dimensionality of each SA head (int)
    n_head=8,          # number of SA heads (int)
    ff_dim=256,        # number of feed-forward hidden dimensions (int)
    window=21,         # window size of sliding window, should be odd. (int) (default=21)
    dropout=0.20,      # dropout rate on the self-attention matrix (float) (default=0.20)
    activation='relu', # activation used in feed-forward, either 'relu' or 'gelu' (str) (default='relu')
    layernorm=True     # whether to apply layernorm after attn+res and ff+res (bool) (default=True)
)

# model consisting of 4 layers:
model = nn.Sequential([MultiDimWindowTransformerLayer(64, 8, 8, 256),
                       MultiDimWindowTransformerLayer(64, 8, 8, 256),
                       MultiDimWindowTransformerLayer(64, 8, 8, 256),
                       MultiDimWindowTransformerLayer(64, 8, 8, 256)])



# 2D sequence input:
# batch size = 1
# sequence dim1 length = 512 (sliding window SA)
# sequence dim2 length = 4 (full SA)
# hidden = 64
x = torch.randn(1, 512, 4, 64)
pos = torch.cumsum(torch.randint(1, 7, (1, 512)), 1)
# if all positional indices follow on eachother by one: pos = torch.arange(512).unsqueeze(0)

x, pos = model((x, pos))

The same model can also be used for 1D sequence inputs:

# batch size = 1
# sequence dim1 length = 512 (sliding window SA)
# hidden = 64
x = torch.randn(1, 512, 64)
pos = torch.cumsum(torch.randint(1, 7, (1, 512)), 1)

x, pos = model((x, pos))

Or even 3D (or more) sequence input:

# batch size = 1
# sequence dim1 length = 512 (sliding window SA)
# sequence dim2 length = 4 (full SA)
# sequence dim3 length = 3 (full SA)
# hidden = 64
x = torch.randn(1, 512, 4, 3, 64)
pos = torch.cumsum(torch.randint(1, 7, (1, 512)), 1)

x, pos = model((x, pos))

Note that computational complexity will scale quadratically with each added dimension. For example: the attention matrix (per head) for the above 1D example is: 512 * 21. For the 2D example this becomes: (512*4) * (21*4). And for the 3D example: (512*4*3) * (21*4*3).

Citation

If you find this repository useful in your research, please cite our paper.

@article{dewaele2021cpg,
	author = {Gaetan De Waele and Jim Clauwaert and Gerben Menschaert and Willem Waegeman},
	title = {CpG Transformer for imputation of single-cell methylomes},
	year = {2021},
	doi = {10.1101/2021.06.08.447547},
	URL = {https://www.biorxiv.org/content/early/2021/06/09/2021.06.08.447547}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bio-attention-0.0.1.tar.gz (2.2 MB view hashes)

Uploaded Source

Built Distribution

bio_attention-0.0.1-py3-none-any.whl (7.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page