Skip to main content

SMILES-based Matryoshka Representation Learning Embedding Model

Project description

CHEM-MRL

Chem-MRL is a SMILES embedding transformer model that leverages Matryoshka Representation Learning (MRL) to generate efficient, truncatable embeddings for downstream tasks such as classification, clustering, and database querying.

The model employs SentenceTransformers' (SBERT) 2D Matryoshka Sentence Embeddings (Matryoshka2dLoss) to enable truncatable embeddings with minimal accuracy loss, improving query performance and flexibility in downstream applications.

Datasets should consists of SMILES pairs and their corresponding Morgan fingerprint Tanimoto similarity scores. Currently, datasets must be in Parquet format.

Hyperparameter tuning indicates that a custom Tanimoto similarity loss function, TanimotoSentLoss, based on CoSENTLoss, outperforms Tanimoto similarity, CoSENTLoss, AnglELoss, and cosine similarity.

Installation

Install with pip

pip install chem-mrl

Install from source code

pip install -e .

Usage

Hydra & Training Scripts

Hydra configuration files are in chem_mrl/conf. The base config defines shared arguments, while model-specific configs are located in chem_mrl/conf/model. Use chem_mrl_config.yaml or classifier_config.yaml to run specific models.

The scripts directory provides training scripts with Hydra for parameter management:

  • Train Chem-MRL model:
    python scripts/train_chem_mrl.py train_dataset_path=/path/to/training.parquet val_dataset_path=/path/to/val.parquet
    
  • Train a linear classifier:
    python scripts/train_classifier.py train_dataset_path=/path/to/training.parquet val_dataset_path=/path/to/val.parquet
    

Basic Training Workflow

To train a model, initialize the configuration with dataset paths and model parameters, then pass it to ChemMRLTrainer for training.

from chem_mrl.schemas import ChemMRLConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainer

# Define training configuration
config = BaseConfig(
    model=ChemMRLConfig(
        model_name=BASE_MODEL_NAME,  # Predefined model name - Can be a any transformer model name or path that is compatible with sentence-transformers
        smiles_a_column_name="smiles_a",  # Column with first molecule SMILES representation
        smiles_b_column_name="smiles_b",  # Column with second molecule SMILES representation
        label_column_name="similarity",  # Similarity score between molecules
        n_dims_per_step=3,  # Model-specific hyperparameter
        use_2d_matryoshka=True,  # Enable 2d MRL
        # Additional parameters specific to 2D MRL models
        n_layers_per_step=2,
        kl_div_weight=0.7,  # Weight for KL divergence regularization
        kl_temperature=0.5,  # Temperature parameter for KL loss
    ),
    train_dataset_path="train.parquet",  # Path to training data
    val_dataset_path="val.parquet",  # Path to validation data
    test_dataset_path="test.parquet",  # Optional test dataset
)

# Initialize trainer and start training
trainer = ChemMRLTrainer(config)
test_eval_metric = (
    trainer.train()
)  # Returns evaluation metric (if test dataset exists) otherwise returns the final validation eval metric

Custom Evaluation Callbacks

You can provide a callback function that is executed every evaluation_steps steps, allowing custom logic such as logging, early stopping, or model checkpointing.

from chem_mrl.schemas import Chem2dMRLConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainer


# Define a callback function for logging evaluation metrics
def eval_callback(score: float, epoch: int, steps: int):
    print(f"Step {steps}, Epoch {epoch}: Evaluation Score = {score}")


config = BaseConfig(
    model=ChemMRLConfig(
        model_name=BASE_MODEL_NAME,
        smiles_a_column_name="smiles_a",
        smiles_b_column_name="smiles_b",
        label_column_name="similarity",
    ),
    train_dataset_path="train.parquet",
    val_dataset_path="val.parquet",
)

# Train with callback
trainer = ChemMRLTrainer(config)
val_eval_metric = trainer.train(
    eval_callback=eval_callback
)  # Callback executed every `evaluation_steps`

W&B Integration

This library includes a WandBTrainerExecutor class for seamless Weights & Biases (W&B) integration. It handles authentication, initialization, and logging at the frequency specified by evaluation_steps.

from chem_mrl.schemas import Chem2dMRLConfig, ChemMRLConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainer, WandBTrainerExecutor

# Define W&B configuration for experiment tracking
wandb_config = WandbConfig(
    project_name="chem_mrl_test",  # W&B project name
    run_name="test",  # Name for the experiment run
    use_watch=True,  # Enables model watching for tracking gradients
    watch_log="all",  # Logs all model parameters and gradients
    watch_log_freq=1000,  # Logging frequency
    watch_log_graph=True,  # Logs model computation graph
)

# Configure training with W&B integration
config = BaseConfig(
    model=ChemMRLConfig(
        model_name=BASE_MODEL_NAME,
        smiles_a_column_name="smiles_a",
        smiles_b_column_name="smiles_b",
        label_column_name="similarity",
    ),
    train_dataset_path="train.parquet",
    val_dataset_path="val.parquet",
    evaluation_steps=1000,
    wandb=wandb_config,
)

# Initialize trainer and W&B executor
trainer = ChemMRLTrainer(config)
executor = WandBTrainerExecutor(trainer)
executor.execute()  # Handles training and W&B logging

Classifier

This repository includes code for training a linear classifier with optional dropout regularization. The classifier categorizes substances based on SMILES and category features.

Hyperparameter tuning shows that cross-entropy loss (softmax option) outperforms self-adjusting dice loss in terms of accuracy, making it the preferred choice for molecular property classification.

Usage

Basic Classification Training

To train a classifier, configure the model with dataset paths and column names, then initialize ClassifierTrainer to start training.

from chem_mrl.schemas import ClassifierConfig
from chem_mrl.trainers import ClassifierTrainer

# Define classification training configuration
config = BaseConfig(
    model=ClassifierConfig(
        model_name="path/to/trained_mrl_model",  # Pretrained MRL model path
        smiles_column_name="smiles",  # Column containing SMILES representations of molecules
        label_column_name="label",  # Column containing classification labels
    ),
    train_dataset_path="train_classification.parquet",  # Path to training dataset
    val_dataset_path="val_classification.parquet",  # Path to validation dataset
)

# Initialize and train the classifier
trainer = ClassifierTrainer(config)
trainer.train()

Training with Dice Loss

For imbalanced classification tasks, Dice Loss can improve performance by focusing on hard-to-classify samples. Below is a configuration using DiceLossClassifierConfig, which introduces additional hyperparameters.

from chem_mrl.schemas import DiceLossClassifierConfig
from chem_mrl.trainers import ClassifierTrainer
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.schemas.Enums import ClassifierLossFctOption, DiceReductionOption

# Define classification training configuration with Dice Loss
config = BaseConfig(
    model=ClassifierConfig(
        model_name="path/to/trained_mrl_model",
        smiles_column_name="smiles",
        label_column_name="label",
        loss_fct=ClassifierLossFctOption.selfadjdice,
        dice_reduction=DiceReductionOption.sum,  # Reduction method for Dice Loss (e.g., 'mean' or 'sum')
        dice_gamma=1.0,  # Smoothing factor hyperparameter
    ),
    train_dataset_path="train_classification.parquet",  # Path to training dataset
    val_dataset_path="val_classification.parquet",  # Path to validation dataset
)

# Initialize and train the classifier with Dice Loss
trainer = ClassifierTrainer(config)
trainer.train()

References:

  • Chithrananda, Seyone, et al. "ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction." arXiv [Cs.LG], 2020. Link.
  • Ahmad, Walid, et al. "ChemBERTa-2: Towards Chemical Foundation Models." arXiv [Cs.LG], 2022. Link.
  • Kusupati, Aditya, et al. "Matryoshka Representation Learning." arXiv [Cs.LG], 2022. Link.
  • Li, Xianming, et al. "2D Matryoshka Sentence Embeddings." arXiv [Cs.CL], 2024. Link.
  • Bajusz, Dávid, et al. "Why is the Tanimoto Index an Appropriate Choice for Fingerprint-Based Similarity Calculations?" J Cheminform, 7, 20 (2015). Link.
  • Li, Xiaoya, et al. "Dice Loss for Data-imbalanced NLP Tasks." arXiv [Cs.CL], 2020. Link
  • Reimers, Nils, and Gurevych, Iryna. "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks." Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing, 2019. Link.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

chem_mrl-0.4.1.tar.gz (392.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

chem_mrl-0.4.1-py3-none-any.whl (58.1 kB view details)

Uploaded Python 3

File details

Details for the file chem_mrl-0.4.1.tar.gz.

File metadata

  • Download URL: chem_mrl-0.4.1.tar.gz
  • Upload date:
  • Size: 392.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for chem_mrl-0.4.1.tar.gz
Algorithm Hash digest
SHA256 b9f471761c4aca30beae23cd3a1c7be6147f4532677a281315d06e91a405cd3e
MD5 bc037539c14a95e941eebbfaaec01acf
BLAKE2b-256 cc00f139e3bf80534b7e07c482b2579fc666a0ef649515183abc6703e082db13

See more details on using hashes here.

Provenance

The following attestation bundles were made for chem_mrl-0.4.1.tar.gz:

Publisher: release.yml on emapco/chem-mrl

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file chem_mrl-0.4.1-py3-none-any.whl.

File metadata

  • Download URL: chem_mrl-0.4.1-py3-none-any.whl
  • Upload date:
  • Size: 58.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for chem_mrl-0.4.1-py3-none-any.whl
Algorithm Hash digest
SHA256 dba3fce5d8021f0ac900540644a4d5b95d5e3de943f5b421c472ac0bef87d20d
MD5 551111e33dee18f60ad54260c4768207
BLAKE2b-256 558bfdea73dd6d8b41efa0d6d0037ed3dfcf3cf45b601e4a706c81f006db233c

See more details on using hashes here.

Provenance

The following attestation bundles were made for chem_mrl-0.4.1-py3-none-any.whl:

Publisher: release.yml on emapco/chem-mrl

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page