Skip to main content

Datasets and evaluation from the Spatial Reasoning with Denoising Models paper

Project description

SRM Benchmarks

Package with benchmark datasets to see how good is your image generative model at understanding complex spatial relationships. Those are the datasets used in the ICML 2025 paper Spatial Reasoning with Denoising Models.

Installation

From PyPI

pip install srmbench

From source

git clone https://github.com/spatialreasoners/srmbench.git
cd srmbench
pip install -e .

Development installation

git clone https://github.com/spatialreasoners/srmbench.git
cd srmbench
pip install -e ".[dev]"

Usage

Available Datasets

SRM Benchmarks provides three main datasets for evaluating spatial reasoning capabilities:

  • MNIST Sudoku: Sudoku puzzles with MNIST digits
  • Even Pixels: Images with specific color distribution constraints
  • Counting Objects: Images with polygons or stars to count (with optional numbers overlay)

Quick Start

1. MNIST Sudoku Dataset

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2 as transforms
from srmbench.datasets import MnistSudokuDataset
from srmbench.evaluations import MnistSudokuEvaluation

# Create dataset
dataset = MnistSudokuDataset(stage="test")

# Define transform: PIL Image (H, W) -> Tensor (H, W) in [0, 1]
# Note: ToImage() converts PIL to Tensor, ToDtype with scale=True normalizes [0,255] -> [0,1]
transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),  # Scales from [0,255] to [0,1]
    transforms.Lambda(lambda x: x.squeeze(0)),       # Remove channel dimension
])

# Collate function to handle (image, mask) tuples
def collate_fn(batch):
    images = torch.stack([transform(item[0]) for item in batch])
    masks = [item[1] for item in batch]
    return images, masks

# Create DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn
)

# Use with evaluation
evaluation = MnistSudokuEvaluation()

# Evaluate batches
for images, masks in dataloader:
    # Here you can apply the mask and reconstruct using your model.
    # For example:
    # images = model(images * masks)

    results = evaluation.evaluate(images)
    # duplicate_count = 0 means valid sudoku (no duplicates)
    print(f"Valid Sudoku: {results['is_valid_sudoku'].float().mean():.2%}")
    print(f"Avg Duplicate Count: {results['duplicate_count'].float().mean():.2f}")

2. Even Pixels Dataset

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2 as transforms
from srmbench.datasets import EvenPixelsDataset
from srmbench.evaluations import EvenPixelsEvaluation

# Create dataset
dataset = EvenPixelsDataset(stage="test")

# Define transform: PIL RGB (H, W, 3) -> Tensor (3, H, W) in [-1, 1]
# Note: ToImage() converts PIL to Tensor, ToDtype with scale=True normalizes [0,255] -> [0,1]
transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),  # Scales from [0,255] to [0,1]
    transforms.Lambda(lambda x: x * 2.0 - 1.0),      # Normalize to [-1,1]
])

# Collate function (dataset returns (image, mask) tuple)
def collate_fn(batch):
    images = torch.stack([transform(item[0]) for item in batch])
    return images

# Create DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn
)

# Use with evaluation
evaluation = EvenPixelsEvaluation()

# Evaluate batches
for images in dataloader:
    results = evaluation.evaluate(images)
    print(f"Saturation STD: {results['saturation_std']:.4f}")
    print(f"Value STD: {results['value_std']:.4f}")
    print(f"Color Imbalance: {results['color_imbalance_count']:.0f} pixels")
    print(f"Perfect Balance: {results['is_color_count_even']:.2%}")

3. Counting Objects Dataset

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2 as transforms
from srmbench.datasets import CountingObjectsFFHQ
from srmbench.evaluations import CountingObjectsEvaluation

# Create dataset (polygons or stars variant)
# NOTE: Use image_resolution=(128, 128) to match model training resolution
dataset = CountingObjectsFFHQ(
    stage="test",
    object_variant="polygons",  # or "stars"
    image_resolution=(128, 128),
    are_nums_on_images=True,
)

# Define transform: PIL RGB (H, W, 3) -> Tensor (3, H, W) in [-1, 1]
# Note: ToImage() converts PIL to Tensor, ToDtype with scale=True normalizes [0,255] -> [0,1]
transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),  # Scales from [0,255] to [0,1]
    transforms.Lambda(lambda x: x * 2.0 - 1.0),      # Normalize to [-1,1]
])

# Create DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    num_workers=4,
    collate_fn=lambda batch: torch.stack([transform(img) for img in batch])
)

# Use with evaluation (set device="cpu" if no GPU available)
evaluation = CountingObjectsEvaluation(object_variant="polygons", device="cpu")

# Evaluate batches
for images in dataloader:
    results = evaluation.evaluate(images, include_counts=True)
    print(f"Vertices Uniform: {results['are_vertices_uniform']:.2%}")
    print(f"Numbers Match Objects: {results['numbers_match_objects']:.2%}")

Evaluation Metrics

Each evaluation returns different metrics:

MNIST Sudoku:

  • is_valid_sudoku: Boolean indicating whether the sudoku is valid (no duplicate digits in any row, column, or subgrid)
  • duplicate_count: Total count of duplicate violations (0 = perfect valid sudoku, higher = more duplicates)

Even Pixels:

  • saturation_std: Standard deviation of saturation across the image (should be ~0 for uniform saturation)
  • value_std: Standard deviation of value/brightness across the image (should be ~0 for uniform brightness)
  • color_imbalance_count: Number of pixels deviating from a perfect 50/50 split between the two main colors (0 = perfectly balanced)
  • is_color_count_even: Boolean indicating whether the two main colors have exactly equal pixel counts (1.0 = balanced, 0.0 = unbalanced)

Counting Objects:

  • are_vertices_uniform: Fraction of images where all objects have the same number of vertices
  • numbers_match_objects: Fraction of images where the displayed numbers match the actual object counts (high for dataset images, low for random images)
  • Additional vertex/polygon count distributions (with include_counts=True)
  • Confidence scores (with include_confidences=True)

Running tests

pytest

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

If you use this package in your research, please cite:

@inproceedings{wewer25srm,
  title     = {Spatial Reasoning with Denoising Models},
  author    = {Wewer, Christopher and Pogodzinski, Bartlomiej and Schiele, Bernt and Lenssen, Jan Eric},
  booktitle = {International Conference on Machine Learning ({ICML})},
  year      = {2025},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

srmbench-0.1.2.tar.gz (29.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

srmbench-0.1.2-py3-none-any.whl (26.9 kB view details)

Uploaded Python 3

File details

Details for the file srmbench-0.1.2.tar.gz.

File metadata

  • Download URL: srmbench-0.1.2.tar.gz
  • Upload date:
  • Size: 29.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for srmbench-0.1.2.tar.gz
Algorithm Hash digest
SHA256 746dd3b09c3b5934cd02269e8106d0b3dfe998384c7e0f0e072a91cee88887d6
MD5 70ad6c5392f484ed9d45b352f437b5a7
BLAKE2b-256 d5b6ed67e73beb79407966cdca37214a56f64cbf5eb1c367dec6c2b0427f93f8

See more details on using hashes here.

File details

Details for the file srmbench-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: srmbench-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 26.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for srmbench-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 6ecb400f2c25dc7c02aaa119f64ee546acaca5849c53750500cfd3b93112effd
MD5 68669ce74c9384c87473765fe226c26b
BLAKE2b-256 b876c6171df63766a43de56523d1e495e1fd38a461144b58145356e74b1fc049

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page