Deep learning packages for molecular discovery with a simple sklearn-style interface
Project description
Deep learning for molecular discovery with a simple sklearn-style interface
torch-molecule is a package that facilitates molecular discovery through deep learning, featuring a user-friendly, sklearn-style interface. It includes model checkpoints for efficient deployment and benchmarking across a range of molecular tasks. The package focuses on three main components: Predictive Models, Generative Models, and Representation Models, which make molecular AI models easy to implement and deploy.
See the List of Supported Models section for all available models.
Installation
-
Create a Conda environment:
conda create --name torch_molecule python=3.11.7 conda activate torch_molecule
-
Install using pip (v0.1.3):
pip install torch-molecule
-
Install from source for the latest version:
Clone the repository:
git clone https://github.com/liugangcode/torch-molecule cd torch-molecule
Install:
pip install .
Additional Packages
| Model | Required Packages |
|---|---|
| HFPretrainedMolecularEncoder | transformers |
| BFGNNMolecularPredictor | torch-scatter |
| GRINMolecularPredictor | torch-scatter |
For models that require torch-scatter: Install using the following command: pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html, e.g.,
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.7.1+cu128.html
For models that require transformers: pip install transformers
Usage
More examples can be found in the
examplesandtestsfolders.
torch-molecule supports applications in broad domains from chemistry, biology, to materials science. To get started, you can load prepared datasets from torch_molecule.datasets (updated after v0.1.3):
| Dataset | Description | Function |
|---|---|---|
| qm9 | Quantum chemical properties (DFT level) | load_qm9 |
| chembl2k | Bioactive molecules with drug-like properties | load_chembl2k |
| broad6k | Bioactive molecules with drug-like properties | load_broad6k |
| toxcast | Toxicity of chemical compounds | load_toxcast |
| admet | Chemical absorption, distribution, metabolism, excretion, and toxicity | load_admet |
| gasperm | Six gas permeability properties for polymeric materials | load_gasperm |
| zinc250k | A common subset of ZINC dataset, which does not have labels and could be used for unconditional generation or virtual screening | load_zinc250k |
from torch_molecule.datasets import load_qm9
# local_dir is the local path where the dataset will be saved
molecular_data = load_qm9(local_dir='torchmol_data')
smiles_list, property_np_array = molecular_data.data, molecular_data.target
# len(smiles_list): 133885
# Property array shape: (133885, 1)
# load_qm9 returns the target "gap" by default, but you can adjust it by passing new target_cols
target_cols = ['homo', 'lumo', 'gap']
molecular_data = load_qm9(local_dir='torchmol_data', target_cols=target_cols)
smiles_list, property_np_array = molecular_data.data, molecular_data.target
# the target could be None if loading an unlabeled dataset
from torch_molecule.datasets import load_zinc250k
molecular_data = load_zinc250k(local_dir='torchmol_data')
smiles_list = molecular_data.data
assert molecular_data.target is None
(We are actively adding more datasets. We welcome your suggestions and contributions on your datasets!)
Fit a Model
After preparing the dataset, we can easily fit a model similar to how we use sklearn (actually, the coding is even simpler than sklearn, as we still need to do feature engineering in sklearn to convert molecule SMILES into vectors):
from torch_molecule import GREAMolecularPredictor
split = int(0.8 * len(smiles_list))
grea = GREAMolecularPredictor(
num_task=num_task,
task_type="regression",
evaluate_higher_better=False,
verbose="progress_bar" #or "print_statement" recommended for jupyter notebooks, or "none"
)
# Fit with automatic hyperparameter tuning with 10 attempts, or implement .fit() with the default/manual hyperparameters
grea.autofit(
X_train=smiles_list[:split],
y_train=property_np_array[:split],
X_val=smiles_list[split:],
y_val=property_np_array[split:],
n_trials=10,
)
Checkpoints
torch-molecule provides checkpoint functions that can be interacted with on Hugging Face:
from torch_molecule import GREAMolecularPredictor
repo_id = "user/repo_id" # replace with your own Hugging Face username and repo_id
# Save the trained model to Hugging Face
grea.save_to_hf(
repo_id=repo_id,
task_id="qm9_grea",
commit_message="Upload qm9_grea",
private=False
)
# Load a pretrained checkpoint from Hugging Face
model = GREAMolecularPredictor()
model.load_from_hf(repo_id=repo_id, local_cache=f"{model_dir}/GREA_{task_name}.pt")
# Adjust model parameters and make predictions
model.set_params(verbose='none')
predictions = model.predict(smiles_list)
Or you can save the model to a local path:
grea.save_to_local("qm9_grea.pt")
new_model = GREAMolecularPredictor()
new_model.load_from_local("qm9_grea.pt")
List of Supported Models
Predictive Models
Generative Models
Representation Models
Acknowledgements
The project template was adapted from https://github.com/lwaekfjlk/python-project-template. We thank the authors for their contribution to the open-source community.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_molecule-0.1.6.post1-py3-none-any.whl.
File metadata
- Download URL: torch_molecule-0.1.6.post1-py3-none-any.whl
- Upload date:
- Size: 362.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2346d37ad7b11601f9257ee52a2ba62a995695acafd4fd84a4e5b1bb04722289
|
|
| MD5 |
09dd073a2f889361690a0dc2cf5f8a11
|
|
| BLAKE2b-256 |
4765cbe19dabcdc392b7ff65cd41ee42dccfea001e77c6e55a1b77068cfc36a2
|