Skip to main content

An Ensemble Machine Learning algorithm for classification of bacterial strains.

Project description

StrainFish

strainfish is a weighted ensemble machine learning algorithm with multiple DNA sequence encoders and logic, specifically designed for classification of marker sequences.

Conceived and built by Kranti Konganti, HFP

Latest version: 0.2.2

  • Multiple DNA sequence encoders for NVIDIA GPU-accelerated training.
  • A weighted Ensemble machine-learning model generation with sensible defaults.
  • NVIDIA GPU-accelerated Learning and Prediction only!
  • Important Note: This software is under active development and as such some features are experimental. Results should be thoroughly validated and independently verified before use in critical applications or publications.

Table of Contents

  1. Installation
  2. Quick Start
  3. Training Models
  4. Making Predictions
  5. Configuration Options
  6. Test Data and Examples
  7. Dependencies
  8. License

Installation

StrainFish requires Python 3.12 or newer and NVIDIA GPU support for its core machine learning processes.

Step 1: Install StrainFish

First, install the base StrainFish package from PyPI:

pip install strainfish

This command installs StrainFish but not the necessary cuML (GPU) libraries. The package will not be fully functional until cuML is installed in Step 2.

Step 2: Install cuML

The following commands ensure that the correct cuML version, compatible with your CUDA environment, is installed alongside StrainFish. You must choose one of the following commands based on your system's CUDA version to install the compatible cuML library:

  • For systems with CUDA 12.x:

    pip install strainfish[cuda-12]
    
  • For systems with CUDA 13.x:

    pip install strainfish[cuda-13]
    

Verify cuML and CUDA Installation

After completing Step 2, verify that cuML can access the NVIDIA GPU. The following command uses the nvidia-ml-py dependency (included with StrainFish) to query the driver.

python -c "import pynvml; pynvml.nvmlInit(); print('\nNVIDIA CUDA driver version:', f'{pynvml.nvmlSystemGetCudaDriverVersion() // 1000}.{(pynvml.nvmlSystemGetCudaDriverVersion() % 100) // 10}'); pynvml.nvmlShutdown();"

This should output the CUDA version (e.g., 12.4) supported by your installed NVIDIA driver. Ensure this matches or is higher than the CUDA version required by your chosen cuML package.

Quick Start

Training a Model

To train a model on your DNA sequences:

strainfish train run \
  -f path/to/sequences.fasta \
  -l path/to/labels.csv \
  -o /path/to/models_output_dir/model_prefix

Predicting using a Model

To perform prediction using a trained model:

strainfish predict run \
  -f path/to/predict_sequences.fasta \
  -m /path/to/models_output_dir/model_prefix \
  -o path/to/results_directory

Training Models

StrainFish uses an ensemble approach for both training and prediction (XGBoost, RandomForest and NaiveBayes), and includes multiple DNA sequence encodings, though only one encoding can be used at a time during training.

Basic Training Command

strainfish train run \
  -f training_sequences.fasta \                             # Input FASTA file
  -l labels.csv \                                           # Labels CSV (id,label)
  -o /path/to/models_output_dir/model_prefix                # Output directory for models

Advanced Configuration

Optional StrainFish configuration options during training:

strainfish train run \
  -f training_sequences.fasta \
  -l labels.csv \
  -o model_output_dir \
  --encode-method tf \              # Encoding method: sm, sp, or tf
  --kmer 7 \                        # K-mer size for hashing
  --num-hashes 100 \                # Number of hashes per sequence
  --factor 21 \                     # Sequence overlap factor
  --chunk-size 200 \                # Size of DNA chunks
  --pseknc-weight 0.1 \             # Weight for PseKNC encoding
  --xgb-n-estimators 300 \          # XGBoost parameters
  --rf-n-estimators 100 \           # RandomForest parameters

Encoding Methods

StrainFish supports three DNA sequence encoding methods:

  • tf (TF-IDF): TF-IDF vectorization (Default)
  • sp (SentencePiece): Subword tokenization using SentencePiece models (Experimental)
  • sm (SOMH): MinHash based approach with PseKNC and sequencing composition weights (AT/GC ratio) (Experimental)

Making Predictions

Basic Prediction Command

strainfish predict run \
  -f prediction_sequences.fasta \                        # Input FASTA file(s)
  -m /path/to/models_output_dir/model_prefix \           # Path to trained model
  -o results_dir                                         # Output directory for predictions

Model Management

List available models:

strainfish predict list-models
# Or list models stored at a particular models directory:
strainfish predict list-models -md /path/to/models_dir

Configuration Options

Tunable parameters for StrainFish.

XGBoost Parameters

View all configurable XGBoost parameters:

strainfish train show-xgb-params

Key parameters:

  • --xgb-n-estimators: Number of boosting rounds
  • --xgb-max-depth: Maximum tree depth
  • --xgb-learning-rate: Learning rate for boosting
  • --xgb-subsample: Subsample ratio of the training instance

RandomForest Parameters

View all configurable RandomForest parameters:

strainfish train show-rf-params

Key parameters:

  • --rf-n-estimators: Number of trees in the forest
  • --rf-max-depth: Maximum depth of the tree
  • --rf-random-state: Random seed for reproducibility
  • --rf-min-samples-leaf: Minimum samples required at a leaf node

SentencePiece Parameters

View all configurable SentencePiece parameters:

strainfish train show-sp-params

Key parameters:

  • --sp-vocab-size: Vocabulary size for tokenization
  • --sp-max-sentence-length: Maximum sentence length
  • --sp-char-cov: Character coverage ratio

Imbalance Handling Parameters

View all imbalance handling parameters:

strainfish train show-imb-params

Key parameters:

  • --imb-smote-k-neighbors: Number of neighbors for SMOTE
  • --imb-enn-n-neighbors: Number of neighbors for ENN cleaning

Test Data and Examples

This repository includes test data in the tests/test_input/ directory:

  • test.train.fasta: Training sequences in FASTA format
  • test.train.csv: Labels file with id,label columns
  • predict.fasta: Sequences for prediction using trained models

You can use these to test StrainFish functionality:

# Train a model using test data
strainfish train run \
  -f tests/test_input/test.train.fasta \
  -l tests/test_input/test.train.csv \
  -o test_output/test_model

# Make predictions on the trained model
strainfish predict run \
  -f tests/test_input/predict.fasta \
  -m test_output/test_model \
  -o prediction_results

Dependencies

StrainFish has the following main dependencies:

  • Core ML Libraries (GPU-accelerated): numpy, pandas, scikit-learn, xgboost, cuML (mandatory for functionality)
  • Sequence Processing: biopython, sourmash, sentencepiece
  • CLI Interface: rich, rich-click
  • Utilities: joblib, psutil, humanize, pynvml
  • Testing: pytest, pytest-cov

Note on cuML: The cuML library is essential for StrainFish's GPU-accelerated computations. It must be installed explicitly as an extra (e.g., strainfish[cuml-cu12]) during the pip install command after the base package installation. For a complete and version-specific list of all dependencies, including those for cuML, see pyproject.toml.

License

This project is licensed under the MIT License - see the LICENSE.md file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

strainfish-0.2.2.tar.gz (23.3 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

strainfish-0.2.2-py3-none-any.whl (27.6 MB view details)

Uploaded Python 3

File details

Details for the file strainfish-0.2.2.tar.gz.

File metadata

  • Download URL: strainfish-0.2.2.tar.gz
  • Upload date:
  • Size: 23.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.10 {"installer":{"name":"uv","version":"0.9.10"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Rocky Linux","version":"9.6","id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for strainfish-0.2.2.tar.gz
Algorithm Hash digest
SHA256 4c16d7cc76a6bbf298e22c71dfc9179de456aa906a47314b631b7ab62c89d551
MD5 69d93d36644e5c62fbdd153d110162a7
BLAKE2b-256 8777674eebb2e07df3653b8fd9ba3c536a8a4b5120304ff5c29fafbc1fc3a635

See more details on using hashes here.

File details

Details for the file strainfish-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: strainfish-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 27.6 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.10 {"installer":{"name":"uv","version":"0.9.10"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Rocky Linux","version":"9.6","id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for strainfish-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 07a321bbdeb86d2c6945377fb8f5152bf5b201a24c85f16e264cc66cb1e1c67b
MD5 dc8366bf3d380c5d6d9abbd879d2442d
BLAKE2b-256 f97a83e930c6d702e00d0f60e0931aa5b3c7c72b58d47ee0de540b118eae3222

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page