Finetuning utilities for AlphaGenome with custom heads and parameter freezing
Project description
AlphaGenome Finetuning (alphagenome-ft)
A lightweight Python package for finetuning Google DeepMind's AlphaGenome model with custom prediction heads and parameter freezing capabilities, without modifying the original codebase.
Project leads - Alan Murphy, Masayuki (Moon) Nagai, Alejandro Buendia
Use cases
- If you want to apply AlphaGenome to your MPRA (or other perturbation) data of interest, see Encoder-only / short sequences (MPRA).
- If you want to apply AlphaGenome to your own genome-wide assay, start with Heads-only finetuning (frozen backbone)
; then LoRA-style adapters
or Full-model finetuning if needed.
Contents
- Overview
- Workflows
- Encoder-only / short sequences (MPRA) — finetune on short sequences (< 1 kb)
- Heads-only finetuning (frozen backbone) — train a new head on top of a frozen model
- LoRA-style adapters — low-rank adapter layers
- Full-model finetuning — unfreeze the backbone (e.g. progressive unfreezing)
- Attribution analysis — interpret predictions after training
- Reference
- Other
How to use this README: Sections are ordered by how you typically run things. Start with adapting to MPRA or heads-only finetuning; then add full-model workflows if needed. Run attribution and interpretation after you have a trained model. Step-by-step tutorials live in docs/.
Features
- Custom Prediction Heads: Define and register your own task-specific prediction heads
- Parameter Freezing: Path-based freeze helpers plus optimizer masking (
create_optimizer,heads_only=True) for real backbone freezes during training - Easy Integration: Works seamlessly with pretrained AlphaGenome models (simple wrapper classes)
- Parameter Inspection: Utilities to explore and count model parameters
- Attribution Analysis: Utilities to calculate attributions based on gradients or in silico mutagenesis (ISM)
- JAX/Haiku Native: Built on the same framework as AlphaGenome
Installation
This package depends on the AlphaGenome stack (alphagenome and alphagenome_research), which are not on PyPI and must be installed from GitHub. Use the following order.
Step 1: Install alphagenome-ft
From PyPI (recommended):
pip install alphagenome-ft
From source (development):
git clone https://github.com/genomicsxai/alphagenome_ft.git
cd alphagenome_ft
pip install -e .
This installs alphagenome-ft and its PyPI dependencies (JAX, Haiku, optax, etc.). It does not install the AlphaGenome model code.
Step 2: Install AlphaGenome and AlphaGenome Research
alphagenome_ft wraps AlphaGenome and AlphaGenome Research. Install both from GitHub:
pip install git+https://github.com/google-deepmind/alphagenome.git
pip install git+https://github.com/google-deepmind/alphagenome_research.git
Requirements
- Python ≥ 3.11
- All other runtime dependencies (JAX, Haiku, optax, orbax-checkpoint, etc.) are installed automatically with
alphagenome-ft. Seepyproject.tomlfor versions. - AlphaGenome and AlphaGenome Research must be installed separately as above; they are not on PyPI.
Quick Start
There are two options to add new heads to AlphaGenome.
Option A: Use a Predefined AlphaGenome Head
You can reuse the predefined head kinds from alphagenome_research without importing
HeadName by passing the string value. Supported strings are:
atac, dnase, procap, cage, rna_seq, chip_tf, chip_histone, contact_maps,
splice_sites_classification, splice_sites_usage, splice_sites_junction. For the details of each head, refer to alphagenome_research.
from alphagenome_ft import (
get_predefined_head_config,
register_predefined_head,
create_model_with_heads,
)
# 1. Build a predefined head config (num_tracks must match the number of your target tracks)
rna_config = get_predefined_head_config(
"rna_seq",
num_tracks=4,
)
# 2. Register it under an instance name you will train
register_predefined_head("K562_rna_seq", rna_config)
# 3. Create a model that uses the registered instance
model = create_model_with_heads("all_folds", heads=["K562_rna_seq"])
model.freeze_except_head("K562_rna_seq")
# 4. Optimizer for a *custom* training loop (frozen backbone in Optax, not only stop_gradient)
import optax
from alphagenome_ft import create_optimizer
optimizer = create_optimizer(
model._params,
trainable_head_names=("K562_rna_seq",),
learning_rate=1e-3,
weight_decay=1e-4,
heads_only=True, #NOTE: Adding this is key to avoiding base weight updates
)
opt_state = optimizer.init(model._params)
# Then: updates, opt_state = optimizer.update(grads, opt_state, model._params)
# model._params = optax.apply_updates(model._params, updates)
The built-in BigWig trainer (alphagenome_ft.finetune.train, heads_only=True) builds this optimizer for you. Pattern reference: Heads-only optimizer. freeze_except_head alone does not block backbone updates under jax.grad; heads_only=True in create_optimizer does.
Note if you have a local AlphaGenome weights version you want to use instead of getting the weights from Kaggle use:
model = create_model_with_heads(
'all_folds',
heads=['my_head'],
checkpoint_path="full/path/to/weights",
)
Option B: Use Custom Heads with Guidance from Reference Templates
We provide template heads for guidance on accessing different embeddings, see ./alphagenome_ft/templates.py:
from alphagenome.models import dna_output
from alphagenome_research.model import dna_model
from alphagenome_ft import (
templates,
CustomHeadConfig,
CustomHeadType,
register_custom_head,
create_model_with_heads,
)
# 1. Register a template head (modify class for your task)
register_custom_head(
'my_head',
templates.StandardHead, # Choose template: StandardHead, TransformerHead, EncoderOnlyHead
CustomHeadConfig(
type=CustomHeadType.GENOME_TRACKS,
output_type=dna_output.OutputType.RNA_SEQ,
num_tracks=1,
)
)
# 2. Create model with custom head
model = create_model_with_heads(
'all_folds',
heads=['my_head'],
)
# 3. Freeze backbone for finetuning (optional hint on the model)
model.freeze_except_head('my_head')
# 4. Custom loop: heads-only optimizer (see docs/heads_only_optimizer.md)
from alphagenome_ft import create_optimizer
optimizer = create_optimizer(
model._params,
trainable_head_names=('my_head',),
learning_rate=1e-3,
weight_decay=1e-4,
heads_only=True,
)
opt_state = optimizer.init(model._params)
Available Templates:
templates.StandardHead- Uses 1bp embeddings (decoder output: local + global features)templates.TransformerHead- Uses 128bp embeddings (transformer output: pure attention)templates.EncoderOnlyHead- Uses encoder output (CNN only, for short sequences < 1kb)
All templates use simple architecture: Linear → ReLU → Linear
The key difference is which embeddings they access. See alphagenome_ft/templates.py for code.
Note: Template heads are there as a guide for to how to set up your own custom head rather than a definitive ‘best’/‘standard’ option. You should update these with your own layer and loss function choices to fit your data needs.
To define your own head:
import jax
import jax.numpy as jnp
import haiku as hk
from alphagenome_ft import CustomHead
class MyCustomHead(CustomHead):
"""Your custom prediction head."""
def predict(self, embeddings, organism_index, **kwargs):
# Get embeddings at desired resolution
x = embeddings.get_sequence_embeddings(resolution=1) # or 128
# Add your prediction layers
x = hk.Linear(256, name='hidden')(x)
x = jax.nn.relu(x)
predictions = hk.Linear(self._num_tracks, name='output')(x)
return predictions
def loss(self, predictions, batch):
targets = batch.get('targets')
if targets is None:
return {'loss': jnp.array(0.0)}
mse = jnp.mean((predictions - targets) ** 2)
return {'loss': mse, 'mse': mse}
# Register and use in the same way as Option B
Note: Add Custom Head to Existing Model (Keep pre-trained Heads)
The three approaches above create models that include only the heads you explicitly provide. If you want to keep AlphaGenome's pre-trained heads (ATAC, RNA-seq, etc.) alongside your custom head:
from alphagenome.models import dna_output
from alphagenome_research.model import dna_model
from alphagenome_ft import (
templates,
CustomHeadConfig,
CustomHeadType,
register_custom_head,
add_heads_to_model,
)
# 1. Load pretrained model (includes standard heads)
base_model = dna_model.create_from_kaggle('all_folds')
# 2. Register custom or predefined head
register_custom_head(
'my_head',
templates.StandardHead,
CustomHeadConfig(
type=CustomHeadType.GENOME_TRACKS,
output_type=dna_output.OutputType.RNA_SEQ,
num_tracks=1,
)
)
# 3. Add custom head to model (keeps ALL standard heads)
model = add_heads_to_model(base_model, heads=['my_head'])
# 4. Freeze backbone + standard heads, train only custom head
model.freeze_except_head('my_head')
# 5. Loss + heads-only optimizer for custom training (finetune.train sets this if heads_only=True)
loss_fn = model.create_loss_fn_for_head('my_head')
from alphagenome_ft import create_optimizer
optimizer = create_optimizer(
model._params,
trainable_head_names=('my_head',),
learning_rate=1e-3,
weight_decay=1e-4,
heads_only=True,
)
opt_state = optimizer.init(model._params)
When to use each approach:
create_model_with_heads()- Heads only (faster, smaller)add_heads_to_model()- Added heads + pre-trained heads (useful when referring to the original tracks)
Workflows
Workflow 1: Encoder-only / short sequences (MPRA)
When to use: Short sequences (< ~1 kb): MPRA, promoters, enhancers. Uses encoder (CNN) only.
Tutorial: Encoder-only finetuning. Use templates.EncoderOnlyHead and use_encoder_output=True in create_model_with_heads(...). Custom loops must use create_optimizer(..., heads_only=True), not a plain optax.adamw on the full tree. See the application repo AlpahGenome MPRA repo for more details.
Workflow 2: Heads-only finetuning (frozen backbone)
When to use: New task (ChIP-seq, gene expression, etc.) on standard-length sequences; train only a new head, backbone frozen.
Tutorial: Frozen backbone, new head. Built-in train(..., heads_only=True) applies optimizer masking; custom loops: heads_only_optimizer.md.
Workflow 3: LoRA-style adapters
When to use: Low-rank adapters on the backbone. Tutorial: LoRA-style adapters. For custom Optax loops, use create_optimizer(..., heads_only=True) with your LoRA head name so only head (including LoRA) weights update.
Workflow 4: Full-model finetuning
When to use: Adapt the backbone (e.g. after heads-only or for a different distribution).
Tutorial: Full-model finetuning (unfreezing the backbone). Start heads-only with train(..., heads_only=True) or create_optimizer(..., heads_only=True); then unfreeze via unfreeze_parameters(unfreeze_prefixes=[...]) or freeze_backbone(freeze_prefixes=[...]); save with save_checkpoint(..., save_full_model=True).
After training: Attribution analysis
Compute attributions after training to see which sequence features drive predictions.
Methods: DeepSHAP*, Gradient × Input, Gradient, ISM.
Load a checkpoint, then use compute_deepshap_attributions, compute_input_gradients, or compute_ism_attributions; visualize with plot_attribution_map and plot_sequence_logo.
Full API, examples (basic, visualization, single-sequence pipeline), method comparison, and multi-track output_index: Attribution analysis.
NOTE: DeepSHAP* - The implementation is DeepSHAP-like in that it uses a reference sequence but is not a faithful reimplmenntation.
References
AlphaGenome Architecture
Understanding the architecture helps design custom heads:
DNA Sequence (B, S, 4)
↓
┌─────────────────────────────────────┐
│ BACKBONE (can be frozen) │
│ ├─ SequenceEncoder │
│ ├─ TransformerTower (9 blocks) │
│ └─ SequenceDecoder │
└─────────────────────────────────────┘
↓
┌────────────────────────────────────────────────┐
│ EMBEDDINGS (multi-resolution) │
│ ├─ embeddings_1bp: (B, S, 1536) │ # from SequenceDecoder
│ ├─ embeddings_128bp: (B, S/128, 1536) │ # from TransformerTower
│ ├─ embeddings_pair: (B, S/2048, S/2048, 128) │ # from TransformerTower
│ └─ encoder_output: (B, S/128, 1536) │ # from SequenceEncoder
└────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────┐
│ HEADS (task-specific) │
│ ├─ Standard: ATAC, RNA-seq, etc. │
│ └─ Custom: YOUR_HEAD_HERE ← Add! │
└─────────────────────────────────────┘
Available Embedding Resolutions
# In your custom head's predict() method:
# 1bp resolution (highest detail, local + global features)
x_1bp = embeddings.get_sequence_embeddings(resolution=1)
# Shape: (batch, sequence_length, 1536)
# 128bp resolution (global attention features)
x_128bp = embeddings.get_sequence_embeddings(resolution=128)
# Shape: (batch, sequence_length//128, 3072)
# Encoder output (CNN features only, requires use_encoder_output=True)
x_encoder = embeddings.encoder_output
# Shape: (batch, sequence_length//128, D)
Parameter Management
# Freeze/unfreeze by path or prefix
model.freeze_parameters(freeze_paths=['alphagenome/encoder/...'])
model.unfreeze_parameters(unfreeze_prefixes=['alphagenome/head/'])
# Convenient presets
model.freeze_backbone() # Freeze all backbone components (default)
model.freeze_backbone(freeze_prefixes=['sequence_encoder']) # Freeze only encoder
model.freeze_backbone(freeze_prefixes=['transformer_tower']) # Freeze only transformer
model.freeze_backbone(freeze_prefixes=['sequence_decoder']) # Freeze only decoder
model.freeze_backbone(freeze_prefixes=['sequence_encoder', 'transformer_tower']) # Freeze encoder + transformer
model.freeze_all_heads(except_heads=['my_head']) # Freeze all heads except one
model.freeze_except_head('my_head') # Freeze everything except one head (sets hint for heads-only optimizers)
# Optimizer masking (true freeze during training — required for custom loops)
# Full pattern: docs/heads_only_optimizer.md
from alphagenome_ft import create_optimizer
opt = create_optimizer(
model._params,
trainable_head_names=('my_head',),
learning_rate=1e-3,
weight_decay=1e-4,
heads_only=True,
)
opt_state = opt.init(model._params)
# Inspection
paths = model.get_parameter_paths() # All parameter paths
head_paths = model.get_head_parameter_paths() # Just head parameters
backbone_paths = model.get_backbone_parameter_paths() # Just backbone
count = model.count_parameters() # Total parameter count
Modular Backbone Freezing
The freeze_backbone() method now supports modular freezing of backbone components. This allows you to selectively freeze only specific parts of the backbone (encoder, transformer, or decoder) while keeping others trainable. This is useful for progressive finetuning strategies:
# Example: Progressive finetuning strategy
# 1. Start with only head trainable
model.freeze_backbone() # Freeze all backbone
# 2. Unfreeze decoder for fine-grained adaptation
model.unfreeze_parameters(unfreeze_prefixes=['sequence_decoder'])
# 3. Later, unfreeze transformer for global context adaptation
model.unfreeze_parameters(unfreeze_prefixes=['transformer_tower'])
# 4. Finally, unfreeze encoder for full finetuning
model.unfreeze_parameters(unfreeze_prefixes=['sequence_encoder'])
# Or use freeze_backbone with specific prefixes from the start:
model.freeze_backbone(freeze_prefixes=['sequence_encoder', 'transformer_tower']) # Only decoder trainable
Saving a Checkpoint
save_full_model=False(default): Only saves custom head parameters (~MBs)- Recommended for finetuning - much smaller checkpoints
- Requires loading the base model when restoring
save_full_model=True: Saves entire model including backbone (~GBs)- Self-contained checkpoint
- Larger file size but no need for base model
- Use if unfreezing the base model
Loading a Checkpoint
from alphagenome_ft import load_checkpoint
# Load a heads-only checkpoint (requires base model)
model = load_checkpoint(
'checkpoints/my_model',
base_model_version='all_folds' # Which base model to use
)
# Now use for inference or continue training
predictions = model.predict(...)
Important: Before loading, you must register the custom head classes:
# Import and register your custom head class
from your_module import MyCustomHead
from alphagenome_ft import register_custom_head
register_custom_head('my_head', MyCustomHead, config)
# Now load checkpoint
model = load_checkpoint('checkpoints/my_model')
Testing
The package includes a comprehensive test suite using pytest.
Kaggle credentials (model download)
Several tests load AlphaGenome via create_from_kaggle. Credentials are detected if either:
KAGGLE_USERNAMEandKAGGLE_KEYare set in the environment, or~/.kaggle/kaggle.jsonexists and contains"username"and"key"(the default location when you use the Kaggle CLI/API).
No extra export step is required if kaggle.json is already in place. To run only tests that do not download the model:
pytest tests/test_optimizer_masking.py::test_heads_only_optimizer_fake_param_tree -q
Run Tests
# Install test dependencies
pip install -e ".[test]"
# Run all tests
pytest
# Run with coverage
pytest --cov=alphagenome_ft --cov-report=html
# Run specific test file
pytest tests/test_custom_heads.py
pytest tests/test_checkpoint.py
See tests/README.md for detailed testing documentation.
Contributing
Contributions welcome! Please:
- Fork the repository
- Create a feature branch
- Add tests for new functionality (see
tests/README.md) - Ensure tests pass:
pytest - Submit a pull request
Publishing (maintainers)
Releases are published to PyPI via GitHub Actions using Trusted Publishing (no API token stored in the repo).
One-time setup on PyPI
- Open PyPI → alphagenome-ft → Publishing.
- Add a trusted publisher:
- Owner: your GitHub user or org (e.g.
genomicsxai) - Repository:
alphagenome_ft - Workflow name:
publish.yml - Environment (optional):
pypiif you create that environment in the repo
- Owner: your GitHub user or org (e.g.
If you use the pypi environment, create it in the repo under Settings → Environments so the workflow can run.
Releasing a new version
- Bump
versioninpyproject.toml(e.g.0.1.2). - Commit and push.
- Create and push a tag matching the version:
git tag v0.1.2 && git push origin v0.1.2 - The Publish to PyPI workflow runs and uploads to PyPI.
License
MIT License - see LICENSE file for details.
This project extends AlphaGenome, which has its own license terms.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file alphagenome_ft-0.1.8.tar.gz.
File metadata
- Download URL: alphagenome_ft-0.1.8.tar.gz
- Upload date:
- Size: 84.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9426601e521cb45af70b8fce616d7d2cbd6879ba5db3367047c87b9b4692f685
|
|
| MD5 |
7d5b9c5c8ca754076994a8caef17e4d8
|
|
| BLAKE2b-256 |
c28eddf32c62d415024d6e2cc7ab36a163d7c28471a2b14e836251d39629abad
|
Provenance
The following attestation bundles were made for alphagenome_ft-0.1.8.tar.gz:
Publisher:
publish.yml on genomicsxai/alphagenome_ft
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
alphagenome_ft-0.1.8.tar.gz -
Subject digest:
9426601e521cb45af70b8fce616d7d2cbd6879ba5db3367047c87b9b4692f685 - Sigstore transparency entry: 1228365456
- Sigstore integration time:
-
Permalink:
genomicsxai/alphagenome_ft@5405326e13118fd3d68ec42dba507a12980e4c97 -
Branch / Tag:
refs/tags/v0.1.8 - Owner: https://github.com/genomicsxai
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@5405326e13118fd3d68ec42dba507a12980e4c97 -
Trigger Event:
push
-
Statement type:
File details
Details for the file alphagenome_ft-0.1.8-py3-none-any.whl.
File metadata
- Download URL: alphagenome_ft-0.1.8-py3-none-any.whl
- Upload date:
- Size: 74.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
157fbe7b646dc437fa3b72acbbc78d065970446a13acf15c183429f80939e139
|
|
| MD5 |
94367f77edea379b4a75a5bdaf85acfc
|
|
| BLAKE2b-256 |
25e065c4d07e05cbf820316fb141f54fff2a07b513b29123fcd5a37887487818
|
Provenance
The following attestation bundles were made for alphagenome_ft-0.1.8-py3-none-any.whl:
Publisher:
publish.yml on genomicsxai/alphagenome_ft
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
alphagenome_ft-0.1.8-py3-none-any.whl -
Subject digest:
157fbe7b646dc437fa3b72acbbc78d065970446a13acf15c183429f80939e139 - Sigstore transparency entry: 1228365498
- Sigstore integration time:
-
Permalink:
genomicsxai/alphagenome_ft@5405326e13118fd3d68ec42dba507a12980e4c97 -
Branch / Tag:
refs/tags/v0.1.8 - Owner: https://github.com/genomicsxai
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@5405326e13118fd3d68ec42dba507a12980e4c97 -
Trigger Event:
push
-
Statement type: