A deep learning framework for hierarchical taxonomy classification of 16S rRNA gene sequences.
Project description
DeepTaxa
DeepTaxa is a deep learning framework for hierarchical taxonomic classification of 16S rRNA gene sequences. It classifies sequences into all seven taxonomic ranks (Domain through Species) in a single forward pass, achieving 92.96% species-level accuracy (3-seed mean) on the Greengenes2 2024.09 test set.
Table of Contents
- Performance
- Installation
- Quick Start
- Data and Pre-Trained Models
- Training
- Experimentation
- Scripts
- Tutorials
- License
- Citation
- Contact
- Acknowledgements
Performance
The published HybridCNNBERT checkpoint achieves the following on 69,335 held-out test sequences from Greengenes2 2024.09 (3-seed mean across seeds 42, 123, 456):
| Rank | Accuracy | F1 | ECE |
|---|---|---|---|
| Domain | 99.98% | 99.98% | 0.0001 |
| Phylum | 99.69% | 99.68% | 0.0023 |
| Class | 99.63% | 99.59% | 0.0024 |
| Order | 99.07% | 98.97% | 0.0056 |
| Family | 98.61% | 98.41% | 0.0075 |
| Genus | 96.90% | 96.48% | 0.0144 |
| Species | 92.96% | 92.12% | 0.0242 |
Cross-seed standard deviation is at most 0.0008 F1 at every rank (species std 0.0008 F1 / 0.07 percentage points accuracy), demonstrating high reproducibility.
Architecture
| Component | Configuration |
|---|---|
| CNN | embed_dim=896, 256 filters, kernels [3, 5, 7], 1 conv layer |
| BERT | 4 layers, 7 heads, hidden=896, FFN=3584, GELU, random init |
| Fusion | Learnable alpha/beta weights + BERT residual connection |
| Training | Cross-entropy loss, LR=5e-4, batch=64, dropout=0.20, 10 epochs |
Three architectures are available:
- HybridCNNBERTClassifier (default): Fuses CNN local motif features with BERT global context. Used for the published checkpoints.
- CNNClassifier: Multi-kernel convolutional network only. Faster training, slightly lower species accuracy.
- BERTClassifier: Transformer encoder only. On its own, a from-scratch transformer underperforms substantially at the species rank; provided mainly for ablation.
Pre-Trained Checkpoints
Two checkpoints are hosted on Hugging Face:
| Checkpoint | Training data | Species accuracy | Parameters |
|---|---|---|---|
deeptaxa-full-length-v1.pt |
Full-length 16S (277,336 sequences, ~1,500 bp) | 92.96% (3-seed mean) | 76.4 M |
deeptaxa-v3v4-v1.pt |
In-silico V3-V4 amplicons (~420 bp, 273,003 amplicons) | 87.55% (seed 42) | 75.8 M |
Both checkpoints share the same compact architecture (the small parameter difference reflects smaller per-rank classifier heads on the V3-V4 model, which has a smaller species vocabulary: 8,347 vs 16,909). A config.json with full model metadata is also available.
Installation
DeepTaxa requires Python 3.10 or later. We recommend using a Conda environment:
git clone https://github.com/systems-genomics-lab/deeptaxa.git
cd deeptaxa
conda create --name deeptaxa_env python=3.10 -y
conda activate deeptaxa_env
pip install .
deeptaxa --version
Dependencies (torch, transformers, pandas, numpy, scikit-learn, h5py, etc.) are specified in pyproject.toml and installed automatically.
Note: For GPU support, install a CUDA-compatible PyTorch build before running
pip install .. See the PyTorch installation guide.
Quick Start
Predict with the pre-trained model (no training data needed):
# Download the checkpoint
mkdir -p ../deeptaxa-data/models
wget -P ../deeptaxa-data/models \
https://huggingface.co/systems-genomics-lab/deeptaxa/resolve/main/deeptaxa-full-length-v1.pt
# Classify sequences
deeptaxa predict \
--fasta-file your_sequences.fna \
--checkpoint ../deeptaxa-data/models/deeptaxa-full-length-v1.pt \
--output-dir ../deeptaxa-outputs/predictions
Evaluate against known labels (adds per-rank accuracy, F1, ECE to the output):
deeptaxa predict \
--fasta-file ../deeptaxa-data/greengenes/gg_2024_09_testing.fna.gz \
--taxonomy-file ../deeptaxa-data/greengenes/gg_2024_09_testing.tsv.gz \
--checkpoint ../deeptaxa-data/models/deeptaxa-full-length-v1.pt \
--output-dir ../deeptaxa-outputs/evaluation
Inspect a checkpoint:
deeptaxa describe \
--checkpoint ../deeptaxa-data/models/deeptaxa-full-length-v1.pt
Tip: Run
deeptaxa train --helpordeeptaxa predict --helpfor a full list of options.
Data and Pre-Trained Models
Datasets and checkpoints are hosted on Hugging Face. Store them in a sibling directory outside the codebase:
working_directory/
├── deeptaxa/ # This repository
├── deeptaxa-data/ # Datasets and checkpoints
│ ├── greengenes/
│ │ ├── gg_2024_09_training.fna.gz (277,336 sequences, ~96 MB)
│ │ ├── gg_2024_09_training.tsv.gz (taxonomy labels, ~2.6 MB)
│ │ ├── gg_2024_09_testing.fna.gz (69,335 sequences, ~24 MB)
│ │ └── gg_2024_09_testing.tsv.gz (taxonomy labels, ~0.8 MB)
│ └── models/
│ ├── deeptaxa-full-length-v1.pt
│ └── deeptaxa-v3v4-v1.pt
└── deeptaxa-outputs/ # Training and prediction outputs
DeepTaxa uses the Greengenes2 database (2024.09 release), reformatted and hosted on Hugging Face.
Download
# Dataset
mkdir -p deeptaxa-data/greengenes && cd deeptaxa-data/greengenes
for f in gg_2024_09_training.fna.gz gg_2024_09_training.tsv.gz \
gg_2024_09_testing.fna.gz gg_2024_09_testing.tsv.gz; do
wget https://huggingface.co/datasets/systems-genomics-lab/greengenes/resolve/main/$f
done
# Checkpoints
mkdir -p ../models && cd ../models
wget https://huggingface.co/systems-genomics-lab/deeptaxa/resolve/main/deeptaxa-full-length-v1.pt
wget https://huggingface.co/systems-genomics-lab/deeptaxa/resolve/main/deeptaxa-v3v4-v1.pt
wget https://huggingface.co/systems-genomics-lab/deeptaxa/resolve/main/config.json
Note: Checkpoint files use PyTorch's
pickle-based serialization. Download them only from the official Hugging Face repository.
Training
All architecture hyperparameters default to the published (compact) configuration, so a minimal training command reproduces the published checkpoint:
deeptaxa train \
--fasta-file ../deeptaxa-data/greengenes/gg_2024_09_training.fna.gz \
--taxonomy-file ../deeptaxa-data/greengenes/gg_2024_09_training.tsv.gz \
--model-type hybridcnnbert \
--output-dir ../deeptaxa-outputs/
Training takes approximately 1 h 20 m on an NVIDIA RTX 4090 (or 2 h 35 m on an NVIDIA A40) for 10 epochs.
Output
Each training run produces:
checkpoints/deeptaxa_<uuid>_epoch<N>.pt: Model weights, optimizer state, scheduler state, and label encoders for each epoch.metrics/deeptaxa_<uuid>_epoch<N>.json: Per-epoch validation loss, accuracy, F1, precision, and recall at each rank.deeptaxa_uuid.txt: The unique run identifier.
Early Stopping
To stop training when validation loss plateaus:
deeptaxa train \
--fasta-file ../deeptaxa-data/greengenes/gg_2024_09_training.fna.gz \
--taxonomy-file ../deeptaxa-data/greengenes/gg_2024_09_training.tsv.gz \
--model-type hybridcnnbert \
--epochs 20 \
--early-stopping-patience 3 \
--output-dir ../deeptaxa-outputs/
Setting --early-stopping-patience 0 (the default) disables early stopping.
Experimentation
The default configuration uses DNABERT-2 tokenization, cross-entropy loss, and uniform rank weighting. Each choice can be varied independently for ablation studies.
Encoding comparison
# Default: DNABERT-2 BPE tokenization
deeptaxa train --model-type cnn --encoding dnabert ...
# Ablation: one-hot nucleotide encoding (4-channel, no pretrained tokenizer)
deeptaxa train --model-type cnn --encoding onehot ...
Loss function comparison
# Default: cross-entropy
deeptaxa train --model-type hybridcnnbert --loss-type cross_entropy ...
# Ablation: focal loss (gamma=2.0)
deeptaxa train --model-type hybridcnnbert --loss-type focal --focal-gamma 2.0 ...
Architecture comparison
Train CNN-only, BERT-only, or the hybrid under the same data and hyperparameters using --model-type cnn, --model-type bert, or --model-type hybridcnnbert.
Calibration
When --taxonomy-file is provided at prediction time, DeepTaxa computes Expected Calibration Error (ECE) alongside accuracy, F1, precision, recall, and AUC. ECE measures the gap between predicted confidence and observed accuracy across 10 equal-width bins. All metrics are saved to metrics.json.
Scripts
The scripts/ directory contains reusable tools for common workflows:
| Script | Purpose |
|---|---|
deeptaxa_workflow.sh |
End-to-end workflow: train, resume, describe, predict |
run_experiment.sh |
Central experiment runner with logging and timing |
run_ablation.sh |
Ablation study: architecture, encoding, and loss variants |
run_amplicon_eval.sh |
Simulated amplicon evaluation (V3-V4, V4) |
run_similarity_eval.sh |
Similarity-stratified evaluation using vsearch |
calibration_diagnosis.sh |
A/B comparison of temperature configurations |
calibration_sweep.sh |
Multi-configuration temperature sweep |
simulate_amplicons.py |
Extract amplicon regions via in-silico PCR |
sequence_similarity.py |
Compute train-test nearest-neighbor identity |
Tutorials
Interactive tutorials with executable code are published at systems-genomics-lab.github.io/deeptaxa:
- Prediction: Classify sequences with the pre-trained model
- Training: Train from scratch on Greengenes2
- Analysis: Evaluate performance, calibration, and error patterns
- Architecture: Model internals and extensibility
License
- Code and models: MIT License
- Greengenes dataset: Modified BSD License
Citation
If DeepTaxa contributes to your research, please cite our paper in Bioinformatics Advances: https://doi.org/10.1093/bioadv/vbag166
@article{salah2026deeptaxa,
title={{DeepTaxa}: A Hybrid {CNN}-{BERT} Framework for {16S} {rRNA} Taxonomic Classification},
author={Salah, Rana and AbdElaal, Khlood R. and Ghonaim, Lobna and Awe, Olaitan I. and Moustafa, Ahmed},
journal={Bioinformatics Advances},
year={2026},
doi={10.1093/bioadv/vbag166},
publisher={Oxford University Press}
}
For the Greengenes dataset:
@article{mcdonald2024greengenes,
title={Greengenes2 unifies microbial data in a single reference tree},
author={McDonald, Daniel and Jiang, Yueyu and Balaban, Metin and others},
journal={Nature Biotechnology},
volume={42},
pages={715--718},
year={2024},
doi={10.1038/s41587-023-01845-1}
}
Contact
To report bugs, suggest features, or contribute code, open an issue on GitHub.
Acknowledgements
- Ahmed A. El Hosseiny and the High-Performance Computing Team of the School of Sciences and Engineering at the American University in Cairo for GPU access that enabled this work.
- Hugging Face for hosting datasets and models.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file deeptaxa_rrna-1.0.1.tar.gz.
File metadata
- Download URL: deeptaxa_rrna-1.0.1.tar.gz
- Upload date:
- Size: 120.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ef69f222733f7b0d322e06959f3d2f6b0c1ca11b304f1cf18baf46b4f09e2342
|
|
| MD5 |
810d19396633a039f48175245b12556e
|
|
| BLAKE2b-256 |
11efc33c41d514113eeafdb16ff4df02ae1bb08193e9bbfbfce56f9149a4b1a3
|
File details
Details for the file deeptaxa_rrna-1.0.1-py3-none-any.whl.
File metadata
- Download URL: deeptaxa_rrna-1.0.1-py3-none-any.whl
- Upload date:
- Size: 59.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9e8eb2e9e354050c514a106d32fdd726a15e557c11745b1a7b1e5cee0b791014
|
|
| MD5 |
f255789a38941fe2dd7f260cf82c9a80
|
|
| BLAKE2b-256 |
b00cf1f716069519c36e6213b331cca6003181fc1d5abfa9b33fd56f815e5a47
|