Skip to main content

A catalog of sklearn-compatible deep learning classifiers for MALDI-TOF binned spectra

Project description

MaldiDeepKit

CI Coverage Documentation

PyPI Version Python License

MaldiDeepKit

A catalog of sklearn-compatible deep learning classifiers for MALDI-TOF binned spectra

InstallationFeaturesQuick StartDocumentationTutorialsMaldiSuiteContributingCitingLicense

MaldiDeepKit is part of the MaldiSuite ecosystem and complements MaldiAMRKit and MaldiBatchKit: where MaldiAMRKit handles preprocessing, alignment and AMR-aware evaluation, and MaldiBatchKit harmonises multi-centre spectra, MaldiDeepKit focuses on the classification step, providing four PyTorch architectures wrapped in a unified scikit-learn estimator API with defaults calibrated for 6000-bin MALDI-TOF input.

Installation

pip install maldideepkit

maldiamrkit is a core dependency and is installed automatically - MaldiDeepKit duck-types on the MaldiSet data model and reuses maldiamrkit.alignment.Warping for leak-safe spectral warping.

Development Installation

git clone https://github.com/EttoreRocchi/MaldiDeepKit.git
cd MaldiDeepKit
pip install -e ".[dev]"
pre-commit install

See CONTRIBUTING.md for coding conventions, testing, and PR guidelines.

Features

  • Unified sklearn API (BaseEstimator + ClassifierMixin) for every classifier. Each one implements fit / predict / predict_proba / score / get_params / set_params and plugs into Pipeline, cross_val_score, and GridSearchCV with no glue code.
  • Four PyTorch architectures sharing the same base class and hyperparameter surface:
    • MaldiMLPClassifier - MLP with optional sigmoid-gated attention, interpretable per-bin gates.
    • MaldiCNNClassifier - 1-D Conv1D + BatchNorm + ReLU + MaxPool blocks for local pattern learning.
    • MaldiResNetClassifier - 1-D ResNet-18-style residual blocks for a deeper convolutional backbone.
    • MaldiTransformerClassifier - 1-D Vision Transformer with global self-attention, pre-norm, LayerScale, and stochastic depth.
  • MALDI-TOF defaults: kernel sizes, depths, patch widths, and warmup / cosine-annealing schedules are tuned for 6000-bin spectra in the 2000-20000 Da range.
  • Auto-scaling for non-default layouts: every classifier ships a from_spectrum(bin_width, input_dim, **overrides) factory that rescales conv kernels and patches when the user trims the m/z range or picks a different bin width. See the Spectrum scaling guide.
  • Training recipes: AdamW-on-weight_decay dispatch, gradient clipping, linear warmup + cosine annealing, focal loss, label smoothing, mixed precision (AMP), Stochastic Weight Averaging, Sharpness-Aware Minimization, post-hoc threshold tuning, and temperature scaling - all exposed as classifier kwargs.
  • Leak-safe spectral warping: pass any sklearn-style transformer (maldiamrkit.alignment.Warping) via warping=; it is fitted on the training fold only and applied to both splits during training and to new spectra at predict time, before per-feature standardization.
  • MaldiSet integration: pass a maldiamrkit.MaldiSet directly to fit / predict; MaldiDeepKit duck-types on the DataFrame-like .X attribute, so MaldiSuite's data model flows end-to-end.
  • Persistence: save() writes a state-dict .pt plus a hyperparameter .json (and a sibling .warper.pkl if a warper was fitted); load() fails fast on class or input_dim mismatches.
  • CPU-friendly: every classifier runs on CPU, which is what the project's CI tests against; CUDA speeds up the models' training significantly.

Documentation

Full documentation is available at maldideepkit.readthedocs.io.

Quick Start

Fit a Classifier

Every MaldiDeepKit classifier exposes the standard scikit-learn estimator API. Swapping architectures is a one-line change:

import numpy as np
from maldideepkit import MaldiMLPClassifier

rng = np.random.default_rng(0)
X = rng.standard_normal((200, 6000)).astype("float32")  # 200 binned spectra
y = rng.integers(0, 2, size=200)

clf = MaldiMLPClassifier(random_state=0)
clf.fit(X, y)

proba = clf.predict_proba(X)
preds = clf.predict(X)
acc = clf.score(X, y)

# Inspect attention weights (MLP only)
weights = clf.get_attention_weights(X[:10])  # (10, hidden_dim)

Inside an Sklearn Pipeline

from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from maldideepkit import MaldiCNNClassifier

pipe = Pipeline([
    ("scaler", StandardScaler()),
    ("clf", MaldiCNNClassifier(random_state=0)),
])

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
scores = cross_val_score(pipe, X, y, cv=cv, scoring="accuracy")
print(f"CV accuracy: {scores.mean():.3f} +/- {scores.std():.3f}")

MaldiSet Integration

Integration with MaldiAMRKit is first-class: pass a MaldiSet directly.

from maldiamrkit import MaldiSet
from maldideepkit import MaldiCNNClassifier

ds = MaldiSet.from_directory(
    "spectra/", "metadata.csv",
    aggregate_by={"antibiotics": "Ciprofloxacin"},
    n_jobs=-1,
)
clf = MaldiCNNClassifier(random_state=0).fit(ds, ds.y.squeeze())
preds = clf.predict(ds)

Auto-Scaling for Custom Layouts

When the spectrum layout deviates from the reference 6000-bin / 3 Da default, from_spectrum rescales conv kernels and patches:

from maldideepkit import MaldiCNNClassifier, MaldiTransformerClassifier

# Reference layout (kernel_size=7, patch_size=4)
cnn = MaldiCNNClassifier.from_spectrum(bin_width=3, input_dim=6000)

# Wider bins -> smaller kernel
cnn_coarse = MaldiCNNClassifier.from_spectrum(bin_width=6, input_dim=3000)

# Transformer is scale-agnostic; only input_dim is recorded
tr = MaldiTransformerClassifier.from_spectrum(bin_width=1, input_dim=18000)

See the Spectrum scaling guide for the semantics behind each knob.

Save and Load

clf.save("my_model")
# -> my_model.pt, my_model.json, my_model.warper.pkl (if warping was used)
restored = MaldiCNNClassifier.load("my_model")

For more examples covering training recipes, calibration, attention inspection, and ensembles, see the Quickstart Guide and the API Reference.

Algorithms

Classifier Backbone Typical use case
MaldiMLPClassifier MLP + optional sigmoid-gated attention Fast baseline with interpretable feature gates
MaldiCNNClassifier 1-D Conv1D + BatchNorm + ReLU + MaxPool blocks Local pattern learning from binned spectra
MaldiResNetClassifier 1-D ResNet-18-style residual blocks Deeper convolutional backbone
MaldiTransformerClassifier 1-D Vision Transformer (LayerScale, stochastic depth) Long-range peak combinations via global self-attention

All four inherit from BaseSpectralClassifier and share the same hyperparameter surface for optimisation, device placement, early stopping, calibration, and persistence.

Shared Training Knobs

Feature Kwarg Notes
Decoupled weight decay weight_decay Switches Adam to AdamW when > 0. Default 0 (MLP/CNN), 1e-4 (ResNet), 0.05 (Transformer).
Gradient clipping grad_clip_norm clip_grad_norm_ before every step. Default on (1.0) for the deep models.
Warmup + cosine annealing warmup_epochs Replaces plateau scheduler. Default 5 (deep models), 0 (MLP/CNN).
Stochastic depth (Transformer) drop_path_rate Linearly ramped across blocks. Default 0.1.
LayerScale (Transformer) layerscale_init Per-channel residual scaling initialised near zero - crucial on small cohorts.
Focal loss loss="focal" + focal_gamma For imbalanced binary problems.
Label smoothing label_smoothing Passed to both cross-entropy and focal paths.
Stochastic Weight Averaging swa_start_epoch AveragedModel replaces best-val at end of fit.
Threshold tuning tune_threshold Binary only; sweeps balanced-accuracy / F1 / Youden on val.
Temperature scaling calibrate_temperature One-parameter LBFGS calibration on val logits.
Sharpness-Aware Minimization use_sam + sam_rho Two-pass training, ~2× compute.
Spectral warping warping Any Warping-like sklearn transformer; fitted on train only, applied before standardization.

Utilities

maldideepkit.utils exposes:

  • find_lr(clf, X, y) - learning-rate finder.
  • tune_threshold / fit_temperature - post-hoc calibrators usable standalone.
  • FocalLoss, SAMOptimizer, DropPath - building blocks for custom training loops.

Tutorials

For more detailed examples, see the notebooks:

  • Quick Start - Fit a MaldiMLPClassifier and explore the sklearn-compatible API.
  • Model Comparison - Train all four classifiers on the same dataset and compare accuracy.
  • Attention Interpretation - Visualise the sigmoid-gated attention learned by MaldiMLPClassifier.
  • Full Pipeline - End-to-end template: MaldiAMRKit preprocessing + MaldiDeepKit classification.

MaldiSuite Ecosystem

MaldiDeepKit is the third package of the MaldiSuite ecosystem:

  • MaldiAMRKit - preprocessing, alignment, peak detection, differential analysis, and classical-ML evaluation for MALDI-TOF AMR workflows.
  • MaldiBatchKit - batch-effect correction and harmonisation for multi-centre / multi-instrument MALDI-TOF spectra.
  • MaldiDeepKit (this package) - sklearn-compatible deep learning classifiers.

The three packages share the MaldiSet / MaldiSpectrum data model and are designed to compose in a single end-to-end pipeline.

Requirements

The models benefit significantly from CUDA; CPU fallback is supported for all models.

Contributing

Pull requests, bug reports, and feature ideas are welcome. See the Contributing Guide for how to get started.

Citing

If you use MaldiDeepKit, please cite this repository until the companion paper is available.

Related publications from the MaldiSuite ecosystem:

Rocchi, E., Nicitra, E., Calvo, M. et al. Combining mass spectrometry and machine learning models for predicting Klebsiella pneumoniae antimicrobial resistance: a multicenter experience from clinical isolates in Italy. BMC Microbiol (2026). doi:10.1186/s12866-025-04657-2

See the full publications list for more papers using the MaldiSuite.

License

This project is licensed under the MIT License. See the LICENSE file for details.

Acknowledgements

The architectures and training recipes bundled in MaldiDeepKit are 1-D adaptations of well-established networks. In particular:

ResNet - He K, Zhang X, Ren S, Sun J (2016). Deep Residual Learning for Image Recognition. CVPR. doi:10.1109/CVPR.2016.90

Vision Transformer - Dosovitskiy A, Beyer L, Kolesnikov A, et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR. arXiv:2010.11929

LayerScale - Touvron H, Cord M, Sablayrolles A, et al. (2021). Going deeper with Image Transformers. ICCV. arXiv:2103.17239

Stochastic Depth - Huang G, Sun Y, Liu Z, Sedra D, Weinberger K (2016). Deep Networks with Stochastic Depth. ECCV. arXiv:1603.09382

Temperature Scaling - Guo C, Pleiss G, Sun Y, Weinberger KQ (2017). On Calibration of Modern Neural Networks. ICML. arXiv:1706.04599

Sharpness-Aware Minimization - Foret P, Kleiner A, Mobahi H, Neyshabur B (2021). Sharpness-Aware Minimization for Efficiently Improving Generalization. ICLR. arXiv:2010.01412

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

maldideepkit-0.1.0.tar.gz (79.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

maldideepkit-0.1.0-py3-none-any.whl (65.0 kB view details)

Uploaded Python 3

File details

Details for the file maldideepkit-0.1.0.tar.gz.

File metadata

  • Download URL: maldideepkit-0.1.0.tar.gz
  • Upload date:
  • Size: 79.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for maldideepkit-0.1.0.tar.gz
Algorithm Hash digest
SHA256 8111915373fcc222c72314e9d6f0f26610e85627693bb87542592f8efe29420e
MD5 a89a8febe0985ed6ce910350be330c53
BLAKE2b-256 f4743d6f2d9df63544958ea711d2d1e1afeb0b45dbebbca3562128b1401a7bf5

See more details on using hashes here.

File details

Details for the file maldideepkit-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: maldideepkit-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 65.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for maldideepkit-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 dc68434f28d9fcde624f8de683af7d5d46fe7fdb8164d854dec14bb4641d4b19
MD5 73bd78022f1d4f76b215d7982d212d89
BLAKE2b-256 452f7ca305506713ad68d193bf1bc64bd04dee75ef4fcbbefc7e6837cc54186e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page