Skip to main content

GraphSAGE-based decoders for the Stim/Sinter quantum error correction framework.

Project description

gnn_decoders

GraphSAGE-based decoders for the Stim / Sinter quantum-error-correction benchmarking framework.

This package wraps the trained GraphSAGE models from the quantum-error-correction research project and exposes them as sinter.Decoder instances, so they can be benchmarked head-to-head against pymatching, fusion_blossom, and other Sinter-supported decoders in a single sinter.collect call.

Scope of v0.1

  • Codes: rotated-surface-code memory-Z experiments only.
  • Distances: d ∈ {3, 5, 7} (the bundled baseline checkpoints).
  • Seeds per distance: 1-5 (15 checkpoints total, ~6.7 MB bundled).
  • Single observable. Multi-observable DEMs raise NotImplementedError.

If Sinter hands the decoder a DEM whose detector count does not correspond to a bundled checkpoint, or whose num_observables != 1, compilation fails fast with a clear error rather than producing misleading predictions.

Install

pip install -e .

Dependencies (stim, sinter, torch, torch-geometric, numpy) are declared in pyproject.toml. torch-geometric typically needs a torch-version-specific install — see https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html.

Quickstart

import stim
import sinter
from gnn_decoders import GraphSAGEDecoder

circuit = stim.Circuit.generated(
    "surface_code:rotated_memory_z",
    rounds=5,
    distance=5,
    after_clifford_depolarization=0.005,
    after_reset_flip_probability=0.005,
    before_measure_flip_probability=0.005,
    before_round_data_depolarization=0.005,
)

stats = sinter.collect(
    num_workers=4,
    tasks=[sinter.Task(circuit=circuit, json_metadata={"d": 5, "p": 0.005})],
    decoders=["graphsage"],
    custom_decoders={"graphsage": GraphSAGEDecoder(seed=1)},
    max_shots=10_000,
)
for s in stats:
    print(s)

See examples/benchmark_vs_pymatching.py for the full apples-to-apples comparison against PyMatching across all bundled (d, p) combinations.

API

GraphSAGEDecoder(seed=1, k_neighbors=6, device="auto", batch_size=1024)

A sinter.Decoder subclass.

arg meaning
seed Which trained seed to use (1-5). Use available_checkpoints() to inspect.
k_neighbors k-NN parameter for the sparse-graph encoding. Must match training (6).
device "auto", "cpu", "cuda", or a torch.device.
batch_size Shots per PyG Batch per forward pass. Tune for GPU memory.

available_checkpoints() -> dict

Returns the parsed MANIFEST.json describing every bundled .pt file.

supported_distances() -> list[int], supported_seeds(distance) -> list[int]

Introspection helpers.

What's bundled

Each .pt in src/gnn_decoders/checkpoints/ is a dict {state_dict, config, nickname, timestamp} saved by the original GraphSAGE.save method in the research project's code/models.py. Each checkpoint carries its own architectural hyperparameters (in_channels, hidden_dim, num_layers, dropout, aggr), so the model is reconstructed exactly as it was trained — the decoder does not assume any particular configuration.

The bundled "baseline" checkpoints use:

in_channels=5, hidden_dim=128, num_layers=4, dropout=0.0, aggr='mean'

The "best tuned" configuration discovered in code/gSAGE/tuning/best_model_config.json (hidden_dim=256, num_layers=5, dropout=0.1, aggr='max') is not bundled here because it was only trained at d=7 in the source project. It can be added in a future release by retraining across d ∈ {3, 5, 7} and copying the resulting .pt files into src/gnn_decoders/checkpoints/.

License

MIT.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gnn_decoders-0.1.0.tar.gz (6.2 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gnn_decoders-0.1.0-py3-none-any.whl (6.2 MB view details)

Uploaded Python 3

File details

Details for the file gnn_decoders-0.1.0.tar.gz.

File metadata

  • Download URL: gnn_decoders-0.1.0.tar.gz
  • Upload date:
  • Size: 6.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.4

File hashes

Hashes for gnn_decoders-0.1.0.tar.gz
Algorithm Hash digest
SHA256 efdac280966be995a7d64ca1cb314122827b7a91490eb5231cb6af95e48f3024
MD5 d22e5dde08aa46753fc76df6cc8f0190
BLAKE2b-256 bff69baad510b65816b646e51dfca9104327714a9783d3d7ac2521cedc13553d

See more details on using hashes here.

File details

Details for the file gnn_decoders-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: gnn_decoders-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 6.2 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.4

File hashes

Hashes for gnn_decoders-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0356a0ef3c498d3d17ddaec1e6df0d703d0222e3768cf3c325ae93e00bc54c3b
MD5 b47b6c4947bc9d2ef87d0428b9938a15
BLAKE2b-256 dbf168264b104a2f4b7ff8ea37dd18521dcd498922ec5f6f3ed86c9210919541

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page