SMILES-based Matryoshka Representation Learning Embeddings
Project description
CHEM-MRL
CHEM-MRL is a SMILES embedding transformer model that leverages Matryoshka Representation Learning (MRL) to generate efficient, truncatable embeddings for downstream tasks such as classification, clustering, and database querying.
The dataset (split 75%/15%/10% for train/val/test) consists of SMILES pairs and their corresponding Morgan fingerprint (8192-bit vectors) Tanimoto similarity scores. The model employs SentenceTransformers' (SBERT) 2D Matryoshka Sentence Embeddings (Matryoshka2dLoss) to enable truncatable embeddings with minimal accuracy loss, improving query performance in downstream applications.
Hyperparameter tuning indicates that a custom Tanimoto similarity loss function, based on CoSENTLoss, outperforms Tanimoto similarity, CoSENTLoss, AnglELoss, and cosine similarity.
Classifier
This repository includes code for training a linear SBERT classifier with optional dropout regularization. The classifier categorizes substances based on SMILES and category features. While demonstrated on the Isomer Design dataset, it is generalizable to any dataset containing smiles and label columns. The training scripts (see below) allow users to specify these column names.
Currently, the dataset must be in Parquet format.
Hyperparameter tuning shows that cross-entropy loss (softmax option) outperforms self-adjusting dice loss in terms of accuracy, making it the preferred choice for molecular property classification.
Scripts
The scripts directory contains two training scripts:
scripts/train_chem_mrl.py– Trains a Chem-MRL modelscripts/train_classifier.py– Trains a linear classifier
train_chem_mrl.py
For usage details, run:
python scripts/train_chem_mrl.py -h
Example output:
usage: train_chem_mrl.py [-h] --train_dataset_path TRAIN_DATASET_PATH --val_dataset_path VAL_DATASET_PATH
[--test_dataset_path TEST_DATASET_PATH] [--n_train_samples N_TRAIN_SAMPLES] [--n_val_samples N_VAL_SAMPLES]
[--n_test_samples N_TEST_SAMPLES] [--n_dataloader_workers N_DATALOADER_WORKERS] [--generate_dataset_examples_at_init]
[--model_name MODEL_NAME] [--train_batch_size TRAIN_BATCH_SIZE] [--num_epochs NUM_EPOCHS] [--lr_base LR_BASE]
[--scheduler {warmupconstant,warmuplinear,warmupcosine,warmupcosinewithhardrestarts}]
[--warmup_steps_percent WARMUP_STEPS_PERCENT] [--use_fused_adamw] [--use_tf32] [--use_amp] [--seed SEED]
[--model_output_path MODEL_OUTPUT_PATH] [--evaluation_steps EVALUATION_STEPS]
[--checkpoint_save_steps CHECKPOINT_SAVE_STEPS] [--checkpoint_save_total_limit CHECKPOINT_SAVE_TOTAL_LIMIT]
[--return_eval_metric] [--use_wandb] [--wandb_api_key WANDB_API_KEY] [--wandb_project_name WANDB_PROJECT_NAME]
[--wandb_run_name WANDB_RUN_NAME] [--wandb_use_watch] [--wandb_watch_log {gradients,parameters,all}]
[--wandb_watch_log_freq WANDB_WATCH_LOG_FREQ] [--wandb_watch_log_graph] [--smiles_a_column_name SMILES_A_COLUMN_NAME]
[--smiles_b_column_name SMILES_B_COLUMN_NAME] [--label_column_name LABEL_COLUMN_NAME]
[--loss_func {tanimotosentloss,tanimotosimilarityloss,cosentloss,angleloss}]
[--tanimoto_similarity_loss_func {mse,l1,smooth_l1,huber,bin_cross_entropy,kldiv,cosine_embedding_loss}]
[--eval_similarity_fct {cosine,tanimoto}] [--eval_metric {spearman,pearson}]
[--mrl_dimensions MRL_DIMENSIONS [MRL_DIMENSIONS ...]] [--dim_weights DIM_WEIGHTS [DIM_WEIGHTS ...]]
[--n_dims_per_step N_DIMS_PER_STEP] [--use_2d_matryoshka] [--n_layers_per_step N_LAYERS_PER_STEP]
[--last_layer_weight LAST_LAYER_WEIGHT] [--prior_layers_weight PRIOR_LAYERS_WEIGHT] [--kl_div_weight KL_DIV_WEIGHT]
[--kl_temperature KL_TEMPERATURE]
Train SMILES-based MRL embeddings model
options:
-h, --help show this help message and exit
--train_dataset_path TRAIN_DATASET_PATH
--val_dataset_path VAL_DATASET_PATH
--test_dataset_path TEST_DATASET_PATH
--n_train_samples N_TRAIN_SAMPLES
Number of training samples to load. Uses seeded sampling if a seed is set. (default: None)
--n_val_samples N_VAL_SAMPLES
Number of evaluation samples to load. Uses seeded sampling if a seed is set. (default: None)
--n_test_samples N_TEST_SAMPLES
Number of testing samples to load. Uses seeded sampling if a seed is set. (default: None)
--n_dataloader_workers N_DATALOADER_WORKERS
How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
--generate_dataset_examples_at_init
If set, then all `sentence_transformers.InputExample` examples will be generated at at initialization. If not set, the
`sentence_transformers.InputExample` examples are generated on the fly by the dataloader. (default: False)
--model_name MODEL_NAME
Name of the model to use. Must be either a file path or a hugging-face model name. (default: seyonec/ChemBERTa-zinc-
base-v1)
--train_batch_size TRAIN_BATCH_SIZE
Training batch size (default: 32)
--num_epochs NUM_EPOCHS
Number of epochs to train (default: 3)
--lr_base LR_BASE Base learning rate. Will be scaled by the square root of the batch size (default: 1.1190785944700813e-05)
--scheduler {warmupconstant,warmuplinear,warmupcosine,warmupcosinewithhardrestarts}
Learning rate scheduler (default: warmuplinear)
--warmup_steps_percent WARMUP_STEPS_PERCENT
Number of warmup steps that the scheduler will use (default: 0.0)
--use_fused_adamw Use cuda-optimized FusedAdamW optimizer. ~10% faster than torch.optim.AdamW (default: False)
--use_tf32 Use TensorFloat-32 for matrix multiplication and convolutions (default: False)
--use_amp Use automatic mixed precision (default: False)
--seed SEED Omit to not set a seed during training. Used to seed the dataloader sampling and the transformer. (default: 42)
--model_output_path MODEL_OUTPUT_PATH
Path to save model (default: output)
--evaluation_steps EVALUATION_STEPS
Run evaluator every evaluation_steps (default: 0)
--checkpoint_save_steps CHECKPOINT_SAVE_STEPS
Save checkpoint every checkpoint_save_steps (default: 0)
--checkpoint_save_total_limit CHECKPOINT_SAVE_TOTAL_LIMIT
Save total limit (default: 20)
--return_eval_metric Return the final evaluation metric after training (default: False)
--use_wandb Use W&B for logging. Must be enabled for other W&B features to work. (default: False)
--wandb_api_key WANDB_API_KEY
W&B API key. Can be omitted if W&B cli is installed and logged in (default: None)
--wandb_project_name WANDB_PROJECT_NAME
--wandb_run_name WANDB_RUN_NAME
--wandb_use_watch Enable W&B watch (default: False)
--wandb_watch_log {gradients,parameters,all}
Specify which logs to W&B should watch (default: all)
--wandb_watch_log_freq WANDB_WATCH_LOG_FREQ
How often to log (default: 1000)
--wandb_watch_log_graph
Specify if graphs should be logged by W&B (default: False)
--smiles_a_column_name SMILES_A_COLUMN_NAME
SMILES A column name (default: smiles_a)
--smiles_b_column_name SMILES_B_COLUMN_NAME
SMILES B column name (default: smiles_b)
--label_column_name LABEL_COLUMN_NAME
Label column name (default: fingerprint_similarity)
--loss_func {tanimotosentloss,tanimotosimilarityloss,cosentloss,angleloss}
Loss function (default: tanimotosentloss)
--tanimoto_similarity_loss_func {mse,l1,smooth_l1,huber,bin_cross_entropy,kldiv,cosine_embedding_loss}
Base loss function for tanimoto similarity loss function (only for tanimotosimilarityloss) (default: None)
--eval_similarity_fct {cosine,tanimoto}
Similarity function to use for evaluation (default: tanimoto)
--eval_metric {spearman,pearson}
Metric to use for evaluation (default: spearman)
--mrl_dimensions MRL_DIMENSIONS [MRL_DIMENSIONS ...]
A list of embedding dimensions to be used for the loss function. Each value must be less than equal to the base
transformer's hidden dimension. (default: [768, 512, 256, 128, 64, 32, 16, 8])
--dim_weights DIM_WEIGHTS [DIM_WEIGHTS ...]
A list of weights to be used for the loss function. The number of dimension weights must match that of the MRL
dimensions. (default: [1, 1, 1, 1, 1, 1, 1, 1])
--n_dims_per_step N_DIMS_PER_STEP
The number of dimensions to use per step. If -1, then all dimensions are used. If > 0, then a random sample of
n_dims_per_step dimensions are used per step. (default: 1)
--use_2d_matryoshka Use 2D Matryoshka to train over layers in addition to embedding dimensions. (default: False)
--n_layers_per_step N_LAYERS_PER_STEP
The number of layers to use per step. If -1, then all layers are used. If > 0, then a random sample of
n_layers_per_step layers are used per step. (only for 2D MRL) (default: 1)
--last_layer_weight LAST_LAYER_WEIGHT
The weight to use for the loss of the final layer. Increase this to focus more on the performance when using all
layers. (only for 2D MRL) (default: 1.0)
--prior_layers_weight PRIOR_LAYERS_WEIGHT
The weight to use for the loss of the prior layers. Increase this to focus more on the performance when using fewer
layers. (only for 2D MRL) (default: 1.0)
--kl_div_weight KL_DIV_WEIGHT
The weight to use for the KL-div loss that is used to make the prior layers match that of the last layer. Increase
this to focus more on the performance when using fewer layers. (only for 2D MRL) (default: 1.0)
--kl_temperature KL_TEMPERATURE
The temperature to use for the KL-divergence loss. If 0, then the KL-divergence loss is not used. (only for 2D MRL)
(default: 0.3)
train_classifier.py
For usage details, run:
$ python scripts/train_classifier.py -h
Example output:
usage: train_classifier.py [-h] --train_dataset_path TRAIN_DATASET_PATH --val_dataset_path VAL_DATASET_PATH
[--test_dataset_path TEST_DATASET_PATH] [--n_train_samples N_TRAIN_SAMPLES] [--n_val_samples N_VAL_SAMPLES]
[--n_test_samples N_TEST_SAMPLES] [--n_dataloader_workers N_DATALOADER_WORKERS]
[--generate_dataset_examples_at_init] [--model_name MODEL_NAME] [--train_batch_size TRAIN_BATCH_SIZE]
[--num_epochs NUM_EPOCHS] [--lr_base LR_BASE]
[--scheduler {warmupconstant,warmuplinear,warmupcosine,warmupcosinewithhardrestarts}]
[--warmup_steps_percent WARMUP_STEPS_PERCENT] [--use_fused_adamw] [--use_tf32] [--use_amp] [--seed SEED]
[--model_output_path MODEL_OUTPUT_PATH] [--evaluation_steps EVALUATION_STEPS]
[--checkpoint_save_steps CHECKPOINT_SAVE_STEPS] [--checkpoint_save_total_limit CHECKPOINT_SAVE_TOTAL_LIMIT]
[--return_eval_metric] [--use_wandb] [--wandb_api_key WANDB_API_KEY] [--wandb_project_name WANDB_PROJECT_NAME]
[--wandb_run_name WANDB_RUN_NAME] [--wandb_use_watch] [--wandb_watch_log {gradients,parameters,all}]
[--wandb_watch_log_freq WANDB_WATCH_LOG_FREQ] [--wandb_watch_log_graph] [--smiles_column_name SMILES_COLUMN_NAME]
[--label_column_name LABEL_COLUMN_NAME] [--eval_metric {accuracy}] [--loss_func {softmax,selfadjdice}]
[--classifier_hidden_dimension CLASSIFIER_HIDDEN_DIMENSION] [--dropout_p DROPOUT_P] [--freeze_model]
[--dice_reduction {mean,sum}] [--dice_gamma DICE_GAMMA]
Train SMILES-based classifier model
options:
-h, --help show this help message and exit
--train_dataset_path TRAIN_DATASET_PATH
--val_dataset_path VAL_DATASET_PATH
--test_dataset_path TEST_DATASET_PATH
--n_train_samples N_TRAIN_SAMPLES
Number of training samples to load. Uses seeded sampling if a seed is set. (default: None)
--n_val_samples N_VAL_SAMPLES
Number of evaluation samples to load. Uses seeded sampling if a seed is set. (default: None)
--n_test_samples N_TEST_SAMPLES
Number of testing samples to load. Uses seeded sampling if a seed is set. (default: None)
--n_dataloader_workers N_DATALOADER_WORKERS
How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
--generate_dataset_examples_at_init
If set, then all `sentence_transformers.InputExample` examples will be generated at at initialization. If not set, the
`sentence_transformers.InputExample` examples are generated on the fly by the dataloader. (default: False)
--model_name MODEL_NAME
Name of the model to use. Must be either a file path or a hugging-face model name. (default: seyonec/ChemBERTa-zinc-
base-v1)
--train_batch_size TRAIN_BATCH_SIZE
Training batch size (default: 32)
--num_epochs NUM_EPOCHS
Number of epochs to train (default: 3)
--lr_base LR_BASE Base learning rate. Will be scaled by the square root of the batch size (default: 1.1190785944700813e-05)
--scheduler {warmupconstant,warmuplinear,warmupcosine,warmupcosinewithhardrestarts}
Learning rate scheduler (default: warmuplinear)
--warmup_steps_percent WARMUP_STEPS_PERCENT
Number of warmup steps that the scheduler will use (default: 0.0)
--use_fused_adamw Use cuda-optimized FusedAdamW optimizer. ~10% faster than torch.optim.AdamW (default: False)
--use_tf32 Use TensorFloat-32 for matrix multiplication and convolutions (default: False)
--use_amp Use automatic mixed precision (default: False)
--seed SEED Omit to not set a seed during training. Used to seed the dataloader sampling and the transformer. (default: 42)
--model_output_path MODEL_OUTPUT_PATH
Path to save model (default: output)
--evaluation_steps EVALUATION_STEPS
Run evaluator every evaluation_steps (default: 0)
--checkpoint_save_steps CHECKPOINT_SAVE_STEPS
Save checkpoint every checkpoint_save_steps (default: 0)
--checkpoint_save_total_limit CHECKPOINT_SAVE_TOTAL_LIMIT
Save total limit (default: 20)
--return_eval_metric Return the final evaluation metric after training (default: False)
--use_wandb Use W&B for logging. Must be enabled for other W&B features to work. (default: False)
--wandb_api_key WANDB_API_KEY
W&B API key. Can be omitted if W&B cli is installed and logged in (default: None)
--wandb_project_name WANDB_PROJECT_NAME
--wandb_run_name WANDB_RUN_NAME
--wandb_use_watch Enable W&B watch (default: False)
--wandb_watch_log {gradients,parameters,all}
Specify which logs to W&B should watch (default: all)
--wandb_watch_log_freq WANDB_WATCH_LOG_FREQ
How often to log (default: 1000)
--wandb_watch_log_graph
Specify if graphs should be logged by W&B (default: False)
--smiles_column_name SMILES_COLUMN_NAME
SMILES column name (default: smiles)
--label_column_name LABEL_COLUMN_NAME
Label column name (default: label)
--eval_metric {accuracy}
Metric to use for evaluation (default: accuracy)
--loss_func {softmax,selfadjdice}
Loss function (default: softmax)
--classifier_hidden_dimension CLASSIFIER_HIDDEN_DIMENSION
Classifier hidden dimension. Must be less than equal to the ChemMRL transformer's hidden dimension. Note, the base
model will be truncated to this dimension. (default: 768)
--dropout_p DROPOUT_P
Dropout probability for linear layer regularization (default: 0.15)
--freeze_model Freeze internal base SMILES model (default: False)
--dice_reduction {mean,sum}
Dice loss reduction. Used if loss_func=selfadjdice (default: mean)
--dice_gamma DICE_GAMMA
Dice loss gamma. Used if loss_func=selfadjdice (default: 1.0)
References:
- Chithrananda, Seyone, et al. "ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction." arXiv [Cs.LG], 2020. Link.
- Ahmad, Walid, et al. "ChemBERTa-2: Towards Chemical Foundation Models." arXiv [Cs.LG], 2022. Link.
- Kusupati, Aditya, et al. "Matryoshka Representation Learning." arXiv [Cs.LG], 2022. Link.
- Li, Xianming, et al. "2D Matryoshka Sentence Embeddings." arXiv [Cs.CL], 2024. Link.
- Bajusz, Dávid, et al. "Why is the Tanimoto Index an Appropriate Choice for Fingerprint-Based Similarity Calculations?" J Cheminform, 7, 20 (2015). Link.
- Li, Xiaoya, et al. "Dice Loss for Data-imbalanced NLP Tasks." arXiv [Cs.CL], 2020. Link
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file chem_mrl-0.3.4.tar.gz.
File metadata
- Download URL: chem_mrl-0.3.4.tar.gz
- Upload date:
- Size: 18.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
95505621cd783d9072ca0b0f0d6c2b360835c868f12d1be046251e460006a266
|
|
| MD5 |
43712dbac8b5e7fa09aa9f5301d63963
|
|
| BLAKE2b-256 |
b91ef910a491ed60d514b5571270357bd64d04164900af5539b7ad9612e67ccf
|
Provenance
The following attestation bundles were made for chem_mrl-0.3.4.tar.gz:
Publisher:
release.yml on emapco/chem-mrl
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
chem_mrl-0.3.4.tar.gz -
Subject digest:
95505621cd783d9072ca0b0f0d6c2b360835c868f12d1be046251e460006a266 - Sigstore transparency entry: 168157684
- Sigstore integration time:
-
Permalink:
emapco/chem-mrl@a52c7de0a0acb63d7512f8006ffc32862983ab43 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/emapco
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@a52c7de0a0acb63d7512f8006ffc32862983ab43 -
Trigger Event:
push
-
Statement type:
File details
Details for the file chem_mrl-0.3.4-py3-none-any.whl.
File metadata
- Download URL: chem_mrl-0.3.4-py3-none-any.whl
- Upload date:
- Size: 15.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d86257f085d8f90d08e5cfe2211d6933134510648ecec4affe2cdd55cb8ec47c
|
|
| MD5 |
34360fb44d479c73e7b9eca96b638522
|
|
| BLAKE2b-256 |
38a5c05e1ea8d237398e5862937ff6568554029506b2bf6e9c10a1ad5909be74
|
Provenance
The following attestation bundles were made for chem_mrl-0.3.4-py3-none-any.whl:
Publisher:
release.yml on emapco/chem-mrl
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
chem_mrl-0.3.4-py3-none-any.whl -
Subject digest:
d86257f085d8f90d08e5cfe2211d6933134510648ecec4affe2cdd55cb8ec47c - Sigstore transparency entry: 168157687
- Sigstore integration time:
-
Permalink:
emapco/chem-mrl@a52c7de0a0acb63d7512f8006ffc32862983ab43 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/emapco
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@a52c7de0a0acb63d7512f8006ffc32862983ab43 -
Trigger Event:
push
-
Statement type: