Skip to main content

MassSpecGym: A benchmark for the discovery and identification of molecules

Project description

MassSpecGym: A benchmark for the discovery and identification of molecules

Code style: black Code style: black

MassSpecGym provides three challenges for benchmarking the discovery and identification of new molecules from MS/MS spectra. The provided challenges abstract the process of scientific discovery from biological and environmental samples into well-defined machine learning problems.

📣 The paper will be available soon!

Installation

Installation steps:

conda create -n massspecgym python=3.11
conda activate massspecgym
git clone https://github.com/pluskal-lab/MassSpecGym.git; cd MassSpecGym
pip install -e .[dev,notebooks]

For AMD GPUs, you may need to install PyTorch for ROCm:

pip install -U torch==2.3.0 --index-url https://download.pytorch.org/whl/rocm6.0

📣 Easier installation via pip will be available soon!

MassSpecGym infrastructure

Train and evaluate your model 🚀

MassSpecGym allows you to implement, train, validate, and test your model with a few lines of code. Built on top of PyTorch Lightning, MassSpecGym abstracts data preparation and splitting while eliminating boilerplate code for training and evaluation loops. To train and evaluate your model, you only need to implement your custom architecture and prediction logic.

Below is an example of how to implement a simple model based on DeepSets for the molecule retrieval task. The model is trained to predict the fingerprint of a molecule from its spectrum and then retrieves the most similar molecules from a set of candidates based on fingerprint similarity. For more examples, please see notebooks/demo.ipynb.

  1. Import necessary modules:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer

from massspecgym.data import RetrievalDataset, MassSpecDataModule
from massspecgym.data.transforms import SpecTokenizer, MolFingerprinter
from massspecgym.models.base import Stage
from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
  1. Implement your model:
class MyDeepSetsRetrievalModel(RetrievalMassSpecGymModel):
    def __init__(
        self,
        hidden_channels: int = 128,
        out_channels: int = 4096,  # fingerprint size
        *args,
        **kwargs
    ):
        """Implement your architecture."""
        super().__init__(*args, **kwargs)

        self.phi = nn.Sequential(
            nn.Linear(2, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
        )
        self.rho = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Implement your prediction logic."""
        x = self.phi(x)
        x = x.sum(dim=-2)  # sum over peaks
        x = self.rho(x)
        return x

    def step(
        self, batch: dict, stage: Stage
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Implement your custom logic of using predictions for training and inference."""
        # Unpack inputs
        x = batch["spec"]  # input spectra
        fp_true = batch["mol"]  # true fingerprints
        cands = batch["candidates"]  # candidate fingerprints concatenated for a batch
        batch_ptr = batch["batch_ptr"]  # number of candidates per sample in a batch

        # Predict fingerprint
        fp_pred = self.forward(x)

        # Calculate loss
        loss = nn.functional.mse_loss(fp_true, fp_pred)

        # Calculate final similarity scores between predicted fingerprints and retrieval candidates
        fp_pred_repeated = fp_pred.repeat_interleave(batch_ptr, dim=0)
        scores = nn.functional.cosine_similarity(fp_pred_repeated, cands)

        return dict(loss=loss, scores=scores)
  1. Train and validate your model:
# Init hyperparameters
n_peaks = 60
fp_size = 4096
batch_size = 32

# Load dataset
dataset = RetrievalDataset(
    spec_transform=SpecTokenizer(n_peaks=n_peaks),
    mol_transform=MolFingerprinter(fp_size=fp_size),
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    batch_size=batch_size,
    num_workers=4
)

# Init model
model = MyDeepSetsRetrievalModel(out_channels=fp_size)

# Init trainer
trainer = Trainer(accelerator="cpu", devices=1, max_epochs=5)

# Train
trainer.fit(model, datamodule=data_module)
  1. Test your model (the test API will be available soon):
# Test
trainer.test(model, datamodule=data_module)

TODO

  • Croissant.
  • Testing API.
  • Optimize de novo evaluation metrics to run in parallel by workers initialized in the corresponding pl.Module constructor
  • Link to documentation.
  • Link to Papers With Code leaderboard (requires url to paper).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

massspecgym-1.0.0.tar.gz (51.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

massspecgym-1.0.0-py3-none-any.whl (55.5 kB view details)

Uploaded Python 3

File details

Details for the file massspecgym-1.0.0.tar.gz.

File metadata

  • Download URL: massspecgym-1.0.0.tar.gz
  • Upload date:
  • Size: 51.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for massspecgym-1.0.0.tar.gz
Algorithm Hash digest
SHA256 860920b5e73e0f3dd85b59805ea6a9b63f88f385ad9b2d5e6663eb7203d51de8
MD5 26364b5f743946d2d69db4ffefeacc32
BLAKE2b-256 144f94cee2c0caaefd0a3985d0714a4f1dbc95f54e2c7c70b5aa70c72b6767bb

See more details on using hashes here.

File details

Details for the file massspecgym-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: massspecgym-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 55.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for massspecgym-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5302d8225e8fc660972232f992499a3012eac5d652b312aad631d5fe9eb6c074
MD5 e8cd429479400b42322cd37067ec3f3b
BLAKE2b-256 f451d5315ce7463c9215f37993b7dfd662b0204ff41ee6626f303adb347212e1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page