Skip to main content

Deep learning packages for molecular discovery with a simple sklearn-style interface

Project description

torch-molecule logo

GitHub Repository Documentation

Deep learning for molecular discovery with a simple sklearn-style interface


torch-molecule is a package that facilitates molecular discovery through deep learning, featuring a user-friendly, sklearn-style interface. It includes model checkpoints for efficient deployment and benchmarking across a range of molecular tasks. The package focuses on three main components: Predictive Models, Generative Models, and Representation Models, which make molecular AI models easy to implement and deploy.

scikit-learn vs torch-molecule comparison

See the List of Supported Models section for all available models.

Installation

  1. Create a Conda environment:

    conda create --name torch_molecule python=3.11.7
    conda activate torch_molecule
    
  2. Install using pip:

    pip install torch-molecule
    
  3. Install from source for the latest version:

    Clone the repository:

    git clone https://github.com/liugangcode/torch-molecule
    cd torch-molecule
    

    Install:

    pip install .
    

Additional Packages

Model Required Packages
HFPretrainedMolecularEncoder transformers
BFGNNMolecularPredictor torch-scatter
GRINMolecularPredictor torch-scatter
GRINMolecularPredictor (if enable repetition_augmentation=True) CombineMols

For models that require torch-scatter: Install using the following command: pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html, e.g.,

pip install torch-scatter -f https://data.pyg.org/whl/torch-2.7.1+cu128.html

For models that require transformers: pip install transformers

Usage

More examples can be found in the examples and tests folders.

torch-molecule supports applications in broad domains from chemistry, biology, to materials science. To get started, you can load prepared datasets from torch_molecule.datasets (updated after v0.1.3):

Dataset Description Function
qm9 Quantum chemical properties (DFT level) load_qm9
chembl2k Bioactive molecules with drug-like properties load_chembl2k
broad6k Bioactive molecules with drug-like properties load_broad6k
toxcast Toxicity of chemical compounds load_toxcast
admet Chemical absorption, distribution, metabolism, excretion, and toxicity load_admet
gasperm Six gas permeability properties for polymeric materials load_gasperm
zinc250k A common subset of ZINC dataset, which does not have labels and could be used for unconditional generation or virtual screening load_zinc250k
from torch_molecule.datasets import load_qm9

# local_dir is the local path where the dataset will be saved
molecular_data = load_qm9(local_dir='torchmol_data')
smiles_list, property_np_array = molecular_data.data, molecular_data.target

# len(smiles_list): 133885
# Property array shape: (133885, 1)

# load_qm9 returns the target "gap" by default, but you can adjust it by passing new target_cols
target_cols = ['homo', 'lumo', 'gap']
molecular_data = load_qm9(local_dir='torchmol_data', target_cols=target_cols)
smiles_list, property_np_array = molecular_data.data, molecular_data.target

# the target could be None if loading an unlabeled dataset
from torch_molecule.datasets import load_zinc250k
molecular_data = load_zinc250k(local_dir='torchmol_data')
smiles_list = molecular_data.data
assert molecular_data.target is None

(We are actively adding more datasets. We welcome your suggestions and contributions on your datasets!)

Fit a Model

After preparing the dataset, we can easily fit a model similar to how we use sklearn (actually, the coding is even simpler than sklearn, as we still need to do feature engineering in sklearn to convert molecule SMILES into vectors):

from torch_molecule import GREAMolecularPredictor

split = int(0.8 * len(smiles_list))

grea = GREAMolecularPredictor(
    num_task=num_task,
    task_type="regression",
    evaluate_higher_better=False,
    verbose="progress_bar" #or "print_statement" recommended for jupyter notebooks, or "none"
)

# Fit with automatic hyperparameter tuning with 10 attempts, or implement .fit() with the default/manual hyperparameters
grea.autofit(
    X_train=smiles_list[:split],
    y_train=property_np_array[:split],
    X_val=smiles_list[split:],
    y_val=property_np_array[split:],
    n_trials=10,
)

Checkpoints

torch-molecule provides checkpoint functions that can be interacted with on Hugging Face:

from torch_molecule import GREAMolecularPredictor

repo_id = "user/repo_id"  # replace with your own Hugging Face username and repo_id

# Save the trained model to Hugging Face
grea.save_to_hf(
    repo_id=repo_id,
    task_id="qm9_grea",
    commit_message="Upload qm9_grea",
    private=False
)

# Load a pretrained checkpoint from Hugging Face
model = GREAMolecularPredictor()
model.load_from_hf(repo_id=repo_id, local_cache=f"{model_dir}/GREA_{task_name}.pt")

# Adjust model parameters and make predictions
model.set_params(verbose='none')
predictions = model.predict(smiles_list)

Or you can save the model to a local path:

grea.save_to_local("qm9_grea.pt")

new_model = GREAMolecularPredictor()
new_model.load_from_local("qm9_grea.pt")

List of Supported Models

Predictive Models

Model Reference
GRIN Learning Repetition-Invariant Representations for Polymer Informatics. NeurIPS 2025.
BFGNN Graph neural networks extrapolate out-of-distribution for shortest paths. March 2025
SGIR Semi-Supervised Graph Imbalanced Regression. KDD 2023
GREA Graph Rationalization with Environment-based Augmentations. KDD 2022
DIR Discovering Invariant Rationales for Graph Neural Networks. ICLR 2022
SSR SizeShiftReg: a Regularization Method for Improving Size-Generalization in Graph Neural Networks. NeurIPS 2022
IRM Invariant Risk Minimization (2019)
RPGNN Relational Pooling for Graph Representations. ICML 2019
GNNs Graph Convolutional Networks. ICLR 2017 and Graph Isomorphism Network. ICLR 2019
Transformer (SMILES) Transformer (Attention is All You Need. NeurIPS 2017) based on SMILES strings
LSTM (SMILES) Long short-term memory (Neural Computation 1997) based on SMILES strings

Generative Models

Model Reference
DeFoG DeFoG: Discrete Flow Matching for Graph Generation. ICML 2025
Graph DiT Graph Diffusion Transformers for Multi-Conditional Molecular Generation. NeurIPS 2024
DiGress DiGress: Discrete Denoising Diffusion for Graph Generation. ICLR 2023
GDSS Score-based Generative Modeling of Graphs via the System of Stochastic Differential Equations. ICML 2022
MolGPT MolGPT: Molecular Generation Using a Transformer-Decoder Model. Journal of Chemical Information and Modeling 2021
JTVAE Junction Tree Variational Autoencoder for Molecular Graph Generation. ICML 2018.
GraphGA A Graph-Based Genetic Algorithm and Its Application to the Multiobjective Evolution of Median Molecules. Journal of Chemical Information and Computer Sciences 2004
LSTM (SMILES) Long short-term memory (Neural Computation 1997) based on SMILES strings

Representation Models

Model Reference
MoAMa Motif-aware Attribute Masking for Molecular Graph Pre-training. LoG 2024
GraphMAE GraphMAE: Self-Supervised Masked Graph Autoencoders. KDD 2022
AttrMasking Strategies for Pre-training Graph Neural Networks. ICLR 2020
ContextPred Strategies for Pre-training Graph Neural Networks. ICLR 2020
EdgePred Strategies for Pre-training Graph Neural Networks. ICLR 2020
InfoGraph InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization. ICLR 2020
Supervised Supervised pretraining
Pretrained GPT2-ZINC-87M: GPT-2 based model (87M parameters) pretrained on ZINC dataset with ~480M SMILES strings.
RoBERTa-ZINC-480M: RoBERTa based model (102M parameters) pretrained on ZINC dataset with ~480M SMILES strings.
UniKi/bert-base-smiles: BERT model pretrained on SMILES strings.
ChemBERTa-zinc-base-v1: RoBERTa model pretrained on ZINC dataset with ~100k SMILES strings.
ChemBERTa series: Available in multiple sizes and training objectives (MLM/MTR). ChemBERTa-5M-MLM, ChemBERTa-5M-MTR, ChemBERTa-10M-MLM, ChemBERTa-10M-MTR, ChemBERTa-77M-MLM, ChemBERTa-77M-MTR.
ChemGPT series: GPT-Neo based models pretrained on PubChem10M dataset with SELFIES strings. ChemGPT-1.2B, ChemGPT-4.7B, ChemGPT-19B.

Acknowledgements

The project template was adapted from https://github.com/lwaekfjlk/python-project-template. We thank the authors for their contribution to the open-source community.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_molecule-0.1.7.tar.gz (283.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_molecule-0.1.7-py3-none-any.whl (367.6 kB view details)

Uploaded Python 3

File details

Details for the file torch_molecule-0.1.7.tar.gz.

File metadata

  • Download URL: torch_molecule-0.1.7.tar.gz
  • Upload date:
  • Size: 283.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.7

File hashes

Hashes for torch_molecule-0.1.7.tar.gz
Algorithm Hash digest
SHA256 24acf93a63d1437e82621a9b0edf409c6104246944ef5a3ffa399b4ab131c7a0
MD5 e21e095a55528e52a49083387fb72639
BLAKE2b-256 59929b82721a029b796968e45952bded68ccef89966d798753646a9490074aa6

See more details on using hashes here.

File details

Details for the file torch_molecule-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: torch_molecule-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 367.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.7

File hashes

Hashes for torch_molecule-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 25f2de26dd668fdf347ed506f3b7e22c231ef61f41a32c994f23359dae622750
MD5 a1972a6a18600781a90e2541137d230f
BLAKE2b-256 61a086166373afaac0fb7acf43c5b7dc5d17b4a1725ea166403846a7903a35e9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page