Skip to main content

PyTorch implementation of Basenji2

Project description

Basenji2 in PyTorch

This repo provides a PyTorch re-implementation of the Basenji2 model published in "Cross-species regulatory sequence activity prediction" by David Kelley. This implementation was checked by verifying that the Tensorflow and PyTorch version yielded the same output on random data. Small deviations were found, likely due to differences in the underlying algorithms used by Tensorflow and PyTorch (e.g. different matrix multiplication algorithms). In addition, Qixiu Du kindly computed evaluation metrics and found that the PyTorch re-implementation achieves competitive performance on real data, further validating the port.

Installation

pip install basenji2-pytorch

Usage

import torch
from basenji2_pytorch import Basenji2, basenji2_params, basenji2_weights # or PLBasenji2 to also use training parameters from Kelley et al. 2020

# to use a headless model e.g. for transfer learning
# basenji2_params["model"].pop("head_human", None)

basenji2 = Basenji2(basenji2_params["model"])
basenji2.load_state_dict(torch.load(basenji2_weights()), strict=False)
  • basenji2_params is a dictionary of both training and model parameters matching the implementation in Kelley et al. 2020
  • basenji2_weights is a function that uses pooch to download weights from Zenodo and return the path as a string.
  • Basenji2 is a PyTorch nn.Module that can be initialized from the model parameters of basenji2_params
  • PLBasenji2 is a PyTorch Lightning module that can be initialized from basenji2_params to match both the training and architectural parameters of Basenji2

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

basenji2_pytorch-0.1.2.tar.gz (83.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

basenji2_pytorch-0.1.2-py3-none-any.whl (8.0 kB view details)

Uploaded Python 3

File details

Details for the file basenji2_pytorch-0.1.2.tar.gz.

File metadata

  • Download URL: basenji2_pytorch-0.1.2.tar.gz
  • Upload date:
  • Size: 83.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.4

File hashes

Hashes for basenji2_pytorch-0.1.2.tar.gz
Algorithm Hash digest
SHA256 5219ee21bdc41c09c62818a3ace647112377d26864ea2b9969b27fb08c78d848
MD5 c4d38d3176db646006337274ac26657c
BLAKE2b-256 d5c55d0b5a5852262b0c04fc56f064c2049aeb637f6c72a5b540fae584c8ad57

See more details on using hashes here.

File details

Details for the file basenji2_pytorch-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for basenji2_pytorch-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 f56c0797eb7758ff11bf9e1b4b9044eda487438e47a0b547e220f6a3ab32658e
MD5 fe3dbb53b73aaabb5b0905e0aeaf3ae8
BLAKE2b-256 2dead6230cc9772edca0fd8c91638dc1c2e2def09b9a22cfabe6532906e96aea

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page