Implementations of deep generative models of molecules.
Project description
MoLeR: A Model for Molecule Generation
This repository contains training and inference code for the MoLeR model introduced in Learning to Extend Molecular Scaffolds with Structural Motifs. We also include our implementation of CGVAE, but it currently lacks integration with the high-level model interface, and is provided mostly for reference.
Quick start
The molecule_generation
package depends on rdkit
, which has to be installed separately. One simple approach is to do it via conda
conda create --name moler-env python=3.7
conda activate moler-env
conda install rdkit==2020.09.1.0 -c conda-forge
Then, to install molecule_generation
, simply run pip install -e .
within the root folder.
Note that in the instructions above we pinned the rdkit
version, as this is the version the code has been tested with. However, our code is likely to work with other modern version of rdkit
as well.
A MoLeR checkpoint trained using the default hyperparameters is available here. If you save it under MODEL_DIR
, then you can sample 10 molecules by running
molecule_generation sample MODEL_DIR 10
See the next sections for how to train your own model and run more advanced inference.
Workflow
Working with MoLeR can be roughly divided into three stages:
- data preprocessing, where a plain text list of SMILES strings is turned into
*.pkl
files containing descriptions of the molecular graphs and generation traces; - training, where MoLeR is trained on the preprocessed data until convergence; and
- inference, where one loads the model and performs batched encoding, decoding or sampling.
Additionally, you can visualise the decoding traces and internal action probabilities of the model, which can be useful for debugging.
Data Preprocessing
To run preprocessing, your data has to follow a simple GuacaMol format (files train.smiles
, valid.smiles
and test.smiles
, each containing SMILES strings, one per line). Then, you can preprocess the data by running
molecule_generation preprocess INPUT_DIR OUTPUT_DIR TRACE_DIR
where INPUT_DIR
is the directory containing the three *.smiles
files, OUTPUT_DIR
is used for intermediate results, and TRACE_DIR
for final preprocessed files containing the generation traces. Additionally, the preprocess
command accepts command-line arguments to override various preprocessing hyperparameters (notably, the size of the motif vocabulary).
This step roughly corresponds to applying Algorithm 2 from our paper to each molecule in the input data.
After running the above, you should see an output similar to
2022-03-10 11:22:15,927 preprocess.py:239 INFO 1273104 train datapoints, 79568 validation datapoints, 238706 test datapoints loaded, beginning featurization.
2022-03-10 11:22:15,927 preprocess.py:245 INFO Featurising data...
2022-03-10 11:22:15,927 molecule_dataset_utils.py:261 INFO Turning smiles into mol
2022-03-10 11:22:15,927 molecule_dataset_utils.py:79 INFO Initialising feature extractors and motif vocabulary.
2022-03-10 11:44:17,864 motif_utils.py:158 INFO Motifs in total: 99751
2022-03-10 11:44:25,755 motif_utils.py:182 INFO Removing motifs with less than 3 atoms
2022-03-10 11:44:25,755 motif_utils.py:183 INFO Motifs remaining: 99653
2022-03-10 11:44:25,764 motif_utils.py:190 INFO Truncating the list of motifs to 128 most common
2022-03-10 11:44:25,764 motif_utils.py:192 INFO Motifs remaining: 128
2022-03-10 11:44:25,764 motif_utils.py:199 INFO Finished creating the motif vocabulary
2022-03-10 11:44:25,764 motif_utils.py:200 INFO | Number of motifs: 128
2022-03-10 11:44:25,764 motif_utils.py:203 INFO | Min frequency: 3602
2022-03-10 11:44:25,764 motif_utils.py:204 INFO | Max frequency: 1338327
2022-03-10 11:44:25,764 motif_utils.py:205 INFO | Min num atoms: 3
2022-03-10 11:44:25,764 motif_utils.py:206 INFO | Max num atoms: 10
2022-03-10 11:44:25,862 preprocess.py:255 INFO Completed initializing feature extractors; featurising and saving data now.
Wrote 1273104 datapoints to /guacamol/output/train.jsonl.gz.
Wrote 79568 datapoints to /guacamol/output/valid.jsonl.gz.
Wrote 238706 datapoints to /guacamol/output/test.jsonl.gz.
Wrote metadata to /guacamol/output/metadata.pkl.gz.
(...proceeds to compute generation traces...)
After the preprocessed graphs are saved into OUTPUT_DIR
, they will be turned into concrete generation traces, which is typically the most compute-intensive part of preprocessing. During that part, the preprocessing code may print errors, noting molecules that could not have been parsed or failed other assertions; MoLeR's preprocessing is robust to such cases, and will simply skip any problematic samples.
Training
Having stored some preprocessed data under TRACE_DIR
, MoLeR can be trained by running
molecule_generation train MoLeR TRACE_DIR
The train
command accepts many command-line arguments to override training and architectural hyperparameters, most of which are accessed through passing --model-params-override
. For example, the following trains a MoLeR model using GGNN
-style message passing (instead of the default GNN_Edge_MLP
) and using fewer layers in both the encoder and the decoder GNNs:
molecule_generation train MoLeR TRACE_DIR \
--model GGNN \
--model-params-override '{"gnn_num_layers": 6, "decoder_gnn_num_layers": 6}'
As tf2-gnn is highly flexible, MoLeR supports a vast space of architectural configurations.
After running molecule_generation train
, you should see an output similar to
(...tensorflow messages, hyperparameter dump...)
Initial valid metric:
Avg weighted sum. of graph losses: 122.1728
Avg weighted sum. of prop losses: 0.4712
Avg node class. loss: 35.9361
Avg first node class. loss: 27.4681
Avg edge selection loss: 1.7522
Avg edge type loss: 3.8963
Avg attachment point selection loss: 1.1227
Avg KL divergence: 7335960.5000
Property results: sa_score: MAE 11.23, MSE 1416.26 (norm MAE: 13.89) | clogp: MAE 10.87, MSE 4620.69 (norm MAE: 5.98) | mol_weight: MAE 407.42, MSE 185524.38 (norm MAE: 3.70).
(Stored model metadata and weights to trained_model/GNN_Edge_MLP_MoLeR__2022-03-01_18-15-14_best.pkl).
(...training proceeds...)
By default, training proceeds until there is no improvement in validation loss for 3 consecutive mini-epochs, where a mini-epoch is defined as 5000 training steps; this can be controlled through the --patience
flag and the num_train_steps_between_valid
model parameter, respectively.
Inference
After a model has been trained and saved under MODEL_DIR
, we provide a simple API to load it.
To sample molecules from the model, simply run
molecule_generation sample MODEL_DIR NUM_SAMPLES
and, similarly, to encode a list of SMILES stored under SMILES_PATH
into latent vectors, and store them under OUTPUT_PATH
molecule_generation encode MODEL_DIR SMILES_PATH OUTPUT_PATH
In all cases MODEL_DIR
denotes the directory containing the model checkpoint, not the path to the checkpoint itself. The model loader will expect that MODEL_DIR
contains exactly one MoLeR checkpoint, which is recognized automatically using the filename.
You can also load a trained MoLeR model directly from Python via
from molecule_generation import ModelWrapper
model_dir = "./example_model_directory"
example_smiles = ["c1ccccc1", "CNC=O"]
with ModelWrapper(model_dir) as model:
embeddings = model.encode(example_smiles)
print(f"Embedding shape: {embeddings[0].shape}")
decoded_smiles = model.decode(embeddings)
print(f"Encoded: {example_smiles}, decoded: {decoded_smiles}")
As shown above, MoLeR is loaded through a context manager. Behind the scenes, entering the context spawns parallel processes which await queries for encoding/decoding; these processes continue to live as long as the context is active. The degree of paralellism can be configured by passing a num_workers
argument to ModelWrapper
.
Visualisation
We support two subtly different modes of visualisation: decoding a given latent vector, and decoding a latent vector created by encoding a given SMILES string. In the former case, the decoder runs as normal during inference; in the latter case we know the ground-truth input, so we teacher-force the correct decoding decisions.
To enter the visualiser, run either
molecule_generation visualise cli MODEL_DIR SMILES_OR_SAMPLES_PATH
to get the result printed as plain text in the CLI, or
molecule_generation visualise html MODEL_DIR SMILES_OR_SAMPLES_PATH OUTPUT_DIR
to get the result saved under OUTPUT_DIR
as a static HTML webpage.
Code Structure
All of our models are implemented in Tensorflow 2, and are meant to be easy to extend and build upon. We use tf2-gnn for the core Graph Neural Network components.
The MoLeR model itself is implemented as a MoLeRVae
class, inheriting from GraphTaskModel
in tf2-gnn
; that base class encapsulates the encoder GNN. The decoder GNN is instantiated as an external MoLeRDecoder
layer; it also includes batched inference code, which forces the maximum likelihood choice at every step.
Authors
- Krzysztof Maziarz
- Henry Jackson-Flux
- Marc Brockschmidt
- Pashmina Cameron
- Sarah Lewis
- Marwin Segler
- Megan Stanley
- Paweł Czyż
- Ashok Thillaisundaram
Note: as git history was truncated at the point of open-sourcing, GitHub's statistics do not reflect the degree of contribution from some of the authors. All listed above had an impact on the code, and are (approximately) ordered by decreasing contribution.
The code is maintained by the Generative Chemistry group at Microsoft Research, Cambridge, UK. We are hiring.
MoLeR was created as part of our collaboration with Novartis Research. In particular, its design was guided by Nadine Schneider, Finton Sirockin, Nikolaus Stiefl, as well as others from Novartis.
Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.
Style Guide
- For code style, use black and flake8.
- For commit messages, use imperative style and follow the semmantic commit messages template; e.g.
feat(moler_decoder): Improve masking of invalid actions
Trademarks
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file molecule_generation-0.1.0.tar.gz
.
File metadata
- Download URL: molecule_generation-0.1.0.tar.gz
- Upload date:
- Size: 946.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.7.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0296db7a81c0ebbbd307e4a15e712323abebeadfe2e1facda758ff5c9f362bc1 |
|
MD5 | 5958cb2846f627f57d76c922b9f252a8 |
|
BLAKE2b-256 | 5ebaa35fe22bcdf6ce6c731063ac353f826e0722ee5473b49a22570b42e7b7e6 |
File details
Details for the file molecule_generation-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: molecule_generation-0.1.0-py3-none-any.whl
- Upload date:
- Size: 980.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.7.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 54512a07661a823828fd1bad13ed6964291cf1cae92eae92869a8b6f68b0c7e6 |
|
MD5 | ee7874c3804078ca96cb631108a971bf |
|
BLAKE2b-256 | 8e1276dd1cc827baf25bbd192033be409f1be2e00efa5548008d437311b7aa0b |