COMO: Closed-loop Optical Molecule recOgnition with Minimum Risk Training
Project description
COMO
COMO (Closed-loop Optical Molecule recOgnition) is a deep learning framework that recognizes chemical structure diagrams from images and predicts SMILES strings with atom-level coordinates and bond matrices. It uses Minimum Risk Training (MRT) to directly optimize molecular-level, non-differentiable objectives.
Installation
pip install como-ocsr
Quick Start
import como
# Load a model checkpoint (on GPU 0)
model = como.load_model("path/to/checkpoint.pth", device="cuda:0")
# Predict SMILES from a single image
smiles = como.predict(model, "molecule.png")
print(smiles) # "CC(=O)O"
# Batch prediction on a specific GPU
smiles_list = como.predict_batch(model, ["mol1.png", "mol2.png"], device="cuda:1")
# Evaluate on a benchmark (single GPU by default) — file-based
metrics = como.evaluate(
model,
benchmark_dir="benchmark/USPTO/",
csv_path="benchmark/USPTO.csv",
)
print(f"Exact Match: {metrics['postprocess/exact_match_acc']:.2%}")
# Evaluate directly from HuggingFace (no local files needed)
metrics = como.evaluate(
model,
hf_dataset="Keylab/OCSR-Benchmarks",
hf_config="USPTO",
)
# Multi-GPU, multi-benchmark evaluation (mix file-based and HF)
benchmarks = [
{"name": "USPTO", "hf_dataset": "Keylab/OCSR-Benchmarks", "hf_config": "USPTO"},
{"name": "CLEF", "hf_dataset": "Keylab/OCSR-Benchmarks", "hf_config": "CLEF"},
]
results = como.evaluate_benchmarks(model, benchmarks, gpus="0,1,2,3")
for name, m in results.items():
print(f"{name}: {m['postprocess/exact_match_acc']:.2%}")
API Reference
GPU Selection
All functions accept a device parameter for single-GPU usage:
model = como.load_model("checkpoint.pth", device="cuda:0")
como.predict(model, "img.png", device="cuda:1")
como.predict_batch(model, [...], device="cuda:2")
For evaluation (which uses multi-GPU internally via mp.spawn), use the
gpus parameter:
| Function | GPU control |
|---|---|
load_model |
device="cuda:0" |
predict |
device="cuda:0" |
predict_batch |
device="cuda:0" |
evaluate |
gpus="0" (default), gpus="0,1,2", gpus=None (all) |
evaluate_benchmarks |
gpus="0" (default), gpus="0,1,2", gpus=None (all) |
como.load_model(checkpoint_path, device="cuda", **kwargs)
Load a COMO model from a .pth checkpoint. Returns a :class:ComoModel
instance in evaluation mode.
| Parameter | Type | Default | Description |
|---|---|---|---|
checkpoint_path |
str |
required | Path to .pth checkpoint |
device |
str |
"cuda" |
"cuda", "cuda:0", or "cpu" |
Returns: ComoModel
como.predict(model, image, *, beam_size=1, max_len=500, smiles_mode="postprocess", device=None)
Predict the SMILES string for a single molecular image.
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
ComoModel |
required | A loaded model |
image |
str / np.ndarray / PIL.Image / torch.Tensor |
required | Input image (file path, array, PIL, or preprocessed tensor) |
beam_size |
int |
1 |
Beam width (1 = greedy, 3 = beam search) |
max_len |
int |
500 |
Maximum number of tokens to generate |
smiles_mode |
str or None |
"postprocess" |
"postprocess" (best quality), "graph", "decoder", or None (raw result dict) |
device |
str or None |
None |
Optional device override (e.g. "cuda:1") |
Returns:
str— predicted SMILES string (if smiles_mode is notNone)dict— full result dict with keystokens,symbols,coords,bond_mat,decode_smiles,success(ifsmiles_mode=None)
como.predict_batch(model, images, *, beam_size=1, max_len=500, smiles_mode="postprocess", device=None)
Batch prediction on a single GPU.
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
ComoModel |
required | A loaded model |
images |
list |
required | List of file paths, NumPy arrays, PIL Images, or tensors |
beam_size |
int |
1 |
Beam width (1 = greedy, recommended for batch) |
max_len |
int |
500 |
Maximum tokens per image |
smiles_mode |
str or None |
"postprocess" |
SMILES reconstruction mode |
device |
str or None |
None |
Optional device override |
Returns:
list[str]— predicted SMILES for each image (if smiles_mode is notNone)list[dict]— raw result dicts (ifsmiles_mode=None)
como.evaluate(model, benchmark_dir=None, csv_path=None, *, hf_dataset=None, hf_config=None, hf_split="test", beam_size=1, postproc_workers=32, tautomer_standardize=True, gpus="0")
Evaluate on a single benchmark dataset. Returns a flat dict of metrics.
Two mutually exclusive (Hugging Face Priority) input modes are supported:
# File-based
metrics = como.evaluate(model, "benchmark/USPTO/", "benchmark/USPTO.csv")
# HuggingFace dataset (no local files required)
metrics = como.evaluate(model, hf_dataset="Keylab/OCSR-Benchmarks", hf_config="USPTO")
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
ComoModel |
required | A loaded model |
benchmark_dir |
str or None |
None |
Directory containing .png images (file-based mode); Ignored if hf_dataset is provided |
csv_path |
str or None |
None |
CSV with columns image_id, SMILES (file-based mode); Ignored if hf_dataset is provided |
hf_dataset |
str or None |
None |
HuggingFace dataset repo id, e.g. "Keylab/OCSR-Benchmarks" |
hf_config |
str or None |
None |
Config / subset name within the HF dataset, e.g. "USPTO" |
hf_split |
str |
"test" |
Dataset split to load |
beam_size |
int |
1 |
Beam width for decoding |
postproc_workers |
int |
32 |
Parallel workers for SMILES post-processing |
tautomer_standardize |
bool |
True |
Include tautomer-normalized exact match |
gpus |
str or None |
"0" |
GPU IDs ("0,1") or None for all |
Returns: dict with the following keys:
| Key | Type | Description |
|---|---|---|
decoder/exact_match_acc |
float |
Exact match accuracy (decoder mode) |
decoder/avg_tanimoto |
float |
Average Tanimoto similarity (decoder) |
decoder/tautomer_match_acc |
float |
Tautomer-normalized exact match (decoder, if tautomer_standardize=True) |
decoder/failed_predictions |
int |
Number of failed predictions (decoder) |
decoder/valid |
int |
Number of chemically valid predictions (decoder) |
decoder/total |
int |
Total benchmark samples |
graph/exact_match_acc |
float |
Exact match accuracy (graph mode) |
graph/avg_tanimoto |
float |
Average Tanimoto similarity (graph) |
graph/tautomer_match_acc |
float |
Tautomer-normalized exact match (graph, if tautomer_standardize=True) |
graph/failed_predictions |
int |
Number of failed predictions (graph) |
graph/valid |
int |
Number of chemically valid predictions (graph) |
graph/total |
int |
Total benchmark samples |
postprocess/exact_match_acc |
float |
Exact match accuracy (postprocess mode, primary metric) |
postprocess/avg_tanimoto |
float |
Average Tanimoto similarity (postprocess) |
postprocess/tautomer_match_acc |
float |
Tautomer-normalized exact match (postprocess, if tautomer_standardize=True) |
postprocess/failed_predictions |
int |
Number of failed predictions (postprocess) |
postprocess/valid |
int |
Number of chemically valid predictions (postprocess) |
postprocess/records_df |
DataFrame |
Per-image results with columns image_id, gt_smiles, pred_smiles, exact, tautomer, tanimoto |
postprocess/total |
int |
Total benchmark samples |
total |
int |
Total benchmark samples |
como.evaluate_benchmarks(model, benchmarks, *, beam_size=1, postproc_workers=32, tautomer_standardize=True, gpus="0")
Evaluate on multiple benchmarks in one call. Returns a nested dict keyed by benchmark name.
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
ComoModel |
required | A loaded model |
benchmarks |
list[dict] |
required | List of benchmark spec dicts (see below) |
beam_size |
int |
1 |
Beam width for decoding |
postproc_workers |
int |
32 |
Parallel workers for SMILES post-processing |
tautomer_standardize |
bool |
True |
Include tautomer-normalized exact match |
gpus |
str or None |
"0" |
GPU IDs ("0,1") or None for all |
Each dict in benchmarks must contain "name" plus one of:
| Mode | Required keys | Optional keys |
|---|---|---|
| File-based | "benchmark_dir", "csv_path" |
— |
| HuggingFace | "hf_dataset" |
"hf_config" (default: benchmark name), "hf_split" (default: "test") |
Returns: dict[str, dict] — mapping from benchmark name to a metrics
dict with the same structure as :func:evaluate. Example::
{
"USPTO": {
"postprocess/exact_match_acc": 0.934,
"postprocess/avg_tanimoto": 0.987,
...
},
"CLEF": {
"postprocess/exact_match_acc": 0.948,
...
},
}
Examples:
# File-based
benchmarks = [
{"name": "USPTO", "benchmark_dir": "data/benchmark/real/USPTO",
"csv_path": "data/benchmark/real/USPTO.csv"},
{"name": "CLEF", "benchmark_dir": "data/benchmark/real/CLEF",
"csv_path": "data/benchmark/real/CLEF_corrected.csv"},
]
# HuggingFace dataset (recommended — no local files required)
benchmarks = [
{"name": "USPTO", "hf_dataset": "Keylab/OCSR-Benchmarks", "hf_config": "USPTO"},
{"name": "CLEF", "hf_dataset": "Keylab/OCSR-Benchmarks", "hf_config": "CLEF"},
{"name": "JPO", "hf_dataset": "Keylab/OCSR-Benchmarks", "hf_config": "JPO"},
]
results = como.evaluate_benchmarks(model, benchmarks, gpus="0,1")
for name, metrics in results.items():
acc = metrics["postprocess/exact_match_acc"]
tan = metrics["postprocess/avg_tanimoto"]
print(f"{name}: Exact={acc:.2%}, Tanimoto={tan:.4f}")
como.canonicalize_smiles(smiles, *, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True)
Canonicalize a SMILES string using RDKit.
| Parameter | Type | Default | Description |
|---|---|---|---|
smiles |
str |
required | Input SMILES string |
ignore_chiral |
bool |
False |
Strip tetrahedral chirality before canonicalization |
ignore_cistrans |
bool |
False |
Strip cis–trans markers (/ and \) before canonicalization |
replace_rgroup |
bool |
True |
If True, replace R-group tokens (R, R1, X, Ar, …) with wildcard * |
Returns: tuple[str, bool] — (canonical_smiles, ok) where ok is
True if the SMILES is chemically valid and canonicalization succeeded.
como.canonicalize_tautomer(smiles)
Canonicalize a SMILES string via RDKit's TautomerEnumerator, normalizing different tautomeric forms (e.g., keto/enol, lactam/lactim) to the same canonical representation.
| Parameter | Type | Default | Description |
|---|---|---|---|
smiles |
str |
required | Input SMILES string |
Returns: tuple[str, bool] — (tautomer_canonical_smiles, ok) where
ok is False if the input SMILES is invalid or tautomer enumeration fails.
como._result_to_smiles(result, mode="postprocess")
Low-level: convert a raw prediction result dict (from :func:predict with
smiles_mode=None) to a canonical SMILES string.
| Parameter | Type | Default | Description |
|---|---|---|---|
result |
dict |
required | Raw prediction dict with keys decode_smiles, symbols, coords, bond_mat, success |
mode |
str |
"postprocess" |
SMILES reconstruction mode |
mode options:
| Mode | Source | Chirality | Description |
|---|---|---|---|
"decoder" |
Decoder token sequence | ✗ | Raw decoder SMILES, no graph info used. Fastest but lowest quality. |
"graph" |
Predicted atoms + bonds | ✓ | Reconstructs SMILES entirely from predicted atom symbols, coordinates, and bond matrix. Chirality restored via _verify_chirality. |
"postprocess" |
Decoder + atoms + bonds | ✓ | Starts from decoder SMILES, replaces R-groups/abbreviations, restores chirality from predicted coordinates and bond matrix, then expands functional groups back. Best quality. |
Returns: str or None — canonical SMILES string, or None if conversion fails.
Model Weights
Pre-trained model weights are available on HuggingFace:
| Checkpoint | Reward Mode | Description |
|---|---|---|
COMO_joint/tanimoto/final.pth |
Tanimoto | Joint MLE+MRT (Tanimoto reward) |
COMO_joint/edit_distance/final.pth |
Edit Distance | Joint MLE+MRT (Edit Distance reward) |
COMO_joint/visual/final.pth |
Visual | Joint MLE+MRT (Visual reward) |
Download from: https://huggingface.co/Keylab/COMO
Benchmark Datasets
Benchmark datasets are published as a HuggingFace Dataset with one config per benchmark: Keylab/OCSR-Benchmarks
| Config | Size | Type |
|---|---|---|
CLEF |
992 | Real (patents) |
JPO |
449 | Real (patents) |
UOB |
5,740 | Real (academic) |
USPTO |
5,719 | Real (patents) |
USPTO-10K |
9,999 | Real (patents) |
Staker |
50,000 | Real |
ACS |
331 | Real (publications) |
WildMol-10K |
9,889 | Real (wild) |
Indigo |
5,719 | Synthetic |
ChemDraw |
5,719 | Synthetic |
Each sample has three fields: image_id (str), image (PIL), SMILES (str).
from datasets import load_dataset
# Load a single benchmark
ds = load_dataset("Keylab/OCSR-Benchmarks", name="USPTO", split="test")
sample = ds[0]
sample["image"].show() # PIL Image
print(sample["SMILES"]) # ground-truth SMILES
# Iterate over all benchmarks
configs = ["CLEF", "JPO", "UOB", "USPTO", "USPTO-10K",
"Staker", "ACS", "WildMol-10K", "Indigo", "ChemDraw"]
for name in configs:
ds = load_dataset("Keylab/OCSR-Benchmarks", name=name, split="test")
print(f"{name}: {len(ds)} samples")
Pre-packaged .tar.gz archives are also available for bulk download in the
COMO model repository.
Citation
If you use COMO in your research, please cite:
@article{lyu2026closed,
title={COMO: Closed-Loop Optical Molecule Recognition with Minimum Risk Training},
author={Lyu, Zhuoqi and Ke, Qing},
journal={arXiv preprint arXiv:2604.23546},
year={2026}
}
License
- Code (
como/package): MIT License - Model Weights (
.pthfiles): CC BY-NC 4.0 (non-commercial use only) - Benchmark Datasets: collected from existing public OCSR benchmarks; please refer to their original sources for license and attribution:
| Dataset | Source |
|---|---|
| USPTO, CLEF, JPO, UOB, Staker | Rajan et al., 2020, Xiong et al., 2023 |
| Indigo, ChemDraw, ACS, Staker | Qian et al., 2023 |
| USPTO-10K | Morin et al., 2023 |
| WildMol-10K | Fang et al., 2025 |
See LICENSE for full terms.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file como_ocsr-1.2.3.tar.gz.
File metadata
- Download URL: como_ocsr-1.2.3.tar.gz
- Upload date:
- Size: 50.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
241d5fb055958ba9ec93559a5c3069a0a82a53c91a974eca3a5e76f5d2b267ef
|
|
| MD5 |
727aa6beaafc6897ef16d2bb0dad07e3
|
|
| BLAKE2b-256 |
939d220b0a5e9e8dd1edb4a5631c91724f767c8f6334ee97ea93a20b43ff6c1f
|
File details
Details for the file como_ocsr-1.2.3-py3-none-any.whl.
File metadata
- Download URL: como_ocsr-1.2.3-py3-none-any.whl
- Upload date:
- Size: 50.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0a063748ea294aa3ca871504ada6e487649358fd43f37d90411c220ba4ed455d
|
|
| MD5 |
26caf964d15874296cbd04bf9ccffa8d
|
|
| BLAKE2b-256 |
7f543e374938da64c44301f805f5ccfcec7518afea00cc468b128787b51d0195
|