Skip to main content

Deep learning packages for molecular discovery with a simple sklearn-style interface

Project description

torch-molecule logo

GitHub Repository Documentation

Deep learning for molecular discovery with a simple sklearn-style interface


torch-molecule is a package that facilitates molecular discovery through deep learning, featuring a user-friendly, sklearn-style interface. It includes model checkpoints for efficient deployment and benchmarking across a range of molecular tasks. Currently, the package focuses on three main components: Predictive Models, Generative Models, and Representation Models. See the List of Supported Models section for all available models.

API Comparison

Functionality scikit-learn torch-molecule
Property Prediction predictor.fit/predict(...) predictor.fit/autofit/predict(...)
Representation Learning Not supported encoder.fit/encode(...)
Molecular Generation Not supported generator.fit/generate(...)

Installation

  1. Create a Conda environment:

    conda create --name torch_molecule python=3.11.7
    conda activate torch_molecule
    
  2. Install using pip (0.1.2):

    pip install torch-molecule
    
  3. Install from source for the latest version:

    Clone the repository:

    git clone https://github.com/liugangcode/torch-molecule
    cd torch-molecule
    

    Install:

    pip install .
    

Additional Packages

Model Required Packages
HFPretrainedMolecularEncoder transformers
BFGNNMolecularPredictor torch-scatter
GRINMolecularPredictor torch-scatter

**For models that require torch-scatter: Install using the following command: pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html, e.g.,

pip install torch-scatter -f https://data.pyg.org/whl/torch-2.7.1+cu128.html

For models that require transformers: pip install transformers

Usage

Refer to the tests folder for more use cases.

Python API Example

The following example demonstrates how to use the GREAMolecularPredictor class from torch_molecule:

More examples could be found in the folders examples and tests.

from torch_molecule import GREAMolecularPredictor

# Train GREA model
grea_model = GREAMolecularPredictor(
    num_task=num_task,
    task_type="regression",
    model_name="GREA_multitask",
    evaluate_criterion='r2',
    evaluate_higher_better=True,
    verbose=True
)

# Fit the model
X_train = ['C1=CC=CC=C1', 'C1=CC=CC=C1']
y_train = [[0.5], [1.5]]
X_val = ['C1=CC=CC=C1', 'C1=CC=CC=C1']
y_val = [[0.5], [1.5]]
N_trial = 10

grea_model.autofit(
    X_train=X_train.tolist(),
    y_train=y_train,
    X_val=X_val.tolist(),
    y_val=y_val,
    n_trials=N_trial,
)

Checkpoints

torch-molecule provides checkpoint functions that can be interacted with on Hugging Face.

from torch_molecule import GREAMolecularPredictor
from sklearn.metrics import mean_absolute_error

# Define the repository ID for Hugging Face
repo_id = "user/repo_id"

# Initialize the GREAMolecularPredictor model
model = GREAMolecularPredictor()

# Train the model using autofit
model.autofit(
    X_train=X.tolist(),  # List of SMILES strings for training
    y_train=y_train,     # numpy array [n_samples, n_tasks] for training labels
    X_val=X_val.tolist(),# List of SMILES strings for validation
    y_val=y_val,         # numpy array [n_samples, n_tasks] for validation labels
)

# Make predictions on the test set
output = model.predict(X_test.tolist()) # (n_sample, n_task)

# Calculate the mean absolute error
mae = mean_absolute_error(y_test, output['prediction'])
metrics = {'MAE': mae}

# Save the trained model to Hugging Face
model.save_to_hf(
    repo_id=repo_id,
    task_id=f"{task_name}",
    metrics=metrics,
    commit_message=f"Upload GREA_{task_name} model with metrics: {metrics}",
    private=False
)

# Load a pretrained checkpoint from Hugging Face
model = GREAMolecularPredictor()
model.load_from_hf(repo_id=repo_id, local_cache=f"{model_dir}/GREA_{task_name}.pt")

# Set model parameters
model.set_params(verbose=True)

# Make predictions using the loaded model
predictions = model.predict(smiles_list)

List of Supported Models

Predictive Models

Model Reference
GRIN Learning Repetition-Invariant Representations for Polymer Informatics. May 2025
BFGNN Graph neural networks extrapolate out-of-distribution for shortest paths. March 2025
SGIR Semi-Supervised Graph Imbalanced Regression. KDD 2023
GREA Graph Rationalization with Environment-based Augmentations. KDD 2022
DIR Discovering Invariant Rationales for Graph Neural Networks. ICLR 2022
SSR SizeShiftReg: a Regularization Method for Improving Size-Generalization in Graph Neural Networks. NeurIPS 2022
IRM Invariant Risk Minimization (2019)
RPGNN Relational Pooling for Graph Representations. ICML 2019
GNNs Graph Convolutional Networks. ICLR 2017 and Graph Isomorphism Network. ICLR 2019
Transformer (SMILES) Transformer (Attention is All You Need. NeurIPS 2017) based on SMILES strings
LSTM (SMILES) Long short-term memory (Neural Computation 1997) based on SMILES strings

Generative Models

Model Reference
Graph DiT Graph Diffusion Transformers for Multi-Conditional Molecular Generation. NeurIPS 2024
DiGress DiGress: Discrete Denoising Diffusion for Graph Generation. ICLR 2023
GDSS Score-based Generative Modeling of Graphs via the System of Stochastic Differential Equations. ICML 2022
MolGPT MolGPT: Molecular Generation Using a Transformer-Decoder Model. Journal of Chemical Information and Modeling 2021
JTVAE Junction Tree Variational Autoencoder for Molecular Graph Generation. ICML 2018.
GraphGA A Graph-Based Genetic Algorithm and Its Application to the Multiobjective Evolution of Median Molecules. Journal of Chemical Information and Computer Sciences 2004
LSTM (SMILES) Long short-term memory (Neural Computation 1997) based on SMILES strings

Representation Models

Model Reference
MoAMa Motif-aware Attribute Masking for Molecular Graph Pre-training. LoG 2024
GraphMAE GraphMAE: Self-Supervised Masked Graph Autoencoders. KDD 2022
AttrMasking Strategies for Pre-training Graph Neural Networks. ICLR 2020
ContextPred Strategies for Pre-training Graph Neural Networks. ICLR 2020
EdgePred Strategies for Pre-training Graph Neural Networks. ICLR 2020
InfoGraph InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization. ICLR 2020
Supervised Supervised pretraining
Pretrained More than ten pretrained models from Hugging Face

Project Structure

See the structure of torch_molecule with the command tree -L 2 torch_molecule -I '__pycache__|*.pyc|*.pyo|.git|old*'

Plan

  1. Predictive Models: Done: GREA, SGIR, IRM, GIN/GCN w/ virtual, DIR. SMILES-based LSTM/Transformers. TODO more
  2. Generative Models: Done: Graph DiT, GraphGA, DiGress, GDS, MolGPT TODO: more
  3. Representation Models: Done: MoAMa, AttrMasking, ContextPred, EdgePred. Many pretrained models from HF. TODO: checkpoints, more

Note: This project is in active development, and features may change.

Acknowledgements

The project template was adapted from https://github.com/lwaekfjlk/python-project-template. We thank the authors for their contribution to the open-source community.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_molecule-0.1.2.tar.gz (214.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_molecule-0.1.2-py3-none-any.whl (297.1 kB view details)

Uploaded Python 3

File details

Details for the file torch_molecule-0.1.2.tar.gz.

File metadata

  • Download URL: torch_molecule-0.1.2.tar.gz
  • Upload date:
  • Size: 214.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.7

File hashes

Hashes for torch_molecule-0.1.2.tar.gz
Algorithm Hash digest
SHA256 854a3891521971d2aca3b72ddd28bff135ffa19fda65de3048e8e557569e5d77
MD5 81072c363ad79a2ded9a47b09112b039
BLAKE2b-256 56f36326a692f7639b098124b9fe5d22d1391ae8837994ec3d5e6dac1d7c760d

See more details on using hashes here.

File details

Details for the file torch_molecule-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: torch_molecule-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 297.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.7

File hashes

Hashes for torch_molecule-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 7f14b5469d9807bd379ccecb3722c59b39abbf5e2ce829200dc988457950a57b
MD5 5f56cc3c613b69a362cf941df4334b12
BLAKE2b-256 47009ad020fdaec0adfafc159edf5819f87ad4bd8ec20f8355cdfb37384c23e2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page