Skip to main content

Datasets and evaluation from the Spatial Reasoning with Denoising Models paper

Project description

SRM Benchmarks

PyPI version Python License: MIT Tests arXiv Project Page

A minimalistic package with benchmark datasets and evaluation metrics to see how good is your image generative model at understanding complex spatial relationships. Those are the datasets used in the ICML 2025 paper Spatial Reasoning with Denoising Models. All the dataset files and evaluation models have been deployed in their minimal forms to Huggingface, and will be downloaded automatically when you use the package.

SRM Benchmark Datasets

Installation

From PyPI

pip install srmbench

From source

git clone https://github.com/spatialreasoners/srmbench.git
cd srmbench
pip install -e .

Development installation

git clone https://github.com/spatialreasoners/srmbench.git
cd srmbench
pip install -e ".[dev]"

Datasets

SRM Benchmarks provides three main datasets for evaluating spatial reasoning capabilities in generative models. Each dataset tests different aspects of spatial understanding and constraint satisfaction.

🧩 MNIST Sudoku

MNIST Sudoku Examples

Challenge: Inpaint the image by filling the missing cells with MNIST digits where no digit repeats in any row, column, or 3×3 subgrid.

What the model needs to understand:

  • Global constraints: Sudoku validity rules that span the entire image
  • Spatial relationships: Row, column, and subgrid membership
  • Digit recognition: Understanding and generating MNIST digits correctly
  • Constraint propagation: How placing one digit affects valid placements elsewhere

Dataset Details:

  • Image size: 252×252 pixels (9×9 grid of 28×28 MNIST digits)
  • Format: Grayscale images with corresponding masks
  • Masks: Indicate which cells are given (black) vs. need to be filled (white)
  • Difficulty: Configurable via min_given_cells and max_given_cells parameters

Evaluation Metrics:

  • is_valid_sudoku: Boolean indicating valid Sudoku (no duplicates in any row/column/subgrid)
  • duplicate_count: Number of constraint violations (0 = perfect)

🎨 Even Pixels

Even Pixels Examples

Challenge: Generate images where exactly 50% of pixels are one color and 50% are another color, with uniform saturation and brightness.

What the model needs to understand:

  • Pixel-level counting: Precise balance between two colors
  • Global distribution: Maintaining exact 50/50 split across entire image
  • Color consistency: Uniform saturation and value (HSV color space)
  • Statistical properties: Perfect balance down to the pixel level

Dataset Details:

  • Image size: 32×32 pixels (1,024 total pixels)
  • Format: RGB images
  • Color constraint: There are two colors in the image (with opposite hue values), randomly positioned, but the count of pixels for each color is exactly 50% of the total number of pixels.

Evaluation Metrics:

  • saturation_std: Standard deviation of saturation (should be ~0)
  • value_std: Standard deviation of brightness (should be ~0)
  • color_imbalance_count: Deviation from perfect 50/50 split (0 = perfect)
  • is_color_count_even: Boolean for exact pixel balance (1.0 = perfect)

🔢 Counting Objects

Counting Objects Examples

Challenge: Generate images with a specific number of objects (polygons or stars) where the displayed numbers match the actual object counts.

What the model needs to understand:

  • Object counting: Generating exact numbers of distinct objects
  • Number placement: Positioning numbers that accurately represent counts
  • Object consistency: All objects having the same number of vertices (uniform constraint)
  • Semantic coherence: Numbers matching what's visually present

Dataset Details:

  • Image size: Configurable (typically 128×128 or 256×256)
  • Format: RGB images with objects overlaid on FFHQ background faces
  • Variants:
    • Polygons: 3-9 sided polygons
    • Stars: 4-10 pointed stars
  • Numbers: Optional overlay showing object counts (via are_nums_on_images parameter)

Evaluation Metrics:

  • are_vertices_uniform: Fraction where all objects have same vertex count
  • numbers_match_objects: Fraction where displayed numbers match actual counts
  • relative_vertex_count_N: Distribution of N-vertex objects
  • relative_polygons_count_N: Distribution of N objects
  • Confidence scores: Model's certainty in predictions

Usage

Available Datasets

SRM Benchmarks provides three main datasets for evaluating spatial reasoning capabilities:

  • MNIST Sudoku: Sudoku puzzles with MNIST digits
  • Even Pixels: Images with specific color distribution constraints
  • Counting Objects: Images with polygons or stars to count (with optional numbers overlay)

Quick Start

1. MNIST Sudoku Dataset

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2 as transforms
from srmbench.datasets import MnistSudokuDataset
from srmbench.evaluations import MnistSudokuEvaluation

# Define transforms for images and masks
image_transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),  # Scales from [0,255] to [0,1]
    transforms.Lambda(lambda x: x.squeeze(0)),       # Remove channel dimension
])

mask_transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),  # Scales from [0,255] to [0,1]
    transforms.Lambda(lambda x: x.squeeze(0)),       # Remove channel dimension
])

# Create dataset with transforms
dataset = MnistSudokuDataset(
    stage="test",
    transform=image_transform,
    mask_transform=mask_transform
)

# Create DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    num_workers=4,
)

# Use with evaluation
evaluation = MnistSudokuEvaluation()

# Evaluate batches
for images, masks in dataloader:
    # Here you can apply the mask and reconstruct using your model.
    # For example:
    # masked_images = images * (1 - masks)  # Mask out given cells
    # reconstructed = model(masked_images, masks)

    results = evaluation.evaluate(images)
    # duplicate_count = 0 means valid sudoku (no duplicates)
    print(f"Valid Sudoku: {results['is_valid_sudoku'].float().mean():.2%}")
    print(f"Avg Duplicate Count: {results['duplicate_count'].float().mean():.2f}")

2. Even Pixels Dataset

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2 as transforms
from srmbench.datasets import EvenPixelsDataset
from srmbench.evaluations import EvenPixelsEvaluation

# Define transform: PIL RGB (H, W, 3) -> Tensor (3, H, W) in [-1, 1]
transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),  # Scales from [0,255] to [0,1]
    transforms.Lambda(lambda x: x * 2.0 - 1.0),      # Normalize to [-1,1]
])

# Create dataset with transforms
dataset = EvenPixelsDataset(stage="test", transform=transform)

# Create DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    num_workers=4,
)

# Use with evaluation
evaluation = EvenPixelsEvaluation()

# Evaluate batches
for images in dataloader:
    results = evaluation.evaluate(images)
    print(f"Saturation STD: {results['saturation_std']:.4f}")
    print(f"Value STD: {results['value_std']:.4f}")
    print(f"Color Imbalance: {results['color_imbalance_count']:.0f} pixels")
    print(f"Perfect Balance: {results['is_color_count_even']:.2%}")

3. Counting Objects Dataset

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2 as transforms
from srmbench.datasets import CountingObjectsFFHQ
from srmbench.evaluations import CountingObjectsEvaluation

# Define transform: PIL RGB (H, W, 3) -> Tensor (3, H, W) in [-1, 1]
transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),  # Scales from [0,255] to [0,1]
    transforms.Lambda(lambda x: x * 2.0 - 1.0),      # Normalize to [-1,1]
])

# Create dataset with transforms (polygons or stars variant)
# NOTE: Use image_resolution=(128, 128) to match model training resolution
dataset = CountingObjectsFFHQ(
    stage="test",
    object_variant="polygons",  # or "stars"
    image_resolution=(128, 128),
    are_nums_on_images=True,
    transform=transform,
)

# Create DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    num_workers=4,
)

# Use with evaluation (set device="cpu" if no GPU available)
evaluation = CountingObjectsEvaluation(object_variant="polygons", device="cpu")

# Evaluate batches
for images in dataloader:
    results = evaluation.evaluate(images, include_counts=True)
    print(f"Vertices Uniform: {results['are_vertices_uniform']:.2%}")
    print(f"Numbers Match Objects: {results['numbers_match_objects']:.2%}")

License

This project's code is licensed under the MIT License - see the LICENSE file for details. The benchmark datasets included in this package are subject to their respective licenses:

MNIST Sudoku Dataset

Counting Objects Dataset

Note: When using this package, please ensure compliance with the respective dataset licenses, particularly for commercial use. The FFHQ dataset is generally restricted to non-commercial purposes under the CC BY-NC-SA 4.0 license.

Running tests

pytest

Citation

If you use this package in your research, please cite:

@inproceedings{wewer25srm,
  title     = {Spatial Reasoning with Denoising Models},
  author    = {Wewer, Christopher and Pogodzinski, Bartlomiej and Schiele, Bernt and Lenssen, Jan Eric},
  booktitle = {International Conference on Machine Learning ({ICML})},
  year      = {2025},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

srmbench-0.1.3.tar.gz (29.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

srmbench-0.1.3-py3-none-any.whl (27.5 kB view details)

Uploaded Python 3

File details

Details for the file srmbench-0.1.3.tar.gz.

File metadata

  • Download URL: srmbench-0.1.3.tar.gz
  • Upload date:
  • Size: 29.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for srmbench-0.1.3.tar.gz
Algorithm Hash digest
SHA256 97e2c2b5387f976a4b0047176182d294bdbc27fac71008c8adeaba65ba5ca555
MD5 b9a6194bee3b6656f8c9884f3afa9591
BLAKE2b-256 a68632bde1c1a41434755d11e151d2e4b3ab47ab05934fb8b22bf815f433a81b

See more details on using hashes here.

File details

Details for the file srmbench-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: srmbench-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 27.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for srmbench-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 4589fce9f6d3bca84a7d852111eb82e8aa63d6548e799b452556136bfb69256b
MD5 9ea7ccdced89e8c820158bf147c2378e
BLAKE2b-256 879fba3a0a34daa09488b8d596112695cb959a8166fbbc4685dad646eaade327

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page