Skip to main content

SMILES-based Matryoshka Representation Learning Embedding Model

Project description

CHEM-MRL

Chem-MRL is a SMILES embedding transformer model that leverages Matryoshka Representation Learning (MRL) to generate efficient, truncatable embeddings for downstream tasks such as classification, clustering, and database querying.

The model employs SentenceTransformers' (SBERT) 2D Matryoshka Sentence Embeddings (Matryoshka2dLoss) to enable truncatable embeddings with minimal accuracy loss, improving query performance and flexibility in downstream applications.

Datasets should consists of SMILES pairs and their corresponding Morgan fingerprint Tanimoto similarity scores.

Hyperparameter tuning indicates that a custom Tanimoto similarity loss function, TanimotoSentLoss, based on CoSENTLoss, outperforms Tanimoto similarity, CoSENTLoss, AnglELoss, and cosine similarity.

Installation

Install with pip

pip install chem-mrl

Install from source code

pip install -e .

Usage

Hydra & Training Scripts

Hydra configuration files are in chem_mrl/conf. The base config (base.yaml) defines shared arguments and includes model-specific configurations from chem_mrl/conf/model. Supported models: chem_mrl, chem_2d_mrl, classifier, and dice_loss_classifier.

Training Examples:

# Default (chem_mrl model)
python scripts/train_chem_mrl.py

# Specify model type
python scripts/train_chem_mrl.py model=chem_2d_mrl
python scripts/train_chem_mrl.py model=classifier

# Override parameters
python scripts/train_chem_mrl.py model=chem_mrl training_args.num_train_epochs=5 datasets[0].train_dataset.name=/path/to/data.parquet

# Use different custom config also located in `chem_mrl/conf`
python scripts/train_chem_mrl.py --config-name=my_custom_config.yaml

Configuration Options:

  • Command line overrides: Use model=<type> and parameter overrides as shown above
  • Modify base.yaml: Edit the - /model: chem_mrl line in the defaults section to change the default model, or modify any other parameters directly
  • Override config file: Use --config-name=<config_name> to specify a different base configuration file instead of the default base.yaml

Basic Training Workflow

To train a model, initialize the configuration with dataset paths and model parameters, then pass it to ChemMRLTrainer for training.

from sentence_transformers import SentenceTransformerTrainingArguments

from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ChemMRLTrainer

dataset_config = DatasetConfig(
    key="my_dataset",
    train_dataset=SplitConfig(
        name="train.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float32,
        sample_size=1000,
    ),
    val_dataset=SplitConfig(
        name="val.parquet",
        split_key="train",  # Use "train" for local files
        label_cast_type=FieldTypeOption.float16,
        sample_size=500,
    ),
    test_dataset=SplitConfig(
        name="test.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float16,
        sample_size=500,
    ),
    smiles_a_column_name="smiles_a",
    smiles_b_column_name="smiles_b",
    label_column_name="similarity",
)

config = BaseConfig(
    model=ChemMRLConfig(
        model_name=BASE_MODEL_NAME,  # Predefined model name - Can be any transformer model name or path that is compatible with sentence-transformers
        n_dims_per_step=3,  # Model-specific hyperparameter
        use_2d_matryoshka=True,  # Enable 2d MRL
        # Additional parameters specific to 2D MRL models
        n_layers_per_step=2,
        kl_div_weight=0.7,  # Weight for KL divergence regularization
        kl_temperature=0.5,  # Temperature parameter for KL loss
    ),
    datasets=[dataset_config],  # List of dataset configurations
    training_args=SentenceTransformerTrainingArguments("training_output"),
)

# Initialize trainer and start training
trainer = ChemMRLTrainer(config)
test_eval_metric = (
    trainer.train()
)  # Returns the test evaluation metric if a test dataset is provided.
# Otherwise returns the final validation eval metric

Experimental

Train a Query Model

To train a querying model, configure the model to utilize the specialized query tokenizer.

The query tokenizer supports the following query types:

  • similar: Computes SMILES similarity between two molecular structures. For retrieving similar SMILES.
  • substructure: Determines the presence of a substructure within the second SMILES string.

Supported query formats for smiles_a column:

  • similar {smiles}
  • substructure {smiles}
from sentence_transformers import SentenceTransformerTrainingArguments

from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ChemMRLTrainer

dataset_config = DatasetConfig(
    key="query_dataset",
    train_dataset=SplitConfig(
        name="train.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float32,
        sample_size=1000,
    ),
    val_dataset=SplitConfig(
        name="val.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float16,
        sample_size=500,
    ),
    smiles_a_column_name="query",
    smiles_b_column_name="target_smiles",
    label_column_name="similarity",
)

config = BaseConfig(
    model=ChemMRLConfig(
        model_name=BASE_MODEL_NAME,
        use_query_tokenizer=True,  # Train a query model
    ),
    datasets=[dataset_config],
    training_args=SentenceTransformerTrainingArguments("training_output"),
)
trainer = ChemMRLTrainer(config)

Latent Attention Layer

The Latent Attention Layer model is an experimental component designed to enhance the representation learning of transformer-based models by introducing a trainable latent dictionary. This mechanism applies cross-attention between token embeddings and a set of learnable latent vectors before pooling. The output of this layer contributes to both 1D Matryoshka loss (as the final layer output) and 2D Matryoshka loss (by integrating into all-layer outputs). Note: initial tests suggests that when using default configuration, the latent attention layer leads to overfitting.

from sentence_transformers import SentenceTransformerTrainingArguments

from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.schemas import (
    BaseConfig,
    ChemMRLConfig,
    DatasetConfig,
    LatentAttentionConfig,
    SplitConfig,
)
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ChemMRLTrainer

dataset_config = DatasetConfig(
    key="latent_dataset",
    train_dataset=SplitConfig(
        name="train.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float32,
        sample_size=1000,
    ),
    val_dataset=SplitConfig(
        name="val.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float16,
        sample_size=500,
    ),
    smiles_a_column_name="smiles_a",
    smiles_b_column_name="smiles_b",
    label_column_name="similarity",
)

config = BaseConfig(
    model=ChemMRLConfig(
        model_name=BASE_MODEL_NAME,
        latent_attention_config=LatentAttentionConfig(
            hidden_dim=768,  # Transformer hidden size
            num_latents=512,  # Number of learnable latents
            num_cross_heads=8,  # Number of attention heads
            cross_head_dim=32,  # Dimensionality of each head
            output_normalize=True,  # Apply L2 normalization to outputs
        ),
        use_2d_matryoshka=True,
    ),
    datasets=[dataset_config],
    training_args=SentenceTransformerTrainingArguments("training_output"),
)

# Train a model with latent attention
trainer = ChemMRLTrainer(config)

Custom Callbacks

You can provide a list of transformers.TrainerCallback classes to execute while training.

from typing import Any

from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainingArguments,
)
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState

from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ChemMRLTrainer


# Define a callback class for logging evaluation metrics
class EvalCallback(TrainerCallback):
    def on_evaluate(
        self,
        args: SentenceTransformerTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        metrics: dict[str, Any],
        model: SentenceTransformer,
        **kwargs,
    ) -> None:
        """
        Event called after an evaluation phase.
        """
        pass


dataset_config = DatasetConfig(
    key="callback_dataset",
    train_dataset=SplitConfig(
        name="train.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float32,
        sample_size=1000,
    ),
    val_dataset=SplitConfig(
        name="val.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float16,
        sample_size=500,
    ),
    smiles_a_column_name="smiles_a",
    smiles_b_column_name="smiles_b",
    label_column_name="similarity",
)

config = BaseConfig(
    model=ChemMRLConfig(
        model_name=BASE_MODEL_NAME,
    ),
    datasets=[dataset_config],
    training_args=SentenceTransformerTrainingArguments("training_output"),
)

# Train with callback
trainer = ChemMRLTrainer(config)
val_eval_metric = trainer.train(callbacks=[EvalCallback(...)])

Classifier

This repository includes code for training a linear classifier with optional dropout regularization. The classifier categorizes substances based on SMILES and category features.

Hyperparameter tuning shows that cross-entropy loss (softmax option) outperforms self-adjusting dice loss in terms of accuracy, making it the preferred choice for molecular property classification.

Usage

Basic Classification Training

To train a classifier, configure the model with dataset paths and column names, then initialize ClassifierTrainer to start training.

from sentence_transformers import SentenceTransformerTrainingArguments

from chem_mrl.schemas import BaseConfig, ClassifierConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ClassifierTrainer

dataset_config = DatasetConfig(
    key="classification_dataset",
    train_dataset=SplitConfig(
        name="train_classification.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float32,
        sample_size=1000,
    ),
    val_dataset=SplitConfig(
        name="val_classification.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float16,
        sample_size=500,
    ),
    smiles_a_column_name="smiles",
    smiles_b_column_name=None,  # Not needed for classification
    label_column_name="label",
)

# Define classification training configuration
config = BaseConfig(
    model=ClassifierConfig(
        model_name="path/to/trained_mrl_model",  # Pretrained MRL model path
    ),
    datasets=[dataset_config],
    training_args=SentenceTransformerTrainingArguments("training_output"),
)

# Initialize and train the classifier
trainer = ClassifierTrainer(config)
trainer.train()

Training with Dice Loss

For imbalanced classification tasks, Dice Loss can improve performance by focusing on hard-to-classify samples. Below is a configuration using DiceLossClassifierConfig, which introduces additional hyperparameters.

from sentence_transformers import SentenceTransformerTrainingArguments

from chem_mrl.schemas import BaseConfig, ClassifierConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import ClassifierLossFctOption, DiceReductionOption, FieldTypeOption
from chem_mrl.trainers import ClassifierTrainer

dataset_config = DatasetConfig(
    key="dice_loss_dataset",
    train_dataset=SplitConfig(
        name="train_classification.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float32,
        sample_size=1000,
    ),
    val_dataset=SplitConfig(
        name="val_classification.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float16,
        sample_size=500,
    ),
    smiles_a_column_name="smiles",
    smiles_b_column_name=None,  # Not needed for classification
    label_column_name="label",
)

# Define classification training configuration with Dice Loss
config = BaseConfig(
    model=ClassifierConfig(
        model_name="path/to/trained_mrl_model",
        loss_func=ClassifierLossFctOption.selfadjdice,
        dice_reduction=DiceReductionOption.sum,  # Reduction method for Dice Loss (e.g., 'mean' or 'sum')
        dice_gamma=1.0,  # Smoothing factor hyperparameter
    ),
    datasets=[dataset_config],
    training_args=SentenceTransformerTrainingArguments("training_output"),
)

# Initialize and train the classifier with Dice Loss
trainer = ClassifierTrainer(config)
trainer.train()

References:

  • Chithrananda, Seyone, et al. "ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction." arXiv [Cs.LG], 2020. Link.
  • Ahmad, Walid, et al. "ChemBERTa-2: Towards Chemical Foundation Models." arXiv [Cs.LG], 2022. Link.
  • Kusupati, Aditya, et al. "Matryoshka Representation Learning." arXiv [Cs.LG], 2022. Link.
  • Li, Xianming, et al. "2D Matryoshka Sentence Embeddings." arXiv [Cs.CL], 2024. Link.
  • Bajusz, Dávid, et al. "Why is the Tanimoto Index an Appropriate Choice for Fingerprint-Based Similarity Calculations?" J Cheminform, 7, 20 (2015). Link.
  • Li, Xiaoya, et al. "Dice Loss for Data-imbalanced NLP Tasks." arXiv [Cs.CL], 2020. Link
  • Reimers, Nils, and Gurevych, Iryna. "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks." Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing, 2019. Link.
  • Lee, Chankyu, et al. "NV-Embed: Improved Techniques for Training LLMs as Generalist Embedding Models." arXiv [Cs.CL], 2025. Link.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

chem_mrl-0.7.2.tar.gz (414.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

chem_mrl-0.7.2-py3-none-any.whl (62.8 kB view details)

Uploaded Python 3

File details

Details for the file chem_mrl-0.7.2.tar.gz.

File metadata

  • Download URL: chem_mrl-0.7.2.tar.gz
  • Upload date:
  • Size: 414.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for chem_mrl-0.7.2.tar.gz
Algorithm Hash digest
SHA256 3f4f8b8abe062649400c29be8ef93babe3df2f24cb65309b3fb1007e4143687c
MD5 97cd086ec03c4d03e4b49b53c08c6840
BLAKE2b-256 2114cc0ab5608bcc9627933da22fc3a745ae42a6c001f6066268315dfa51841b

See more details on using hashes here.

Provenance

The following attestation bundles were made for chem_mrl-0.7.2.tar.gz:

Publisher: release.yml on emapco/chem-mrl

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file chem_mrl-0.7.2-py3-none-any.whl.

File metadata

  • Download URL: chem_mrl-0.7.2-py3-none-any.whl
  • Upload date:
  • Size: 62.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for chem_mrl-0.7.2-py3-none-any.whl
Algorithm Hash digest
SHA256 c6d3307a6ac06c2f05422574781846abde228814e9a75b9cf473a18b3d9abf8b
MD5 ed77ceab35e9f2e81cf76240fb245723
BLAKE2b-256 812a492601640a8c906e71207cb78d45a541683fc389dfac0fbc19abab3fe019

See more details on using hashes here.

Provenance

The following attestation bundles were made for chem_mrl-0.7.2-py3-none-any.whl:

Publisher: release.yml on emapco/chem-mrl

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page