Skip to main content

A scoring, benchmarking and evaluation framework for goal directed generative models

Project description

DOI

SMILES-RNN

This repo contains code for a SMILES-based recurrent neural network used for de novo molecule generation with several reinforcement learning algorithms available for molecule optimization. This was written to be used in conjunction with MolScore - although any other scoring function can also be used.

Installation

This code can be installed via pip.

pip install smiles-rnn

Or via cloning this repository and setting up an environment with mamba.

git clone https://github.com/MorganCThomas/SMILES-RNN.git
cd SMILES-RNN
mamba env create -f environment.yml
pip install ./

Usage

Arguments to any of the scripts can be printed by running

python <script> --help

Training a prior

To train a prior run the train_prior.py script. You may note below that several other grammars are also implemented including DeepSMILES, SELFIES, atomInSmiles, and SAFE which are generated by conversion from SMILES. When using randomization (which can be done at train time) the SMILES are first randomized and then each random SMILES is converted to the alternative grammar. You can optionally pass in validation of test SMILES where the log likelihood will be compared during training which can be monitored via tensorboard. *Currently choosing a specific GPU device does not work, it will run on the default GPU device (i.e., index 0).

Train an initial prior model based on smiles data

positional arguments:
  {RNN,Transformer,GTr}
                        Model architecture
    RNN                 Use simple forward RNN with GRU or LSTM
    Transformer         TransformerEncoder model
    GTr                 StableTransformerEncoder model

optional arguments:
  -h, --help            show this help message and exit
  --grammar {SMILES,deepSMILES,deepSMILES_r,deepSMILES_cr,deepSMILES_c,deepSMILES_cb,deepSMILES_b,SELFIES, AIS,SAFE,SmiZip}
                        Choice of grammar to use, SMILES will be encoded and decoded via grammar (default: SMILES)
  --randomize           Training smiles will be randomized using default arguments (10 restricted) (default: False)
  --n_jobs N_JOBS       If randomizing use multiple cores (default: 1)
  --smizip-ngrams SMIZIP_NGRAMS
                        SmiZip JSON file containing the list of n-grams (default: None)
  --valid_smiles VALID_SMILES
                        Validation smiles (default: None)
  --test_smiles TEST_SMILES
                        Test smiles (default: None)
  --validate_frequency VALIDATE_FREQUENCY
                        (default: 500)
  --n_epochs N_EPOCHS   (default: 5)
  --batch_size BATCH_SIZE
                        (default: 128)
  -d DEVICE, --device DEVICE
                        cpu/gpu or device number (default: gpu)

required arguments:
  -i TRAIN_SMILES, --train_smiles TRAIN_SMILES
                        Path to smiles file (default: None)
  -o OUTPUT_DIRECTORY, --output_directory OUTPUT_DIRECTORY
                        Output directory to save model (default: None)
  -s SUFFIX, --suffix SUFFIX
                        Suffix to name files (default: None)

Sampling from a trained prior

You can sample a trained model by running the sample_model.py script.

Sample smiles from model

optional arguments:
  -h, --help            show this help message and exit
  -p PATH, --path PATH  Path to checkpoint (.ckpt) (default: None)
  -m {RNN,Transformer,GTr}, --model {RNN,Transformer,GTr}
                        Choice of architecture (default: None)
  -o OUTPUT, --output OUTPUT
                        Path to save file (e.g. Data/Prior_10k.smi) (default: None)
  -d DEVICE, --device DEVICE
                        (default: gpu)
  -n NUMBER, --number NUMBER
                        (default: 10000)
  -t TEMPERATURE, --temperature TEMPERATURE
                        Temperature to sample (1: multinomial, <1: Less random, >1: More random) (default: 1.0)
  --psmiles PSMILES     Either scaffold smiles labelled with decoration points (*) or fragments for linking with connection points (*) and seperated by a period .
                        (default: None)
  --unique              Keep sampling until n unique canonical molecules have been sampled (default: False)
  --native              If trained using an alternative grammar e.g., SELFIES. don't convet back to SMILES (default: False)

Fine-tuning

You can also fine-tune a trained model with a smaller dataset of SMILES by running the fine_tune.py script. If the pre-trained model was trained with an alternative grammar, these SMILES will also be converted at train time i.e., you always input molecules as SMILES.

Fine-tune a pre-trained prior model based on a smaller dataset

optional arguments:
  -h, --help            show this help message and exit

Required arguments:
  -p PRIOR, --prior PRIOR
                        Path to prior file (default: None)
  -i TUNE_SMILES, --tune_smiles TUNE_SMILES
                        Path to fine-tuning smiles file (default: None)
  -o OUTPUT_DIRECTORY, --output_directory OUTPUT_DIRECTORY
                        Output directory to save model (default: None)
  -s SUFFIX, --suffix SUFFIX
                        Suffix to name files (default: None)
  --model {RNN,Transformer,GTr}
                        Choice of architecture (default: None)

Optional arguments:
  --randomize           Training smiles will be randomized using default arguments (10 restricted) (default: False)
  --valid_smiles VALID_SMILES
                        Validation smiles (default: None)
  --test_smiles TEST_SMILES
                        Test smiles (default: None)
  --n_epochs N_EPOCHS   (default: 10)
  --batch_size BATCH_SIZE
                        (default: 128)
  -d DEVICE, --device DEVICE
                        cpu/gpu or device number (default: gpu)
  -f FREEZE, --freeze FREEZE
                        Number of RNN layers to freeze (default: None)

Reinforcement learning

Finally, reinforcement learning can be run with the reinforcement_learning.py script. Note that this is written to work with MolScore to handle the objective task i.e., molecule scoring. However, one can also use the underlying ReinforcementLearning class found in the model/RL.py module where another scoring function can be provided. This class has several methods for different reinforcement learning algorithms including:

  • Reinforce
  • REINVENT
  • BAR
  • Hill-Climb
  • Augmented Hill-Climb

There are generic arguments that can be viewed by running python reinforcement_learning.py --help

Optimize an RNN towards a reward via reinforment learning

optional arguments:
  -h, --help            show this help message and exit

Required arguments:
  -p PRIOR, --prior PRIOR
                        Path to prior checkpoint (.ckpt) (default: None)
  -m MOLSCORE_CONFIG, --molscore_config MOLSCORE_CONFIG
                        Path to molscore config (.json) (default: None)
  --model {RNN,Transformer,GTr}
                        Choice of architecture (default: None)

Optional arguments:
  -a AGENT, --agent AGENT
                        Path to agent checkpoint (.ckpt) (default: None)
  -d DEVICE, --device DEVICE
                        (default: gpu)
  -f FREEZE, --freeze FREEZE
                        Number of RNN layers to freeze (default: None)
  --save_freq SAVE_FREQ
                        How often to save models (default: 100)
  --verbose             Whether to print loss (default: False)
  --psmiles PSMILES     Either scaffold smiles labelled with decoration points (*) or fragments for linking with connection points (*) and seperated by a period .
                        (default: None)
  --psmiles_multi       Whether to conduct multiple updates (1 per decoration) (default: False)
  --psmiles_canonical   Whether to attach decorations one at a time, based on attachment point with lowest NLL, otherwise attachment points will be shuffled within a
                        batch (default: False)
  --psmiles_optimize    Whether to optimize the SMILES prompts during sampling (default: False)
  --psmiles_lr_decay PSMILES_LR_DECAY
                        Amount to decay the learning rate at the beginning of iterative prompting (1=no decay) (default: 1)
  --psmiles_lr_epochs PSMILES_LR_EPOCHS
                        Number of epochs before the decayed learning rate returns to normal (default: 10)

RL strategy:
  {RV,RV2,BAR,AHC,HC,HC-reg,RF,RF-reg}
                        Which reinforcement learning algorithm to use

And RL algorithm specific arguments that can be viewed by running e.g., python reinforcement_learning.py AHC --help

Augmented Hill-Climb

optional arguments:
  -h, --help            show this help message and exit
  --n_steps N_STEPS     (default: 500)
  --batch_size BATCH_SIZE
                        (default: 64)
  -s SIGMA, --sigma SIGMA
                        Scaling coefficient of score (default: 60)
  -k [0-1], --topk [0-1]
                        Fraction of top molecules to keep (default: 0.5)
  -lr LEARNING_RATE, --learning_rate LEARNING_RATE
                        Adam learning rate (default: 0.0005)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

SMILES_RNN-2.0.1-py3-none-any.whl (50.4 kB view details)

Uploaded Python 3

File details

Details for the file SMILES_RNN-2.0.1-py3-none-any.whl.

File metadata

  • Download URL: SMILES_RNN-2.0.1-py3-none-any.whl
  • Upload date:
  • Size: 50.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for SMILES_RNN-2.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 1f59fe2b675145db7c32d8b98822e48c0621d12571472721176930b6f81c94f8
MD5 dafad182cf73ba399fc5f69d5322b8a2
BLAKE2b-256 a46c6be3f4a1282b7beefad7b7f1271d1a5ae21ecd7377a66a3961a0164c4e18

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page