Skip to main content

A deep learning framework for hierarchical taxonomy classification of 16S rRNA gene sequences.

Project description

DeepTaxa

License Hugging Face Tutorials Last Commit Issues GitHub Stars

DeepTaxa is a deep learning framework for hierarchical taxonomic classification of 16S rRNA gene sequences. It classifies sequences into all seven taxonomic ranks (Domain through Species) in a single forward pass, achieving 92.96% species-level accuracy (3-seed mean) on the Greengenes2 2024.09 test set.


Table of Contents

  1. Performance
  2. Installation
  3. Quick Start
  4. Data and Pre-Trained Models
  5. Training
  6. Experimentation
  7. Scripts
  8. Tutorials
  9. License
  10. Citation
  11. Contact
  12. Acknowledgements

Performance

The published HybridCNNBERT checkpoint achieves the following on 69,335 held-out test sequences from Greengenes2 2024.09 (3-seed mean across seeds 42, 123, 456):

Rank Accuracy F1 ECE
Domain 99.98% 99.98% 0.0001
Phylum 99.69% 99.68% 0.0023
Class 99.63% 99.59% 0.0024
Order 99.07% 98.97% 0.0056
Family 98.61% 98.41% 0.0075
Genus 96.90% 96.48% 0.0144
Species 92.96% 92.12% 0.0242

Cross-seed standard deviation is at most 0.0008 F1 at every rank (species std 0.0008 F1 / 0.07 percentage points accuracy), demonstrating high reproducibility.

Architecture

Component Configuration
CNN embed_dim=896, 256 filters, kernels [3, 5, 7], 1 conv layer
BERT 4 layers, 7 heads, hidden=896, FFN=3584, GELU, random init
Fusion Learnable alpha/beta weights + BERT residual connection
Training Cross-entropy loss, LR=5e-4, batch=64, dropout=0.20, 10 epochs

Three architectures are available:

  • HybridCNNBERTClassifier (default): Fuses CNN local motif features with BERT global context. Used for the published checkpoints.
  • CNNClassifier: Multi-kernel convolutional network only. Faster training, slightly lower species accuracy.
  • BERTClassifier: Transformer encoder only. On its own, a from-scratch transformer underperforms substantially at the species rank; provided mainly for ablation.

Pre-Trained Checkpoints

Two checkpoints are hosted on Hugging Face:

Checkpoint Training data Species accuracy Parameters
deeptaxa-full-length-v1.pt Full-length 16S (277,336 sequences, ~1,500 bp) 92.96% (3-seed mean) 76.4 M
deeptaxa-v3v4-v1.pt In-silico V3-V4 amplicons (~420 bp, 273,003 amplicons) 87.55% (seed 42) 75.8 M

Both checkpoints share the same compact architecture (the small parameter difference reflects smaller per-rank classifier heads on the V3-V4 model, which has a smaller species vocabulary: 8,347 vs 16,909). A config.json with full model metadata is also available.


Installation

DeepTaxa requires Python 3.10 or later. We recommend using a Conda environment:

git clone https://github.com/systems-genomics-lab/deeptaxa.git
cd deeptaxa
conda create --name deeptaxa_env python=3.10 -y
conda activate deeptaxa_env
pip install .
deeptaxa --version

Dependencies (torch, transformers, pandas, numpy, scikit-learn, h5py, etc.) are specified in pyproject.toml and installed automatically.

Note: For GPU support, install a CUDA-compatible PyTorch build before running pip install .. See the PyTorch installation guide.


Quick Start

Predict with the pre-trained model (no training data needed):

# Download the checkpoint
mkdir -p ../deeptaxa-data/models
wget -P ../deeptaxa-data/models \
  https://huggingface.co/systems-genomics-lab/deeptaxa/resolve/main/deeptaxa-full-length-v1.pt

# Classify sequences
deeptaxa predict \
  --fasta-file your_sequences.fna \
  --checkpoint ../deeptaxa-data/models/deeptaxa-full-length-v1.pt \
  --output-dir ../deeptaxa-outputs/predictions

Evaluate against known labels (adds per-rank accuracy, F1, ECE to the output):

deeptaxa predict \
  --fasta-file ../deeptaxa-data/greengenes/gg_2024_09_testing.fna.gz \
  --taxonomy-file ../deeptaxa-data/greengenes/gg_2024_09_testing.tsv.gz \
  --checkpoint ../deeptaxa-data/models/deeptaxa-full-length-v1.pt \
  --output-dir ../deeptaxa-outputs/evaluation

Inspect a checkpoint:

deeptaxa describe \
  --checkpoint ../deeptaxa-data/models/deeptaxa-full-length-v1.pt

Tip: Run deeptaxa train --help or deeptaxa predict --help for a full list of options.


Data and Pre-Trained Models

Datasets and checkpoints are hosted on Hugging Face. Store them in a sibling directory outside the codebase:

working_directory/
├── deeptaxa/              # This repository
├── deeptaxa-data/         # Datasets and checkpoints
│   ├── greengenes/
│   │   ├── gg_2024_09_training.fna.gz    (277,336 sequences, ~96 MB)
│   │   ├── gg_2024_09_training.tsv.gz    (taxonomy labels, ~2.6 MB)
│   │   ├── gg_2024_09_testing.fna.gz     (69,335 sequences, ~24 MB)
│   │   └── gg_2024_09_testing.tsv.gz     (taxonomy labels, ~0.8 MB)
│   └── models/
│       ├── deeptaxa-full-length-v1.pt
│       └── deeptaxa-v3v4-v1.pt
└── deeptaxa-outputs/      # Training and prediction outputs

DeepTaxa uses the Greengenes2 database (2024.09 release), reformatted and hosted on Hugging Face.

Download

# Dataset
mkdir -p deeptaxa-data/greengenes && cd deeptaxa-data/greengenes
for f in gg_2024_09_training.fna.gz gg_2024_09_training.tsv.gz \
         gg_2024_09_testing.fna.gz gg_2024_09_testing.tsv.gz; do
  wget https://huggingface.co/datasets/systems-genomics-lab/greengenes/resolve/main/$f
done

# Checkpoints
mkdir -p ../models && cd ../models
wget https://huggingface.co/systems-genomics-lab/deeptaxa/resolve/main/deeptaxa-full-length-v1.pt
wget https://huggingface.co/systems-genomics-lab/deeptaxa/resolve/main/deeptaxa-v3v4-v1.pt
wget https://huggingface.co/systems-genomics-lab/deeptaxa/resolve/main/config.json

Note: Checkpoint files use PyTorch's pickle-based serialization. Download them only from the official Hugging Face repository.


Training

All architecture hyperparameters default to the published (compact) configuration, so a minimal training command reproduces the published checkpoint:

deeptaxa train \
  --fasta-file ../deeptaxa-data/greengenes/gg_2024_09_training.fna.gz \
  --taxonomy-file ../deeptaxa-data/greengenes/gg_2024_09_training.tsv.gz \
  --model-type hybridcnnbert \
  --output-dir ../deeptaxa-outputs/

Training takes approximately 1 h 20 m on an NVIDIA RTX 4090 (or 2 h 35 m on an NVIDIA A40) for 10 epochs.

Output

Each training run produces:

  • checkpoints/deeptaxa_<uuid>_epoch<N>.pt: Model weights, optimizer state, scheduler state, and label encoders for each epoch.
  • metrics/deeptaxa_<uuid>_epoch<N>.json: Per-epoch validation loss, accuracy, F1, precision, and recall at each rank.
  • deeptaxa_uuid.txt: The unique run identifier.

Early Stopping

To stop training when validation loss plateaus:

deeptaxa train \
  --fasta-file ../deeptaxa-data/greengenes/gg_2024_09_training.fna.gz \
  --taxonomy-file ../deeptaxa-data/greengenes/gg_2024_09_training.tsv.gz \
  --model-type hybridcnnbert \
  --epochs 20 \
  --early-stopping-patience 3 \
  --output-dir ../deeptaxa-outputs/

Setting --early-stopping-patience 0 (the default) disables early stopping.


Experimentation

The default configuration uses DNABERT-2 tokenization, cross-entropy loss, and uniform rank weighting. Each choice can be varied independently for ablation studies.

Encoding comparison

# Default: DNABERT-2 BPE tokenization
deeptaxa train --model-type cnn --encoding dnabert ...

# Ablation: one-hot nucleotide encoding (4-channel, no pretrained tokenizer)
deeptaxa train --model-type cnn --encoding onehot ...

Loss function comparison

# Default: cross-entropy
deeptaxa train --model-type hybridcnnbert --loss-type cross_entropy ...

# Ablation: focal loss (gamma=2.0)
deeptaxa train --model-type hybridcnnbert --loss-type focal --focal-gamma 2.0 ...

Architecture comparison

Train CNN-only, BERT-only, or the hybrid under the same data and hyperparameters using --model-type cnn, --model-type bert, or --model-type hybridcnnbert.

Calibration

When --taxonomy-file is provided at prediction time, DeepTaxa computes Expected Calibration Error (ECE) alongside accuracy, F1, precision, recall, and AUC. ECE measures the gap between predicted confidence and observed accuracy across 10 equal-width bins. All metrics are saved to metrics.json.


Scripts

The scripts/ directory contains reusable tools for common workflows:

Script Purpose
deeptaxa_workflow.sh End-to-end workflow: train, resume, describe, predict
run_experiment.sh Central experiment runner with logging and timing
run_ablation.sh Ablation study: architecture, encoding, and loss variants
run_amplicon_eval.sh Simulated amplicon evaluation (V3-V4, V4)
run_similarity_eval.sh Similarity-stratified evaluation using vsearch
calibration_diagnosis.sh A/B comparison of temperature configurations
calibration_sweep.sh Multi-configuration temperature sweep
simulate_amplicons.py Extract amplicon regions via in-silico PCR
sequence_similarity.py Compute train-test nearest-neighbor identity

Tutorials

Interactive tutorials with executable code are published at systems-genomics-lab.github.io/deeptaxa:

  • Prediction: Classify sequences with the pre-trained model
  • Training: Train from scratch on Greengenes2
  • Analysis: Evaluate performance, calibration, and error patterns
  • Architecture: Model internals and extensibility

License


Citation

If DeepTaxa contributes to your research, please cite our paper in Bioinformatics Advances: https://doi.org/10.1093/bioadv/vbag166

@article{salah2026deeptaxa,
  title={{DeepTaxa}: A Hybrid {CNN}-{BERT} Framework for {16S} {rRNA} Taxonomic Classification},
  author={Salah, Rana and AbdElaal, Khlood R. and Ghonaim, Lobna and Awe, Olaitan I. and Moustafa, Ahmed},
  journal={Bioinformatics Advances},
  year={2026},
  doi={10.1093/bioadv/vbag166},
  publisher={Oxford University Press}
}

For the Greengenes dataset:

@article{mcdonald2024greengenes,
  title={Greengenes2 unifies microbial data in a single reference tree},
  author={McDonald, Daniel and Jiang, Yueyu and Balaban, Metin and others},
  journal={Nature Biotechnology},
  volume={42},
  pages={715--718},
  year={2024},
  doi={10.1038/s41587-023-01845-1}
}

Contact

To report bugs, suggest features, or contribute code, open an issue on GitHub.


Acknowledgements

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deeptaxa_rrna-1.0.1.tar.gz (120.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

deeptaxa_rrna-1.0.1-py3-none-any.whl (59.8 kB view details)

Uploaded Python 3

File details

Details for the file deeptaxa_rrna-1.0.1.tar.gz.

File metadata

  • Download URL: deeptaxa_rrna-1.0.1.tar.gz
  • Upload date:
  • Size: 120.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for deeptaxa_rrna-1.0.1.tar.gz
Algorithm Hash digest
SHA256 ef69f222733f7b0d322e06959f3d2f6b0c1ca11b304f1cf18baf46b4f09e2342
MD5 810d19396633a039f48175245b12556e
BLAKE2b-256 11efc33c41d514113eeafdb16ff4df02ae1bb08193e9bbfbfce56f9149a4b1a3

See more details on using hashes here.

File details

Details for the file deeptaxa_rrna-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: deeptaxa_rrna-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 59.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for deeptaxa_rrna-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9e8eb2e9e354050c514a106d32fdd726a15e557c11745b1a7b1e5cee0b791014
MD5 f255789a38941fe2dd7f260cf82c9a80
BLAKE2b-256 b00cf1f716069519c36e6213b331cca6003181fc1d5abfa9b33fd56f815e5a47

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page