Skip to main content

Benchmark for molecular uncertainty estimation.

Project description

MUBen: Molecular Uncertainty Benmark

Code associated with paper MUBen: Benchmarking the Uncertainty of Pre-Trained Models for Molecular Property Prediction.

The code is built to expose implementation details as much as possible and be easily extendable. Questions and suggestions are welcome if you find any issues while using our code.

0. ABOUT

MUBen is a benchmark that aims to investigate the performance of uncertainty quantification (UQ) methods built upon backbone molecular representation models. It implements 6 backbone models (4 pre-trained and 2 non-pre-trained), 8 UQ methods (8 compatible for classification and 6 for regression), and 14 datasets from MoleculeNet (8 for classification and 6 for regression). We are actively expanding the benchmark to include more backbones, UQ methods and datasets. This is an arduous task, and we welcome contribution or collaboration in any form.

Backbones

Backbone Models Paper Official Repo Our Implementation
Pre-Trained Backbones
ChemBERTa link link link
GROVER link link link
Uni-Mol link link link
TorchMD-NET Architecture; Pre-training link link
Non-Pre-Trained Backbones
DNN - - link
GIN link pyg link

Uncertainty Quantification Methods

UQ Method Classification Regression Paper
Included
Deterministic ✅︎ ✅︎ -
Temperature Scaling ✅︎ - link
Focal Loss ✅︎ - link
Deep Ensembles ✅︎ ✅︎ link
SWAG ✅︎ ✅︎ link
Bayes by Backprop ✅︎ ✅︎ link
SGLD ✅︎ ✅︎ link
MC Dropout ✅︎ ✅︎ link

Data

Please check MoleculeNet for a detailed description. We use a subset of the MoleculeNet benckmark, including BBBP, Tox21, ToxCast, SIDER, ClinTox, BACE, MUV, HIV, ESOL, FreeSolv, Lipophilicity, QM7, QM8, QM9.

1. DATA

A set of partitioned datasets are already included in this repo. You can find them under the ./data/ folder: [scaffold split]; [random split].

We utilize the datasets prepared by Uni-Mol. You find the data here or directly download it through this link. We place the unzipped files into ./data/UniMol by default. For convenience, you are suggested to rename the qm7dft, qm8dft, and qm9dft folders to qm7, qm8, and qm9.

Afterwards, you can transfer the dataset format into ours by running

PYTHONPATH="." python ./assist/dataset_build_from_unimol.py

suppose you are in the project root directory. You can specify the input (Uni-Mol) and output data directories with --unimol_data_dir and --output_dir arguments. The script will convert all datasets by default (excluding PCBA). If you want to specify a subset of datasets, you can specify the argument --dataset_names with the target dataset names with lowercase letters.

Notice: If you would like to run the Uni-Mol model, you are suggested to keep the original UniMol data as we will use the pre-defined molecule conformations. Otherwise, it is safe to remove the original data.

Other Options

If you do not want to use Uni-Mol data, you can try the scripts within the legacy folder, including build_dgllife_datasets.py, and build_qm[7,8,9]_dataset.py. Notice that this may result in training/validation/test partitions different from what is being used in our experiments.

Using Customized Datasets

If you want to test the UQ methods on your own dataset, you can use pandas.DataFrame structure with the following keys:

{
  "smiles": list of `str`,
  "labels": list of list of int/float,
  "masks": list of list of int/float (with values within {0,1})
}

and store them as train.csv, valid.csv, and test.csv files. mask=1 indicates the existence informative label at the position and mask=0 indicates missing label. You can check the prepared datasets included in our program for reference. You are recommended to put the dataset files in the ./data/file/<dataset name> directory, but you can of course choose your favorite location and specify the --data_folder argument.

The .csv files should be accompanied by a meta.json file within the same directory. It stores some constant dataset properties, e.g., task_type (classification or regression), n_tasks, or classes ([0,1] for all our classification datasets). For the customized dataset, one required property is the eval_metric for validation and test (e.g., roc-auc, rmse, etc.) since it is not specified in the macro file. Please refer to ./assist/dataset_build_roe.py for an example (unfortunately, we are not allowed to release the dataset).

2. REQUIREMENTS

Please find the required packages in requirements.txt. Our code is developed with Python 3.10 and does not work with Python versions earlier than 3.9. It is recommended to create a new conda environment with

conda create --name <env_name> --file requirements.txt

Docker

Alternatively, you can run this project in a docker container. You can build your image through

docker build -t muben ./docker

and run your container in an interactive shell with

docker run --gpus all -it --rm  muben

External Dependencies

The backbone models GROVER and Uni-Mol require loading pre-trained model checkpoints.

  • The GROVER-base checkpoint is available at GROVER's project repo or can be directly downloaded through this link. Unzip the downloaded .tar.gz file to get the .pt checkpoint.
  • The Uni-Mol checkpoint is available at Uni-Mol's project repo or can be directly downloaded through this link.

By default, the code will look for the models at locations ./models/grover_base.pt and ./models/unimol_base.pt, respectively. You need to specify the --checkpoint_path argument if you prefer other locations and checkpoint names.

3. RUN

A simple demo of running our project can be found at ./demo/demo.ipynb.

To run each of the four backbone models with uncertainty estimation methods, you can check the run_*.py files in the root directory. Example shell scripts are provided in the ./scripts folder as .sh files. You can use them through

./scripts/run_dnn_rdkit.sh <CUDA_VISIBLE_DEVICES>

as an example. Notice that we need to comment out the variables train_on_<dataset name> in the .sh files to skip training on the corresponding datasets. Setting their value to false does not work.

Another way of specifying arguments is through the .json scripts, for example:

PYTHONPATH="." CUDA_VISIBLE_DEVICES=0 python ./run/dnn.py ./scripts/config_dnn.json

This approach could be helpful for debugging the code through vscode.

To get a detailed description of each argument, you can use --help:

PYTHONPATH="." python ./run/dnn.py --help

Logging and WandB

By default, this project uses local logging files (*.log) and WandB to track training status.

The log files are stored as ./logs/<dataset>/<model>/<uncertainty>/<running_time>.log. You can change the file path by specifying the --log_path argument, or disable log saving by setting --log_path="disabled".

To use WandB, you first need to register an account and sign in on your machine with wandb login. If you are running your code on a public device, you can instead use program-wise signing in by specifying the --wandb_api_key argument while running our code. You can find your API key in your browser here: https://wandb.ai/authorize. To disable WandB, use --disable_wandb [true]. By default, we use MUBen-<dataset> as WandB project name and <model>-<uncertainty> as the model name. You can change this behavior by specifying the --wandb_project and --wandb_name arguments.

Data Loading

The progress will automatically create the necessary features (molecular descriptors) required by backbone models from the SMILES strings if they are loaded properly. The processed features are stored in the <bottom-level data folder>/processed/ directory as <train/valid/test>.pt files by default, and will be automatically loaded the next time you apply the same backbone model on the same dataset. You can change this behavior with --disable_dataset_saving for disabling dataset saving or --ignore_preprocessed_dataset for not loading from the saved (processed) dataset.

Constructing Morgan fingerprint, RDKit features or 3D conformations for Uni-Mol may take a while. You can accelerate this process by utilizing multiple threads --num_preprocess_workers=n>1 (default is 8). For 3D conformations, we directly take advantage of the results from Uni-Mol but still keep the choice of generating them by ourselves if the Uni-Mol data files are not found.

Calculating Metrics

During training, we only calculate metrics necessary for early stopping and simple prediction performance evaluation. To get other metrics, you need to use the ./assist/results_get_metrics.py file.

Specifically, you need to save the model predictions by not setting --disable_dataset_saving. The results are saved as ./<result_folder>/<dataset_name>/<model_name>/<uncertainty_method>/seed-<seed>/preds/<test_idx>.pt files. When the training is finished, you can run the ./assist/results_get_metrics.py file to generate all metrics for your model predictions. For example:

PYTHONPATH="." python ./assist/results_get_metrics.py ./scripts/config_metrics.json

Make sure the hyper-parameters in the configuration file are updated to your needs.

The metrics will be saved in the ./<result_folder>/RESULTS/<model_name>-<dataset_name>.csv files. Notice that these files already exist in the repo if you keep the default --result_folder=./output argument and you need to check whether it is updated to reveal your experiment results.

Results

We provided a more comprehensive copy of our experiment results here that are presented in the tables in our paper's appendix. We hope it can ease some effort if you want to further analyze the behavior of our backbone models and uncertainty quantification methods.

4. ONGOING WORKS

4.1. Active Learning

We are developing code to integrate active learning into the pipeline. Specifically, we assume we have a small set of labeled data points (--n_init_instances) at the beginning. Within each active learning iteration, we use the labeled dataset to fine-tune the model parameters and select a batch of data points (--n_al_select) from the unlabeled set with the least predicted certainty (i.e., max predicted entropy for classification and max predicted variance for regression). The process is repeated for several loops (--n_al_loops), and the intermediate performance is tracked.

The code is still under construction and currently is only available under the dev branch. In addition, several points are worth attention:

  • Currently, only DNN and ChemBERTa backbones are supported (./run/dnn_al.py and ./run/chemberta_al.py). Migrating AL to other backbones is not difficult but requires updating some Trainer functions if they are reloaded.
  • To enable active learning, make sure you set --enable_active_learning to true.
  • Currently, Deep Ensembles is not supported for AL.
  • We cannot guarantee the correctness of our implementation. If you notice any abnormalities in the code, please do not hesitate to post an issue.

One example is

python ./run/dnn_al.py \
  --enable_active_learning \
  --n_init_instances 100 \
  --n_al_loops 20 \
  --n_al_select 20 \
  # other model and training hyper-parameters...

5. CITATION

If you find our work helpful, please consider citing it as

@misc{li2023muben,
    title={MUBen: Benchmarking the Uncertainty of Pre-Trained Models for Molecular Property Prediction},
    author={Yinghao Li and Lingkai Kong and Yuanqi Du and Yue Yu and Yuchen Zhuang and Wenhao Mu and Chao Zhang},
    year={2023},
    eprint={2306.10060},
    archivePrefix={arXiv},
    primaryClass={physics.chem-ph}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

muben-0.0.1.tar.gz (116.5 kB view details)

Uploaded Source

File details

Details for the file muben-0.0.1.tar.gz.

File metadata

  • Download URL: muben-0.0.1.tar.gz
  • Upload date:
  • Size: 116.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for muben-0.0.1.tar.gz
Algorithm Hash digest
SHA256 5523fbf51b4eb07c4bb84f73138a40cbae6c8c7931cae00104588cd5e366e7b5
MD5 3124e0f3de268b6770f8e83f763f48d4
BLAKE2b-256 cf82717452fd2729fdce0f8e110dc395b3cc0fe556f4df71d5da0e8fa7644d9a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page