Deep learning packages for molecular discovery with a simple sklearn-style interface
Project description
Deep learning for molecular discovery with a simple sklearn-style interface
torch-molecule is a package that facilitates molecular discovery through deep learning, featuring a user-friendly, sklearn-style interface. It includes model checkpoints for efficient deployment and benchmarking across a range of molecular tasks. Currently, the package focuses on three main components: Predictive Models, Generative Models, and Representation Models. See the List of Supported Models section for all available models.
API Comparison
| Functionality | scikit-learn | torch-molecule |
|---|---|---|
| Property Prediction | predictor.fit/predict(...) |
predictor.fit/autofit/predict(...) |
| Representation Learning | Not supported | encoder.fit/encode(...) |
| Molecular Generation | Not supported | generator.fit/generate(...) |
Installation
-
Create a Conda environment:
conda create --name torch_molecule python=3.11.7 conda activate torch_molecule
-
Install using pip (0.1.2):
pip install torch-molecule
-
Install from source for the latest version:
Clone the repository:
git clone https://github.com/liugangcode/torch-molecule cd torch-molecule
Install:
pip install .
Additional Packages
| Model | Required Packages |
|---|---|
| HFPretrainedMolecularEncoder | transformers |
| BFGNNMolecularPredictor | torch-scatter |
| GRINMolecularPredictor | torch-scatter |
**For models that require torch-scatter: Install using the following command: pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html, e.g.,
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.7.1+cu128.html
For models that require transformers: pip install transformers
Usage
Refer to the tests folder for more use cases.
Python API Example
The following example demonstrates how to use the GREAMolecularPredictor class from torch_molecule:
More examples could be found in the folders examples and tests.
from torch_molecule import GREAMolecularPredictor
# Train GREA model
grea_model = GREAMolecularPredictor(
num_task=num_task,
task_type="regression",
model_name="GREA_multitask",
evaluate_criterion='r2',
evaluate_higher_better=True,
verbose=True
)
# Fit the model
X_train = ['C1=CC=CC=C1', 'C1=CC=CC=C1']
y_train = [[0.5], [1.5]]
X_val = ['C1=CC=CC=C1', 'C1=CC=CC=C1']
y_val = [[0.5], [1.5]]
N_trial = 10
grea_model.autofit(
X_train=X_train.tolist(),
y_train=y_train,
X_val=X_val.tolist(),
y_val=y_val,
n_trials=N_trial,
)
Checkpoints
torch-molecule provides checkpoint functions that can be interacted with on Hugging Face.
from torch_molecule import GREAMolecularPredictor
from sklearn.metrics import mean_absolute_error
# Define the repository ID for Hugging Face
repo_id = "user/repo_id"
# Initialize the GREAMolecularPredictor model
model = GREAMolecularPredictor()
# Train the model using autofit
model.autofit(
X_train=X.tolist(), # List of SMILES strings for training
y_train=y_train, # numpy array [n_samples, n_tasks] for training labels
X_val=X_val.tolist(),# List of SMILES strings for validation
y_val=y_val, # numpy array [n_samples, n_tasks] for validation labels
)
# Make predictions on the test set
output = model.predict(X_test.tolist()) # (n_sample, n_task)
# Calculate the mean absolute error
mae = mean_absolute_error(y_test, output['prediction'])
metrics = {'MAE': mae}
# Save the trained model to Hugging Face
model.save_to_hf(
repo_id=repo_id,
task_id=f"{task_name}",
metrics=metrics,
commit_message=f"Upload GREA_{task_name} model with metrics: {metrics}",
private=False
)
# Load a pretrained checkpoint from Hugging Face
model = GREAMolecularPredictor()
model.load_from_hf(repo_id=repo_id, local_cache=f"{model_dir}/GREA_{task_name}.pt")
# Set model parameters
model.set_params(verbose=True)
# Make predictions using the loaded model
predictions = model.predict(smiles_list)
List of Supported Models
Predictive Models
Generative Models
Representation Models
| Model | Reference |
|---|---|
| MoAMa | Motif-aware Attribute Masking for Molecular Graph Pre-training. LoG 2024 |
| GraphMAE | GraphMAE: Self-Supervised Masked Graph Autoencoders. KDD 2022 |
| AttrMasking | Strategies for Pre-training Graph Neural Networks. ICLR 2020 |
| ContextPred | Strategies for Pre-training Graph Neural Networks. ICLR 2020 |
| EdgePred | Strategies for Pre-training Graph Neural Networks. ICLR 2020 |
| InfoGraph | InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization. ICLR 2020 |
| Supervised | Supervised pretraining |
| Pretrained | More than ten pretrained models from Hugging Face |
Project Structure
See the structure of torch_molecule with the command tree -L 2 torch_molecule -I '__pycache__|*.pyc|*.pyo|.git|old*'
Plan
- Predictive Models: Done: GREA, SGIR, IRM, GIN/GCN w/ virtual, DIR. SMILES-based LSTM/Transformers. TODO more
- Generative Models: Done: Graph DiT, GraphGA, DiGress, GDS, MolGPT TODO: more
- Representation Models: Done: MoAMa, AttrMasking, ContextPred, EdgePred. Many pretrained models from HF. TODO: checkpoints, more
Note: This project is in active development, and features may change.
Acknowledgements
The project template was adapted from https://github.com/lwaekfjlk/python-project-template. We thank the authors for their contribution to the open-source community.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_molecule-0.1.2.tar.gz.
File metadata
- Download URL: torch_molecule-0.1.2.tar.gz
- Upload date:
- Size: 214.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
854a3891521971d2aca3b72ddd28bff135ffa19fda65de3048e8e557569e5d77
|
|
| MD5 |
81072c363ad79a2ded9a47b09112b039
|
|
| BLAKE2b-256 |
56f36326a692f7639b098124b9fe5d22d1391ae8837994ec3d5e6dac1d7c760d
|
File details
Details for the file torch_molecule-0.1.2-py3-none-any.whl.
File metadata
- Download URL: torch_molecule-0.1.2-py3-none-any.whl
- Upload date:
- Size: 297.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7f14b5469d9807bd379ccecb3722c59b39abbf5e2ce829200dc988457950a57b
|
|
| MD5 |
5f56cc3c613b69a362cf941df4334b12
|
|
| BLAKE2b-256 |
47009ad020fdaec0adfafc159edf5819f87ad4bd8ec20f8355cdfb37384c23e2
|