SMILES-based Matryoshka Representation Learning Embedding Model
Project description
CHEM-MRL
Chem-MRL is a SMILES embedding transformer model that leverages Matryoshka Representation Learning (MRL) to generate efficient, truncatable embeddings for downstream tasks such as classification, clustering, and database querying.
The model employs SentenceTransformers' (SBERT) 2D Matryoshka Sentence Embeddings (Matryoshka2dLoss) to enable truncatable embeddings with minimal accuracy loss, improving query performance and flexibility in downstream applications.
Datasets should consists of SMILES pairs and their corresponding Morgan fingerprint Tanimoto similarity scores. Currently, datasets must be in Parquet format.
Hyperparameter tuning indicates that a custom Tanimoto similarity loss function, TanimotoSentLoss, based on CoSENTLoss, outperforms Tanimoto similarity, CoSENTLoss, AnglELoss, and cosine similarity.
Installation
Install with pip
pip install chem-mrl
Install from source code
pip install -e .
Usage
Hydra & Training Scripts
Hydra configuration files are in chem_mrl/conf. The base config defines shared arguments, while model-specific configs are located in chem_mrl/conf/model. Use chem_mrl_config.yaml or classifier_config.yaml to run specific models.
The scripts directory provides training scripts with Hydra for parameter management:
- Train Chem-MRL model:
python scripts/train_chem_mrl.py train_dataset_path=/path/to/training.parquet val_dataset_path=/path/to/val.parquet
- Train a linear classifier:
python scripts/train_classifier.py train_dataset_path=/path/to/training.parquet val_dataset_path=/path/to/val.parquet
Basic Training Workflow
To train a model, initialize the configuration with dataset paths and model parameters, then pass it to ChemMRLTrainer for training.
from chem_mrl.schemas import BaseConfig, ChemMRLConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainer
# Define training configuration
config = BaseConfig(
model=ChemMRLConfig(
model_name=BASE_MODEL_NAME, # Predefined model name - Can be any transformer model name or path that is compatible with sentence-transformers
n_dims_per_step=3, # Model-specific hyperparameter
use_2d_matryoshka=True, # Enable 2d MRL
# Additional parameters specific to 2D MRL models
n_layers_per_step=2,
kl_div_weight=0.7, # Weight for KL divergence regularization
kl_temperature=0.5, # Temperature parameter for KL loss
),
train_dataset_path="train.parquet", # Path to training data
val_dataset_path="val.parquet", # Path to validation data
test_dataset_path="test.parquet", # Optional test dataset
smiles_a_column_name="smiles_a", # Column with first molecule SMILES representation
smiles_b_column_name="smiles_b", # Column with second molecule SMILES representation
label_column_name="similarity", # Similarity score between molecules
)
# Initialize trainer and start training
trainer = ChemMRLTrainer(config)
test_eval_metric = (
trainer.train()
) # Returns the test evaluation metric if a test dataset is provided.
# Otherwise returns the final validation eval metric
Experimental
Train a Query Model
To train a querying model, configure the model to utilize the specialized query tokenizer.
The query tokenizer supports the following query types:
- similar: Computes SMILES similarity between two molecular structures. For retrieving similar SMILES.
- substructure: Determines the presence of a substructure within the second SMILES string.
- families: Identifies the presence of RDKit’s base chemical feature families in the second SMILES string.
Supported query formats for smiles_a column:
similar {smiles}substructure {smiles}families {space delimited list of feature families}
from chem_mrl.schemas import BaseConfig, ChemMRLConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainer
config = BaseConfig(
model=ChemMRLConfig(
model_name=BASE_MODEL_NAME,
use_query_tokenizer=True, # Train a query model
),
train_dataset_path="train.parquet",
val_dataset_path="val.parquet",
smiles_a_column_name="query",
smiles_b_column_name="target_smiles",
label_column_name="similarity",
)
trainer = ChemMRLTrainer(config)
Latent Attention Layer
The Latent Attention Layer model is an experimental component designed to enhance the representation learning of transformer-based models by introducing a trainable latent dictionary. This mechanism applies cross-attention between token embeddings and a set of learnable latent vectors before pooling. The output of this layer contributes to both 1D Matryoshka loss (as the final layer output) and 2D Matryoshka loss (by integrating into all-layer outputs). Note: initial tests suggests that when using default configuration, the latent attention layer leads to overfitting.
from chem_mrl.models import LatentAttentionLayer
from chem_mrl.schemas import BaseConfig, ChemMRLConfig, LatentAttentionConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainer
config = BaseConfig(
model=ChemMRLConfig(
model_name=BASE_MODEL_NAME,
latent_attention_config=LatentAttentionConfig(
hidden_dim=768, # Transformer hidden size
num_latents=512, # Number of learnable latents
num_cross_heads=8, # Number of attention heads
cross_head_dim=32, # Dimensionality of each head
output_normalize=True, # Apply L2 normalization to outputs
),
use_2d_matryoshka=True,
),
train_dataset_path="train.parquet",
val_dataset_path="val.parquet",
)
# Train a model with latent attention
trainer = ChemMRLTrainer(config)
Custom Evaluation Callbacks
You can provide a callback function that is executed every evaluation_steps steps, allowing custom logic such as logging, early stopping, or model checkpointing.
from chem_mrl.schemas import BaseConfig, ChemMRLConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainer
# Define a callback function for logging evaluation metrics
def eval_callback(score: float, epoch: int, steps: int):
print(f"Step {steps}, Epoch {epoch}: Evaluation Score = {score}")
config = BaseConfig(
model=ChemMRLConfig(
model_name=BASE_MODEL_NAME,
),
train_dataset_path="train.parquet",
val_dataset_path="val.parquet",
smiles_a_column_name="smiles_a",
smiles_b_column_name="smiles_b",
label_column_name="similarity",
)
# Train with callback
trainer = ChemMRLTrainer(config)
val_eval_metric = trainer.train(
eval_callback=eval_callback
) # Callback executed every `evaluation_steps`
W&B Integration
This library includes a WandBTrainerExecutor class for seamless Weights & Biases (W&B) integration. It handles authentication, initialization, and logging at the frequency specified by evaluation_steps.
from chem_mrl.schemas import BaseConfig, WandbConfig, ChemMRLConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainer, WandBTrainerExecutor
from chem_mrl.schemas.Enums import WatchLogOption
# Define W&B configuration for experiment tracking
wandb_config = WandbConfig(
project_name="chem_mrl_test", # W&B project name
run_name="test", # Name for the experiment run
use_watch=True, # Enables model watching for tracking gradients
watch_log=WatchLogOption.all, # Logs all model parameters and gradients
watch_log_freq=1000, # Logging frequency
watch_log_graph=True, # Logs model computation graph
)
# Configure training with W&B integration
config = BaseConfig(
model=ChemMRLConfig(
model_name=BASE_MODEL_NAME,
),
train_dataset_path="train.parquet",
val_dataset_path="val.parquet",
smiles_a_column_name="smiles_a",
smiles_b_column_name="smiles_b",
label_column_name="similarity",
evaluation_steps=1000,
wandb=wandb_config,
)
# Initialize trainer and W&B executor
trainer = ChemMRLTrainer(config)
executor = WandBTrainerExecutor(trainer)
executor.execute() # Handles training and W&B logging
Classifier
This repository includes code for training a linear classifier with optional dropout regularization. The classifier categorizes substances based on SMILES and category features.
Hyperparameter tuning shows that cross-entropy loss (softmax option) outperforms self-adjusting dice loss in terms of accuracy, making it the preferred choice for molecular property classification.
Usage
Basic Classification Training
To train a classifier, configure the model with dataset paths and column names, then initialize ClassifierTrainer to start training.
from chem_mrl.schemas import BaseConfig, ClassifierConfig
from chem_mrl.trainers import ClassifierTrainer
# Define classification training configuration
config = BaseConfig(
model=ClassifierConfig(
model_name="path/to/trained_mrl_model", # Pretrained MRL model path
),
train_dataset_path="train_classification.parquet", # Path to training dataset
val_dataset_path="val_classification.parquet", # Path to validation dataset
smiles_a_column_name="smiles", # Column containing SMILES representations of molecules
label_column_name="label", # Column containing classification labels
)
# Initialize and train the classifier
trainer = ClassifierTrainer(config)
trainer.train()
Training with Dice Loss
For imbalanced classification tasks, Dice Loss can improve performance by focusing on hard-to-classify samples. Below is a configuration using DiceLossClassifierConfig, which introduces additional hyperparameters.
from chem_mrl.schemas import BaseConfig, ClassifierConfig
from chem_mrl.trainers import ClassifierTrainer
from chem_mrl.schemas.Enums import ClassifierLossFctOption, DiceReductionOption
# Define classification training configuration with Dice Loss
config = BaseConfig(
model=ClassifierConfig(
model_name="path/to/trained_mrl_model",
loss_func=ClassifierLossFctOption.selfadjdice,
dice_reduction=DiceReductionOption.sum, # Reduction method for Dice Loss (e.g., 'mean' or 'sum')
dice_gamma=1.0, # Smoothing factor hyperparameter
),
train_dataset_path="train_classification.parquet", # Path to training dataset
val_dataset_path="val_classification.parquet", # Path to validation dataset
smiles_a_column_name="smiles",
label_column_name="label",
)
# Initialize and train the classifier with Dice Loss
trainer = ClassifierTrainer(config)
trainer.train()
References:
- Chithrananda, Seyone, et al. "ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction." arXiv [Cs.LG], 2020. Link.
- Ahmad, Walid, et al. "ChemBERTa-2: Towards Chemical Foundation Models." arXiv [Cs.LG], 2022. Link.
- Kusupati, Aditya, et al. "Matryoshka Representation Learning." arXiv [Cs.LG], 2022. Link.
- Li, Xianming, et al. "2D Matryoshka Sentence Embeddings." arXiv [Cs.CL], 2024. Link.
- Bajusz, Dávid, et al. "Why is the Tanimoto Index an Appropriate Choice for Fingerprint-Based Similarity Calculations?" J Cheminform, 7, 20 (2015). Link.
- Li, Xiaoya, et al. "Dice Loss for Data-imbalanced NLP Tasks." arXiv [Cs.CL], 2020. Link
- Reimers, Nils, and Gurevych, Iryna. "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks." Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing, 2019. Link.
- Lee, Chankyu, et al. "NV-Embed: Improved Techniques for Training LLMs as Generalist Embedding Models." arXiv [Cs.CL], 2025. Link.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file chem_mrl-0.5.1.tar.gz.
File metadata
- Download URL: chem_mrl-0.5.1.tar.gz
- Upload date:
- Size: 401.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
65fb0cfbe78653292db48e7ec1fd80b1ad4d8304165932ea14b9a682fc13e6f8
|
|
| MD5 |
a813ca847b032b0f8cf1167f2f5b6073
|
|
| BLAKE2b-256 |
9af0508039702420ff20f1ea2649895e2e3a10e2a680c61571ab63c5554169ea
|
Provenance
The following attestation bundles were made for chem_mrl-0.5.1.tar.gz:
Publisher:
release.yml on emapco/chem-mrl
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
chem_mrl-0.5.1.tar.gz -
Subject digest:
65fb0cfbe78653292db48e7ec1fd80b1ad4d8304165932ea14b9a682fc13e6f8 - Sigstore transparency entry: 170866990
- Sigstore integration time:
-
Permalink:
emapco/chem-mrl@bc0064f30f79f9f37fef177bf86e82bcf6fc809b -
Branch / Tag:
refs/heads/main - Owner: https://github.com/emapco
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@bc0064f30f79f9f37fef177bf86e82bcf6fc809b -
Trigger Event:
push
-
Statement type:
File details
Details for the file chem_mrl-0.5.1-py3-none-any.whl.
File metadata
- Download URL: chem_mrl-0.5.1-py3-none-any.whl
- Upload date:
- Size: 66.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
85de5a628b9aa857ffcd11d77882ebc3853d6cd566af724849ae6575a1c669fc
|
|
| MD5 |
03f0a3b3ad0b48f84409aac916b9dfa1
|
|
| BLAKE2b-256 |
640883c8d87d3a31c1d4f2532568738776fcb29ada9e9f51cdcb8c0170c5bcc0
|
Provenance
The following attestation bundles were made for chem_mrl-0.5.1-py3-none-any.whl:
Publisher:
release.yml on emapco/chem-mrl
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
chem_mrl-0.5.1-py3-none-any.whl -
Subject digest:
85de5a628b9aa857ffcd11d77882ebc3853d6cd566af724849ae6575a1c669fc - Sigstore transparency entry: 170866991
- Sigstore integration time:
-
Permalink:
emapco/chem-mrl@bc0064f30f79f9f37fef177bf86e82bcf6fc809b -
Branch / Tag:
refs/heads/main - Owner: https://github.com/emapco
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@bc0064f30f79f9f37fef177bf86e82bcf6fc809b -
Trigger Event:
push
-
Statement type: