Skip to main content

Reverse-complement parameter sharing (RCPS) layers for machine learning on DNA sequences in PyTorch

Project description

RCPS (Reverse-Complement Parameter Sharing)

RCPS layers for machine learning on DNA sequences in PyTorch, based on the fantastic work by Zhou et. al.. The benefit of RCPS layers is that they only store one version of the learned kernel parameters while still computing the forward and reverse convolutions on an input, emitting both outputs in a single channel-mirrored array. Additionally, RCPSBatchNorm provides a convenient and encoding-safe way to normalise within an RCPS block. For full details, please see the paper linked above.

Usage

RCPS layers can be used to carry out convolutions on DNA sequences and downstream encodings from previous RCPS layers. This version of RCPS acts as a near drop-in replacement for PyTorch's standard Conv1d layers with the exception that additional RCPS-type layers must be inserted before and after the convolutions to ensure the extra channels get handled properly. The easiest way of using them is as part of a torch Sequential object, which allows you to easily chain the RCPS layers between the needed RCPSInput and RCPSOutput layers.

To begin with, we can generate some one-hot encoded DNA sequence to run through our model. Usually I would use the encoder provided in the excellent tangermeme repository, but I have hard-coded one here to avoid an extra dependency.

from rcps import RCPSInput, RCPS, RCPSBatchNorm, RCPSOutput
import torch.nn as nn
import torch

# AAATTATCCGGCG: one-hot encoded, stored in (batch, channel, length) order
fwd_seq = torch.Tensor(
    [
        [
            [1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1],
            [0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0],
        ]
    ]
)

We can use a torch Sequential object to create our model, ensuring to wrap all of our RCPS-layer logic with the -Input and -Output layers.

example_model = nn.Sequential(
    RCPSInput(out_kernels=8, kernel_size=3, padding="same"),
    RCPS(8, 16, 3, padding="same"),
    RCPSBatchNorm(16),
    RCPS(16, 3, 3, padding="same"),
    RCPSBatchNorm(3),
    RCPSOutput(),
)

Finally, we can verify that the model performs the same on both the forward and reverse-complement version of the input. Here, reverse-complementing the one-hot encoded DNA string is performed by 'flipping' both the channel and length axis.

out_fwd = example_model(fwd_seq)
out_rc = example_model(fwd_seq.flip(-1, -2))
print(torch.isclose(out_fwd, out_rc, atol=1e-6).all())
# tensor(True)

Limitations

  • Currently, the nominal way to use the RCPSInput is with an input that represents DNA sequences one-hot encoded, in the shape (batch, 4, length). In theory, there could be other encodings for DNA that do not use 4 channels but do still benefit from the easy 'reverse complement' action of flipping all not-batch dimensions.
  • Due to PyTorch being unable to reverse index (i.e. flip) without a copy, multiple copies of the input and weights are created during forward passes. I would welcome a way to change this, but unfortunately I don't know a way to do that without moving away from PyTorch.

Future Work

  • Find ways that the RCPS can be more smoothly integrated into Sequential models without the extra RCPSInput and RCPSOutput layers. (Pull Requests or Discussions welcome!)
  • Benchmark the performance of the RCPS technique against other reverse-complement preserving encodings/layers.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rcps_pytorch-0.0.2.tar.gz (4.7 kB view details)

Uploaded Source

Built Distribution

rcps_pytorch-0.0.2-py3-none-any.whl (5.2 kB view details)

Uploaded Python 3

File details

Details for the file rcps_pytorch-0.0.2.tar.gz.

File metadata

  • Download URL: rcps_pytorch-0.0.2.tar.gz
  • Upload date:
  • Size: 4.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for rcps_pytorch-0.0.2.tar.gz
Algorithm Hash digest
SHA256 0c04d57c42d28b49df72fe778875ead25a8f357205e3bc37a9e067c828a08962
MD5 f9be33a996f2c3b8793cf7e65c543017
BLAKE2b-256 6e8c14c861cde1abfc18b23f217db175a5f4a00c31cd1ee23f635b0e63b332ad

See more details on using hashes here.

File details

Details for the file rcps_pytorch-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for rcps_pytorch-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 272467fdf0d9a38b64bba5c835c213b97e8824c903b8d8969c4981a000f21c87
MD5 c621fa6b41273079d21ec667518acd61
BLAKE2b-256 aba6752f8bb911fc39c660f82ec2ee5641ffef463e3d3204190e5c4d702d2132

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page