Skip to main content

A light-weight deep learning architecture that uses modern optimization tricks to achieve strong predictive performance.

Project description

PyPI Version Python 3.10+ License

[!IMPORTANT] Cherimoya is still under active development and may change in ways that are not back compatible. Please make note of the version you are using in case you need to return to it in the future.

Cherimoya is a lightweight genomic sequence-to-function (S2F) model for predicting genomic modalities such as transcription factor binding, chromatin accessibility, and transcription initiation. It builds on concepts that were first introduced by BPNet and ChromBPNet while introducing architectural, algorithmic, and systems-level improvements that improve training stability, efficiency, and predictive performance. Despite needing significantly fewer parameters than other architectures, Cherimoya achieves strong predictive performance across a range of tasks and runs ~5-15x faster when measured on an H200 GPU.

The secret to Cherimoya's success is a new Cheri Block, which adapts the ConvNeXT block to the domain of noisy high-throughput genomics experiments. This block is comprised of a dilated depth-wise convolution, a layer norm, a projection into a higher-dimensional space, a GeLU non-linearity, a projection back into the original dimensionality, and a residual connection scaled by a small fixed constant. Conceptually, this means that the blocks first aggregate information spatially but independently for each feature/channel (the depth-wise convolution) and then aggregate information across features but independently for each position (the two projections). The dilated depth-wise convolution and the layer norm have been fused into an efficient custom GPU kernel that is ~2-3x faster than the native PyTorch implementation.

Installation

pip install cherimoya

Or, using uv:

uv pip install cherimoya

To install from source:

git clone https://github.com/jmschrei/cherimoya.git
cd cherimoya
pip install -e .    # or: uv pip install -e .

Key Features

Lightweight Architecture: Cherimoya employs a compact convolutional backbone that substantially reduces parameter count while also slightly increasing predictive accuracy. This design enables efficient training, large-scale hyperparameter exploration, interactive usage via browsers, and usage of dozens or hundreds of such models simultaneously in complex design settings.

Stable training: Several design choices were made to improve the stability of model training, including the use of layer norm in each layer, a small fixed scalar on each residual connection (configurable via the residual_scale argument on both CheriBlock and Cherimoya, default 0.15) that keeps the path close to the identity mapping at initialization, a cosine decay learning rate scheduler with a long warmup (5 epochs by default), removing all bias terms in the Cheri blocks, and, somewhat counterintuitively, removing weight decay from the optimizers.

Automatic Loss Weight Balancing: Profile and count losses are combined using learned weighting parameters rather than fixed hyperparameters. This approach replaces the heuristic developed for BPNet and ChromBPNet models and enables the models to scale to larger contexts and across modalities automatically, while also improving gradient stability across datasets with varying signal-to-noise characteristics.

Muon Optimizer: Cherimoya uses the Muon optimizer when training the projection layers, and the AdamW optimizer for all other layers and terms. This has significantly accelerated training by reducing the number of epochs needed while modestly improving performance.

Model Compilation: Because of the architectural decisions made in the Cheri block, many operations can be automatically fused together in neat ways when using torch.compile and so this has been built-in to the forward pass. This seems to offer a ~50-75% speed improvement. Although this compilation needs to only be done once and can be then re-used across models and sessions, it may need to be redone each time the batch size has changed, e.g., for the last batch being processed.

Mixed Precision: When data is of reasonable depth, Cherimoya models are best trained using mixed precision, which can offer a ~2x speed improvement (sometimes more when also compiling the model). However, using mixed precision can hurt performance when the data is very low quality or low read depth, such as for TF ChIP-seq experiments or pseudobulks for rare cell types. We recommend using float32 precision for BPNet-style models as a starting point unless you have particularly high-quality data.

Deterministic Sampling: The peak/negative sampler in PeakGenerator is a pure function of (random_state, epoch, idx), so every peak appears exactly once per epoch and two runs with the same seed produce bit-identical training data. num_workers > 1 is purely a speed optimization — it produces the same sequence of batches as num_workers = 1, just faster.

Saving and Loading Models

Models are saved as a dictionary containing the constructor arguments and a state dict, rather than a pickled module. This format is robust to source-layout changes and is safe to load with PyTorch's weights_only=True setting:

from cherimoya import Cherimoya

model = Cherimoya(n_filters=96, n_layers=9, n_outputs=2)
# ... train ...
model.save("my_model.torch")

# Load on CPU (default)
model = Cherimoya.load("my_model.torch")

# Or load directly onto a GPU
model = Cherimoya.load("my_model.torch", device="cuda")

The CLI commands (evaluate, attribute, marginalize) and model.fit() use this format internally. Older checkpoints saved with torch.save(model, ...) are not compatible with Cherimoya.load and should be retrained.

End-to-End Pipeline

Cherimoya provides an integrated command-line pipeline that allows you to go directly from mapped reads, to model training and evaluation, to analysis results. This pipeline improves reproducibility by being self-documenting on the parameter settings for each step, and dramatically reduces the overhead associated with managing seperate tooling for each stage. Specifically, it includes:

  • Conversion from BAM/SAM/fragment files to (un)/stranded bigWig(s) using bam2bw
  • Peak calling using MACS3
  • Calling of GC-matched negatives
  • Model training and evaluation
  • Attribution scores using in silico saturation mutagenesis
  • Seqlet calling and annotation using tomtom-lite
  • De novo motif discovery using TF-MoDISco

A multi-step pipeline like this has many hyperparameters that can be customized at each step (e.g., number of filters in the model, number of seqlets to use for TF-MoDISco) and requires pointers to several input and output files. Rather than using a giant command-line call, Cherimoya uses JSONs to manage each step of the pipeline. An advantage of using JSONs is that they create a permanent record of the exact command that was run. Although there are many hyperparameters, the user-provided JSONs can be quite small in practice because they are internally merged with the default parameters for each step. The fastest way to begin this process is through the pipeline-json command, which takes in pointers to your data files and flags describing the data and produces a valid JSON for the pipeline process. These data files usually include a reference genome, some number of input (and optionally control) BAM/SAM/tsv/tsv.gz files (the -i and -c arguments can be repeated) a BED file of positive loci, and a MEME formatted motif database used for evaluation of the model.

For example, if you are working with ChIP-seq data that is stranded:

cherimoya pipeline-json -s hg38.fa -p peaks.bed.gz -i input1.bam -i input2.bam -c control1.bam -c control2.bam -n test -o pipeline.json -m JASPAR_2024.meme

If you are working with ATAC-seq data, which is unstranded and comes in the form of paired-end fragmnents that need to be shifted +4/-4 (as they do in the ChromBPNet work) you can use the following:

cherimoya pipeline-json -s hg38.fa -p peaks.bed.gz -i input1.bam -i input2.bam -n atac-test -o atac-pipeline.json -m JASPAR_2024.meme -ps 4 -ns -4 -u -f -pe

Note that any of these data pointers can point to remote files. This will stream the data through bam2bw and read the peak files remotely. Processing speed will then depend on the speed of your internet connection and whether the hosting site throttles your connection.

The resulting JSON stored at pipeline.json or atac-pipeline.json can then be executed using the pipeline command. These commands are separated because, although the first command produces a valid JSON that the second command can immediately use, one may wish to modify some of the many parameters in the JSON. These parameters include the number of filters and layers in the model, the training and validation chromosomes, and the p-value threshold for calling seqlets. The defaults for most of these steps seem reasonable in practice, but there is immense flexibility there, e.g., the ability to train the model using a reference genome and then make predictions or attributions on synthetic sequences or the reference genome from another species. In this manner, the JSON serves as documentation for the experiments that have been performed.

cherimoya pipeline -p pipeline.json

When running the pipeline, a JSON is produced for each one of the steps (except for running TF-MoDISco and annotating the seqlets, which uses ttl). Each of these JSONs can be run by itself using the appropriate built-in command. Because some of the values in the JSONs for these steps are set programmatically when running the file pipeline, e.g., the filenames to read in and save to, being able to inspect every one of the JSONs can be handy for debugging.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cherimoya-0.1.0.tar.gz (51.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cherimoya-0.1.0-py3-none-any.whl (43.4 kB view details)

Uploaded Python 3

File details

Details for the file cherimoya-0.1.0.tar.gz.

File metadata

  • Download URL: cherimoya-0.1.0.tar.gz
  • Upload date:
  • Size: 51.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.9 {"installer":{"name":"uv","version":"0.10.9","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for cherimoya-0.1.0.tar.gz
Algorithm Hash digest
SHA256 af71c16d8e2c53061ed62314b312078eeb73a0e4d129594e2d45993f5372cc5d
MD5 b437c8f43b5016a9cfd5fcd8ba429b5d
BLAKE2b-256 ec625704ce8c930c1b01ffcd17cbb99aabe1dcf2dff2fb8a56cede6ee2be5d19

See more details on using hashes here.

File details

Details for the file cherimoya-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: cherimoya-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 43.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.9 {"installer":{"name":"uv","version":"0.10.9","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for cherimoya-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 751368251dab9302ea6273146802dc05359817186f94d0def3686bc7caee1193
MD5 d7bf398155ab7bf4f87bba7de9c74334
BLAKE2b-256 95dc62767d3fcb98d45672f6660feb901bcfe579efcf93d8f4c4ff88e63c39cf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page