Skip to main content

Reverse-complement parameter sharing (RCPS) layers for machine learning on DNA sequences in PyTorch

Project description

RCPS (Reverse-Complement Parameter Sharing)

RCPS layers for machine learning on DNA sequences in PyTorch, based on the fantastic work by Zhou et. al.. This version of RCPS acts as a near drop-in replacement for PyTorch's standard Conv1d layers with the exception that additional RCPS-type layers must be inserted before and after the convolutions to ensure the extra channels get handled properly.

The benefit of this version of RCPS layers is that they only store one version of the learned kernel parameters while still computing the forward and reverse convolutions on an input, emitting both outputs in a single channel-mirrored array. Additionally, RCPSBatchNorm provides a convenient and encoding-safe way to normalise within an RCPS block. For full details, please see the paper linked above.

Usage

RCPS layers can be used to carry out convolutions on DNA sequences and downstream encodings from previous RCPS layers. The easiest way of using them is as part of a torch Sequential object, which allows you to easily chain the RCPS layers between the needed RCPSInput and RCPSOutput layers.

To begin with, we can generate some one-hot encoded DNA sequence to run through our model. Usually I would use the encoder provided in the excellent enformer_pytorch repository, but I have hard-coded one here to avoid an extra dependency.

from rcps import RCPSInput, RCPS, RCPSBatchNorm, RCPSOutput
import torch.nn as nn
import torch

# AAATTATCCGGCG: one-hot encoded, stored in (batch, channel, length) order
fwd_seq = torch.Tensor(
    [
        [
            [1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1],
            [0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0],
        ]
    ]
)

We can use a torch Sequential object to create our model, ensuring to wrap all of our RCPS-layer logic with the -Input and -Output layers.

example_model = nn.Sequential(
    RCPSInput(out_kernels=8, kernel_size=3, padding="same"),
    RCPS(8, 16, 3, padding="same"),
    RCPSBatchNorm(16),
    RCPS(16, 3, 3, padding="same"),
    RCPSBatchNorm(3),
    RCPSOutput(),
)

Finally, we can verify that the model performs the same on both the forward and reverse-complement version of the input. Here, reverse-complementing the one-hot encoded DNA string is performed by 'flipping' both the channel and length axis.

out_fwd = example_model(fwd_seq)
out_rc = example_model(fwd_seq.flip(-1, -2))
print(torch.isclose(out_fwd, out_rc, atol=1e-6).all())
# tensor(True)

Limitations

  • Currently, the nominal way to use the RCPSInput is with an input that represents DNA sequences one-hot encoded, in the shape (batch, 4, length). In theory, there could be other encodings for DNA that do not use 4 channels but do still benefit from the easy 'reverse complement' action of flipping all not-batch dimensions.
  • Due to PyTorch being unable to reverse index (i.e. flip) without a copy, multiple copies of the input and weights are created during forward passes. I would welcome a way to change this, but unfortunately I don't know a way to do that without moving away from PyTorch.

Future Work

  • Find ways that the RCPS can be more smoothly integrated into Sequential models without the extra RCPSInput and RCPSOutput layers. (Pull Requests or Discussions welcome!)
  • Benchmark the performance of the RCPS technique against other reverse-complement preserving encodings/layers.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rcps_pytorch-0.0.1.tar.gz (4.7 kB view details)

Uploaded Source

Built Distribution

rcps_pytorch-0.0.1-py3-none-any.whl (5.2 kB view details)

Uploaded Python 3

File details

Details for the file rcps_pytorch-0.0.1.tar.gz.

File metadata

  • Download URL: rcps_pytorch-0.0.1.tar.gz
  • Upload date:
  • Size: 4.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for rcps_pytorch-0.0.1.tar.gz
Algorithm Hash digest
SHA256 91e041a75e040c47a05ece7030e4165261f71246436d59a15937a11eb4a58fd1
MD5 b236a638948660f9626a2b13e93892b9
BLAKE2b-256 809f216a634c63dd8d16589a8a1ec06ce13d55f5bdb072934d21d5bc2a31b53d

See more details on using hashes here.

File details

Details for the file rcps_pytorch-0.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for rcps_pytorch-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 daf7779d23eda7f0993e961ac3efc77d9ddfe918d3da43c0445ef93ba4ffcfd8
MD5 8db17bee206ffadcb9d234bb1ab8472f
BLAKE2b-256 070b6e10dfd8aa1c72b49f32b1e48f3a25709b80c1074a7bef160a627395ca47

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page