A catalog of sklearn-compatible deep learning classifiers for MALDI-TOF binned spectra
Project description
MaldiDeepKit
A catalog of sklearn-compatible deep learning classifiers for MALDI-TOF binned spectra
Installation • Features • Quick Start • Documentation • Tutorials • MaldiSuite • Contributing • Citing • License
MaldiDeepKit is part of the MaldiSuite ecosystem and complements MaldiAMRKit and MaldiBatchKit: where MaldiAMRKit handles preprocessing, alignment and AMR-aware evaluation, and MaldiBatchKit harmonises multi-centre spectra, MaldiDeepKit focuses on the classification step, providing four PyTorch architectures wrapped in a unified scikit-learn estimator API with defaults calibrated for 6000-bin MALDI-TOF input.
Installation
pip install maldideepkit
maldiamrkit is a core dependency and is installed automatically - MaldiDeepKit duck-types on the MaldiSet data model and reuses maldiamrkit.alignment.Warping for leak-safe spectral warping.
Development Installation
git clone https://github.com/EttoreRocchi/MaldiDeepKit.git
cd MaldiDeepKit
pip install -e ".[dev]"
pre-commit install
See CONTRIBUTING.md for coding conventions, testing, and PR guidelines.
Features
- Unified sklearn API (
BaseEstimator+ClassifierMixin) for every classifier. Each one implementsfit/predict/predict_proba/score/get_params/set_paramsand plugs intoPipeline,cross_val_score, andGridSearchCVwith no glue code. - Four PyTorch architectures sharing the same base class and hyperparameter surface:
MaldiMLPClassifier- MLP with optional sigmoid-gated attention, interpretable per-bin gates.MaldiCNNClassifier- 1-D Conv1D + BatchNorm + ReLU + MaxPool blocks for local pattern learning.MaldiResNetClassifier- 1-D ResNet-18-style residual blocks for a deeper convolutional backbone.MaldiTransformerClassifier- 1-D Vision Transformer with global self-attention, pre-norm, LayerScale, and stochastic depth.
- MALDI-TOF defaults: kernel sizes, depths, patch widths, and warmup / cosine-annealing schedules are tuned for 6000-bin spectra in the 2000-20000 Da range.
- Auto-scaling for non-default layouts: every classifier ships a
from_spectrum(bin_width, input_dim, **overrides)factory that rescales conv kernels and patches when the user trims the m/z range or picks a different bin width. See the Spectrum scaling guide. - Training recipes: AdamW-on-
weight_decaydispatch, gradient clipping, linear warmup + cosine annealing, focal loss, label smoothing, mixed precision (AMP), Stochastic Weight Averaging, Sharpness-Aware Minimization, post-hoc threshold tuning, and temperature scaling - all exposed as classifier kwargs. - Leak-safe spectral warping: pass any sklearn-style transformer (
maldiamrkit.alignment.Warping) viawarping=; it is fitted on the training fold only and applied to both splits during training and to new spectra atpredicttime, before per-feature standardization. - MaldiSet integration: pass a
maldiamrkit.MaldiSetdirectly tofit/predict; MaldiDeepKit duck-types on the DataFrame-like.Xattribute, so MaldiSuite's data model flows end-to-end. - Persistence:
save()writes a state-dict.ptplus a hyperparameter.json(and a sibling.warper.pklif a warper was fitted);load()fails fast on class orinput_dimmismatches. - CPU-friendly: every classifier runs on CPU, which is what the project's CI tests against; CUDA speeds up the models' training significantly.
Documentation
Full documentation is available at maldideepkit.readthedocs.io.
Quick Start
Fit a Classifier
Every MaldiDeepKit classifier exposes the standard scikit-learn estimator API. Swapping architectures is a one-line change:
import numpy as np
from maldideepkit import MaldiMLPClassifier
rng = np.random.default_rng(0)
X = rng.standard_normal((200, 6000)).astype("float32") # 200 binned spectra
y = rng.integers(0, 2, size=200)
clf = MaldiMLPClassifier(random_state=0)
clf.fit(X, y)
proba = clf.predict_proba(X)
preds = clf.predict(X)
acc = clf.score(X, y)
# Inspect attention weights (MLP only)
weights = clf.get_attention_weights(X[:10]) # (10, hidden_dim)
Inside an Sklearn Pipeline
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from maldideepkit import MaldiCNNClassifier
pipe = Pipeline([
("scaler", StandardScaler()),
("clf", MaldiCNNClassifier(random_state=0)),
])
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
scores = cross_val_score(pipe, X, y, cv=cv, scoring="accuracy")
print(f"CV accuracy: {scores.mean():.3f} +/- {scores.std():.3f}")
MaldiSet Integration
Integration with MaldiAMRKit is first-class: pass a MaldiSet directly.
from maldiamrkit import MaldiSet
from maldideepkit import MaldiCNNClassifier
ds = MaldiSet.from_directory(
"spectra/", "metadata.csv",
aggregate_by={"antibiotics": "Ciprofloxacin"},
n_jobs=-1,
)
clf = MaldiCNNClassifier(random_state=0).fit(ds, ds.y.squeeze())
preds = clf.predict(ds)
Auto-Scaling for Custom Layouts
When the spectrum layout deviates from the reference 6000-bin / 3 Da default, from_spectrum rescales conv kernels and patches:
from maldideepkit import MaldiCNNClassifier, MaldiTransformerClassifier
# Reference layout (kernel_size=7, patch_size=4)
cnn = MaldiCNNClassifier.from_spectrum(bin_width=3, input_dim=6000)
# Wider bins -> smaller kernel
cnn_coarse = MaldiCNNClassifier.from_spectrum(bin_width=6, input_dim=3000)
# Transformer is scale-agnostic; only input_dim is recorded
tr = MaldiTransformerClassifier.from_spectrum(bin_width=1, input_dim=18000)
See the Spectrum scaling guide for the semantics behind each knob.
Save and Load
clf.save("my_model")
# -> my_model.pt, my_model.json, my_model.warper.pkl (if warping was used)
restored = MaldiCNNClassifier.load("my_model")
For more examples covering training recipes, calibration, attention inspection, and ensembles, see the Quickstart Guide and the API Reference.
Algorithms
| Classifier | Backbone | Typical use case |
|---|---|---|
MaldiMLPClassifier |
MLP + optional sigmoid-gated attention | Fast baseline with interpretable feature gates |
MaldiCNNClassifier |
1-D Conv1D + BatchNorm + ReLU + MaxPool blocks | Local pattern learning from binned spectra |
MaldiResNetClassifier |
1-D ResNet-18-style residual blocks | Deeper convolutional backbone |
MaldiTransformerClassifier |
1-D Vision Transformer (LayerScale, stochastic depth) | Long-range peak combinations via global self-attention |
All four inherit from BaseSpectralClassifier and share the same hyperparameter surface for optimisation, device placement, early stopping, calibration, and persistence.
Shared Training Knobs
| Feature | Kwarg | Notes |
|---|---|---|
| Decoupled weight decay | weight_decay |
Switches Adam to AdamW when > 0. Default 0 (MLP/CNN), 1e-4 (ResNet), 0.05 (Transformer). |
| Gradient clipping | grad_clip_norm |
clip_grad_norm_ before every step. Default on (1.0) for the deep models. |
| Warmup + cosine annealing | warmup_epochs |
Replaces plateau scheduler. Default 5 (deep models), 0 (MLP/CNN). |
| Stochastic depth (Transformer) | drop_path_rate |
Linearly ramped across blocks. Default 0.1. |
| LayerScale (Transformer) | layerscale_init |
Per-channel residual scaling initialised near zero - crucial on small cohorts. |
| Focal loss | loss="focal" + focal_gamma |
For imbalanced binary problems. |
| Label smoothing | label_smoothing |
Passed to both cross-entropy and focal paths. |
| Stochastic Weight Averaging | swa_start_epoch |
AveragedModel replaces best-val at end of fit. |
| Threshold tuning | tune_threshold |
Binary only; sweeps balanced-accuracy / F1 / Youden on val. |
| Temperature scaling | calibrate_temperature |
One-parameter LBFGS calibration on val logits. |
| Sharpness-Aware Minimization | use_sam + sam_rho |
Two-pass training, ~2× compute. |
| Spectral warping | warping |
Any Warping-like sklearn transformer; fitted on train only, applied before standardization. |
Utilities
maldideepkit.utils exposes:
find_lr(clf, X, y)- learning-rate finder.tune_threshold/fit_temperature- post-hoc calibrators usable standalone.FocalLoss,SAMOptimizer,DropPath- building blocks for custom training loops.
Tutorials
For more detailed examples, see the notebooks:
- Quick Start - Fit a
MaldiMLPClassifierand explore the sklearn-compatible API. - Model Comparison - Train all four classifiers on the same dataset and compare accuracy.
- Attention Interpretation - Visualise the sigmoid-gated attention learned by
MaldiMLPClassifier. - Full Pipeline - End-to-end template: MaldiAMRKit preprocessing + MaldiDeepKit classification.
MaldiSuite Ecosystem
MaldiDeepKit is the third package of the MaldiSuite ecosystem:
- MaldiAMRKit - preprocessing, alignment, peak detection, differential analysis, and classical-ML evaluation for MALDI-TOF AMR workflows.
- MaldiBatchKit - batch-effect correction and harmonisation for multi-centre / multi-instrument MALDI-TOF spectra.
- MaldiDeepKit (this package) - sklearn-compatible deep learning classifiers.
The three packages share the MaldiSet / MaldiSpectrum data model and are designed to compose in a single end-to-end pipeline.
Requirements
The models benefit significantly from CUDA; CPU fallback is supported for all models.
Contributing
Pull requests, bug reports, and feature ideas are welcome. See the Contributing Guide for how to get started.
Citing
If you use MaldiDeepKit, please cite this repository until the companion paper is available.
Related publications from the MaldiSuite ecosystem:
Rocchi, E., Nicitra, E., Calvo, M. et al. Combining mass spectrometry and machine learning models for predicting Klebsiella pneumoniae antimicrobial resistance: a multicenter experience from clinical isolates in Italy. BMC Microbiol (2026). doi:10.1186/s12866-025-04657-2
See the full publications list for more papers using the MaldiSuite.
License
This project is licensed under the MIT License. See the LICENSE file for details.
Acknowledgements
The architectures and training recipes bundled in MaldiDeepKit are 1-D adaptations of well-established networks. In particular:
ResNet - He K, Zhang X, Ren S, Sun J (2016). Deep Residual Learning for Image Recognition. CVPR. doi:10.1109/CVPR.2016.90
Vision Transformer - Dosovitskiy A, Beyer L, Kolesnikov A, et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR. arXiv:2010.11929
LayerScale - Touvron H, Cord M, Sablayrolles A, et al. (2021). Going deeper with Image Transformers. ICCV. arXiv:2103.17239
Stochastic Depth - Huang G, Sun Y, Liu Z, Sedra D, Weinberger K (2016). Deep Networks with Stochastic Depth. ECCV. arXiv:1603.09382
Temperature Scaling - Guo C, Pleiss G, Sun Y, Weinberger KQ (2017). On Calibration of Modern Neural Networks. ICML. arXiv:1706.04599
Sharpness-Aware Minimization - Foret P, Kleiner A, Mobahi H, Neyshabur B (2021). Sharpness-Aware Minimization for Efficiently Improving Generalization. ICLR. arXiv:2010.01412
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file maldideepkit-0.1.0.tar.gz.
File metadata
- Download URL: maldideepkit-0.1.0.tar.gz
- Upload date:
- Size: 79.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8111915373fcc222c72314e9d6f0f26610e85627693bb87542592f8efe29420e
|
|
| MD5 |
a89a8febe0985ed6ce910350be330c53
|
|
| BLAKE2b-256 |
f4743d6f2d9df63544958ea711d2d1e1afeb0b45dbebbca3562128b1401a7bf5
|
File details
Details for the file maldideepkit-0.1.0-py3-none-any.whl.
File metadata
- Download URL: maldideepkit-0.1.0-py3-none-any.whl
- Upload date:
- Size: 65.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dc68434f28d9fcde624f8de683af7d5d46fe7fdb8164d854dec14bb4641d4b19
|
|
| MD5 |
73bd78022f1d4f76b215d7982d212d89
|
|
| BLAKE2b-256 |
452f7ca305506713ad68d193bf1bc64bd04dee75ef4fcbbefc7e6837cc54186e
|