Skip to main content

COMO: Closed-loop Optical Molecule recOgnition with Minimum Risk Training

Project description

COMO

COMO (Closed-loop Optical Molecule recOgnition) is a deep learning framework that recognizes chemical structure diagrams from images and predicts SMILES strings with atom-level coordinates and bond matrices. It uses Minimum Risk Training (MRT) to directly optimize molecular-level, non-differentiable objectives.

Installation

pip install como-ocsr

Quick Start

import como

# Load a model checkpoint (on GPU 0)
model = como.load_model("path/to/checkpoint.pth", device="cuda:0")

# Predict SMILES from a single image
smiles = como.predict(model, "molecule.png")
print(smiles)  # "CC(=O)O"

# Batch prediction on a specific GPU
smiles_list = como.predict_batch(model, ["mol1.png", "mol2.png"], device="cuda:1")

# Evaluate on a benchmark (single GPU by default) — file-based
metrics = como.evaluate(
    model,
    benchmark_dir="benchmark/USPTO/",
    csv_path="benchmark/USPTO.csv",
)
print(f"Exact Match: {metrics['postprocess/exact_match_acc']:.2%}")

# Evaluate directly from HuggingFace (no local files needed)
metrics = como.evaluate(
    model,
    hf_dataset="Keylab/OCSR-Benchmarks",
    hf_config="USPTO",
)

# Multi-GPU, multi-benchmark evaluation (mix file-based and HF)
benchmarks = [
    {"name": "USPTO", "hf_dataset": "Keylab/OCSR-Benchmarks", "hf_config": "USPTO"},
    {"name": "CLEF",  "hf_dataset": "Keylab/OCSR-Benchmarks", "hf_config": "CLEF"},
]
results = como.evaluate_benchmarks(model, benchmarks, gpus="0,1,2,3")
for name, m in results.items():
    print(f"{name}: {m['postprocess/exact_match_acc']:.2%}")

API Reference

GPU Selection

All functions accept a device parameter for single-GPU usage:

model = como.load_model("checkpoint.pth", device="cuda:0")
como.predict(model, "img.png", device="cuda:1")
como.predict_batch(model, [...], device="cuda:2")

For evaluation (which uses multi-GPU internally via mp.spawn), use the gpus parameter:

Function GPU control
load_model device="cuda:0"
predict device="cuda:0"
predict_batch device="cuda:0"
evaluate gpus="0" (default), gpus="0,1,2", gpus=None (all)
evaluate_benchmarks gpus="0" (default), gpus="0,1,2", gpus=None (all)

como.load_model(checkpoint_path, device="cuda", **kwargs)

Load a COMO model from a .pth checkpoint. Returns a :class:ComoModel instance in evaluation mode.

Parameter Type Default Description
checkpoint_path str required Path to .pth checkpoint
device str "cuda" "cuda", "cuda:0", or "cpu"

Returns: ComoModel


como.predict(model, image, *, beam_size=1, max_len=500, smiles_mode="postprocess", device=None)

Predict the SMILES string for a single molecular image.

Parameter Type Default Description
model ComoModel required A loaded model
image str / np.ndarray / PIL.Image / torch.Tensor required Input image (file path, array, PIL, or preprocessed tensor)
beam_size int 1 Beam width (1 = greedy, 3 = beam search)
max_len int 500 Maximum number of tokens to generate
smiles_mode str or None "postprocess" "postprocess" (best quality), "graph", "decoder", or None (raw result dict)
device str or None None Optional device override (e.g. "cuda:1")

Returns:

  • str — predicted SMILES string (if smiles_mode is not None)
  • dict — full result dict with keys tokens, symbols, coords, bond_mat, decode_smiles, success (if smiles_mode=None)

como.predict_batch(model, images, *, beam_size=1, max_len=500, smiles_mode="postprocess", device=None)

Batch prediction on a single GPU.

Parameter Type Default Description
model ComoModel required A loaded model
images list required List of file paths, NumPy arrays, PIL Images, or tensors
beam_size int 1 Beam width (1 = greedy, recommended for batch)
max_len int 500 Maximum tokens per image
smiles_mode str or None "postprocess" SMILES reconstruction mode
device str or None None Optional device override

Returns:

  • list[str] — predicted SMILES for each image (if smiles_mode is not None)
  • list[dict] — raw result dicts (if smiles_mode=None)

como.evaluate(model, benchmark_dir=None, csv_path=None, *, hf_dataset=None, hf_config=None, hf_split="test", beam_size=1, postproc_workers=32, tautomer_standardize=True, gpus="0")

Evaluate on a single benchmark dataset. Returns a flat dict of metrics.

Two mutually exclusive (Hugging Face Priority) input modes are supported:

# File-based
metrics = como.evaluate(model, "benchmark/USPTO/", "benchmark/USPTO.csv")

# HuggingFace dataset (no local files required)
metrics = como.evaluate(model, hf_dataset="Keylab/OCSR-Benchmarks", hf_config="USPTO")
Parameter Type Default Description
model ComoModel required A loaded model
benchmark_dir str or None None Directory containing .png images (file-based mode); Ignored if hf_dataset is provided
csv_path str or None None CSV with columns image_id, SMILES (file-based mode); Ignored if hf_dataset is provided
hf_dataset str or None None HuggingFace dataset repo id, e.g. "Keylab/OCSR-Benchmarks"
hf_config str or None None Config / subset name within the HF dataset, e.g. "USPTO"
hf_split str "test" Dataset split to load
beam_size int 1 Beam width for decoding
postproc_workers int 32 Parallel workers for SMILES post-processing
tautomer_standardize bool True Include tautomer-normalized exact match
gpus str or None "0" GPU IDs ("0,1") or None for all

Returns: dict with the following keys:

Key Type Description
decoder/exact_match_acc float Exact match accuracy (decoder mode)
decoder/avg_tanimoto float Average Tanimoto similarity (decoder)
decoder/tautomer_match_acc float Tautomer-normalized exact match (decoder, if tautomer_standardize=True)
decoder/failed_predictions int Number of failed predictions (decoder)
decoder/valid int Number of chemically valid predictions (decoder)
decoder/total int Total benchmark samples
graph/exact_match_acc float Exact match accuracy (graph mode)
graph/avg_tanimoto float Average Tanimoto similarity (graph)
graph/tautomer_match_acc float Tautomer-normalized exact match (graph, if tautomer_standardize=True)
graph/failed_predictions int Number of failed predictions (graph)
graph/valid int Number of chemically valid predictions (graph)
graph/total int Total benchmark samples
postprocess/exact_match_acc float Exact match accuracy (postprocess mode, primary metric)
postprocess/avg_tanimoto float Average Tanimoto similarity (postprocess)
postprocess/tautomer_match_acc float Tautomer-normalized exact match (postprocess, if tautomer_standardize=True)
postprocess/failed_predictions int Number of failed predictions (postprocess)
postprocess/valid int Number of chemically valid predictions (postprocess)
postprocess/records_df DataFrame Per-image results with columns image_id, gt_smiles, pred_smiles, exact, tautomer, tanimoto
postprocess/total int Total benchmark samples
total int Total benchmark samples

como.evaluate_benchmarks(model, benchmarks, *, beam_size=1, postproc_workers=32, tautomer_standardize=True, gpus="0")

Evaluate on multiple benchmarks in one call. Returns a nested dict keyed by benchmark name.

Parameter Type Default Description
model ComoModel required A loaded model
benchmarks list[dict] required List of benchmark spec dicts (see below)
beam_size int 1 Beam width for decoding
postproc_workers int 32 Parallel workers for SMILES post-processing
tautomer_standardize bool True Include tautomer-normalized exact match
gpus str or None "0" GPU IDs ("0,1") or None for all

Each dict in benchmarks must contain "name" plus one of:

Mode Required keys Optional keys
File-based "benchmark_dir", "csv_path"
HuggingFace "hf_dataset" "hf_config" (default: benchmark name), "hf_split" (default: "test")

Returns: dict[str, dict] — mapping from benchmark name to a metrics dict with the same structure as :func:evaluate. Example::

{
  "USPTO": {
    "postprocess/exact_match_acc": 0.934,
    "postprocess/avg_tanimoto": 0.987,
    ...
  },
  "CLEF": {
    "postprocess/exact_match_acc": 0.948,
    ...
  },
}

Examples:

# File-based
benchmarks = [
    {"name": "USPTO", "benchmark_dir": "data/benchmark/real/USPTO",
     "csv_path": "data/benchmark/real/USPTO.csv"},
    {"name": "CLEF",  "benchmark_dir": "data/benchmark/real/CLEF",
     "csv_path": "data/benchmark/real/CLEF_corrected.csv"},
]

# HuggingFace dataset (recommended — no local files required)
benchmarks = [
    {"name": "USPTO", "hf_dataset": "Keylab/OCSR-Benchmarks", "hf_config": "USPTO"},
    {"name": "CLEF",  "hf_dataset": "Keylab/OCSR-Benchmarks", "hf_config": "CLEF"},
    {"name": "JPO",   "hf_dataset": "Keylab/OCSR-Benchmarks", "hf_config": "JPO"},
]

results = como.evaluate_benchmarks(model, benchmarks, gpus="0,1")
for name, metrics in results.items():
    acc = metrics["postprocess/exact_match_acc"]
    tan = metrics["postprocess/avg_tanimoto"]
    print(f"{name}: Exact={acc:.2%}, Tanimoto={tan:.4f}")

como.canonicalize_smiles(smiles, *, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True)

Canonicalize a SMILES string using RDKit.

Parameter Type Default Description
smiles str required Input SMILES string
ignore_chiral bool False Strip tetrahedral chirality before canonicalization
ignore_cistrans bool False Strip cis–trans markers (/ and \) before canonicalization
replace_rgroup bool True If True, replace R-group tokens (R, R1, X, Ar, …) with wildcard *

Returns: tuple[str, bool](canonical_smiles, ok) where ok is True if the SMILES is chemically valid and canonicalization succeeded.


como.canonicalize_tautomer(smiles)

Canonicalize a SMILES string via RDKit's TautomerEnumerator, normalizing different tautomeric forms (e.g., keto/enol, lactam/lactim) to the same canonical representation.

Parameter Type Default Description
smiles str required Input SMILES string

Returns: tuple[str, bool](tautomer_canonical_smiles, ok) where ok is False if the input SMILES is invalid or tautomer enumeration fails.


como._result_to_smiles(result, mode="postprocess")

Low-level: convert a raw prediction result dict (from :func:predict with smiles_mode=None) to a canonical SMILES string.

Parameter Type Default Description
result dict required Raw prediction dict with keys decode_smiles, symbols, coords, bond_mat, success
mode str "postprocess" SMILES reconstruction mode

mode options:

Mode Source Chirality Description
"decoder" Decoder token sequence Raw decoder SMILES, no graph info used. Fastest but lowest quality.
"graph" Predicted atoms + bonds Reconstructs SMILES entirely from predicted atom symbols, coordinates, and bond matrix. Chirality restored via _verify_chirality.
"postprocess" Decoder + atoms + bonds Starts from decoder SMILES, replaces R-groups/abbreviations, restores chirality from predicted coordinates and bond matrix, then expands functional groups back. Best quality.

Returns: str or None — canonical SMILES string, or None if conversion fails.

Model Weights

Pre-trained model weights are available on HuggingFace:

Checkpoint Reward Mode Description
COMO_joint/tanimoto/final.pth Tanimoto Joint MLE+MRT (Tanimoto reward)
COMO_joint/edit_distance/final.pth Edit Distance Joint MLE+MRT (Edit Distance reward)
COMO_joint/visual/final.pth Visual Joint MLE+MRT (Visual reward)

Download from: https://huggingface.co/Keylab/COMO

Benchmark Datasets

Benchmark datasets are published as a HuggingFace Dataset with one config per benchmark: Keylab/OCSR-Benchmarks

Config Size Type
CLEF 992 Real (patents)
JPO 449 Real (patents)
UOB 5,740 Real (academic)
USPTO 5,719 Real (patents)
USPTO-10K 9,999 Real (patents)
Staker 50,000 Real
ACS 331 Real (publications)
WildMol-10K 9,889 Real (wild)
Indigo 5,719 Synthetic
ChemDraw 5,719 Synthetic

Each sample has three fields: image_id (str), image (PIL), SMILES (str).

from datasets import load_dataset

# Load a single benchmark
ds = load_dataset("Keylab/OCSR-Benchmarks", name="USPTO", split="test")
sample = ds[0]
sample["image"].show()   # PIL Image
print(sample["SMILES"])  # ground-truth SMILES

# Iterate over all benchmarks
configs = ["CLEF", "JPO", "UOB", "USPTO", "USPTO-10K",
           "Staker", "ACS", "WildMol-10K", "Indigo", "ChemDraw"]
for name in configs:
    ds = load_dataset("Keylab/OCSR-Benchmarks", name=name, split="test")
    print(f"{name}: {len(ds)} samples")

Pre-packaged .tar.gz archives are also available for bulk download in the COMO model repository.

Citation

If you use COMO in your research, please cite:

@article{lyu2026closed,
  title={COMO: Closed-Loop Optical Molecule Recognition with Minimum Risk Training},
  author={Lyu, Zhuoqi and Ke, Qing},
  journal={arXiv preprint arXiv:2604.23546},
  year={2026}
}

License

  • Code (como/ package): MIT License
  • Model Weights (.pth files): CC BY-NC 4.0 (non-commercial use only)
  • Benchmark Datasets: collected from existing public OCSR benchmarks; please refer to their original sources for license and attribution:
Dataset Source
USPTO, CLEF, JPO, UOB, Staker Rajan et al., 2020, Xiong et al., 2023
Indigo, ChemDraw, ACS, Staker Qian et al., 2023
USPTO-10K Morin et al., 2023
WildMol-10K Fang et al., 2025

See LICENSE for full terms.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

como_ocsr-1.2.3.tar.gz (50.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

como_ocsr-1.2.3-py3-none-any.whl (50.0 kB view details)

Uploaded Python 3

File details

Details for the file como_ocsr-1.2.3.tar.gz.

File metadata

  • Download URL: como_ocsr-1.2.3.tar.gz
  • Upload date:
  • Size: 50.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.2

File hashes

Hashes for como_ocsr-1.2.3.tar.gz
Algorithm Hash digest
SHA256 241d5fb055958ba9ec93559a5c3069a0a82a53c91a974eca3a5e76f5d2b267ef
MD5 727aa6beaafc6897ef16d2bb0dad07e3
BLAKE2b-256 939d220b0a5e9e8dd1edb4a5631c91724f767c8f6334ee97ea93a20b43ff6c1f

See more details on using hashes here.

File details

Details for the file como_ocsr-1.2.3-py3-none-any.whl.

File metadata

  • Download URL: como_ocsr-1.2.3-py3-none-any.whl
  • Upload date:
  • Size: 50.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.2

File hashes

Hashes for como_ocsr-1.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 0a063748ea294aa3ca871504ada6e487649358fd43f37d90411c220ba4ed455d
MD5 26caf964d15874296cbd04bf9ccffa8d
BLAKE2b-256 7f543e374938da64c44301f805f5ccfcec7518afea00cc468b128787b51d0195

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page