Skip to main content

Unified PyTorch image quality metrics library

Project description

torch-image-metrics

PyPI version Python Version License: MIT

torch-image-metrics is a unified PyTorch library for image quality evaluation, providing implementations of popular metrics like PSNR, SSIM, LPIPS, and FID with a simple, consistent API.

✨ Features

  • 🚀 Fast & Efficient: GPU-accelerated computations with PyTorch
  • 📊 Comprehensive Metrics: PSNR, SSIM, MSE, MAE, LPIPS, ImprovedSSIM, FID
  • 🎯 Unified API: Consistent interface across all metrics
  • 📁 Dataset Evaluation: Bulk evaluation with statistical analysis
  • 🔄 Batch Processing: Efficient batch computation support
  • 🛠️ Flexible: Optional dependencies for advanced metrics
  • 📈 Easy Integration: Drop-in replacement for existing workflows

📦 Installation

Basic Installation

pip install torch-image-metrics

With Optional Dependencies (Recommended)

pip install torch-image-metrics[full]

Development Installation

git clone https://github.com/mdipcit/torch-image-metrics.git
cd torch-image-metrics
pip install -e .[dev]

🚀 Quick Start

Quick API (Single Metrics)

import torch
import torch_image_metrics as tim

# Generate sample images
img1 = torch.rand(1, 3, 256, 256)  # Reference image
img2 = torch.rand(1, 3, 256, 256)  # Test image

# Calculate individual metrics
psnr = tim.quick_psnr(img1, img2)      # Peak Signal-to-Noise Ratio
ssim = tim.quick_ssim(img1, img2)      # Structural Similarity Index
mse = tim.quick_mse(img1, img2)        # Mean Squared Error
mae = tim.quick_mae(img1, img2)        # Mean Absolute Error

print(f"PSNR: {psnr:.2f} dB")
print(f"SSIM: {ssim:.4f}")

Calculator API (Multiple Metrics)

import torch_image_metrics as tim

# Initialize calculator
calc = tim.Calculator(device='cuda')  # or 'cpu'

# Compute all available metrics at once
metrics = calc.compute_all_metrics(img1, img2)

print(f"PSNR: {metrics.psnr_db:.2f} dB")
print(f"SSIM: {metrics.ssim:.4f}")
print(f"MSE: {metrics.mse:.6f}")
print(f"MAE: {metrics.mae:.6f}")

# Optional metrics (if dependencies available)
if metrics.lpips is not None:
    print(f"LPIPS: {metrics.lpips:.4f}")
if metrics.ssim_improved is not None:
    print(f"SSIM++: {metrics.ssim_improved:.4f}")

Dataset Evaluation

import torch_image_metrics as tim
from pathlib import Path

# Initialize evaluator with desired metrics
evaluator = tim.Evaluator(
    device='cuda',
    use_lpips=True,           # Enable LPIPS (requires lpips package)
    use_improved_ssim=True,   # Enable SSIM++ (requires torchmetrics)
    use_fid=True,             # Enable FID (requires pytorch-fid)
    batch_size=16
)

# Evaluate entire datasets
test_dir = Path("path/to/test/images")
ref_dir = Path("path/to/reference/images")

results = evaluator.evaluate_dataset(test_dir, ref_dir)

# Access results
print(f"Total images evaluated: {results.total_images}")
print(f"FID Score: {results.fid_score:.2f}")

# Statistical summary
stats = results.statistics
print(f"PSNR: {stats['psnr_db']['mean']:.2f} ± {stats['psnr_db']['std']:.2f} dB")
print(f"SSIM: {stats['ssim']['mean']:.4f} ± {stats['ssim']['std']:.4f}")

# Print comprehensive summary
evaluator.print_summary(results)

📊 Supported Metrics

Metric Description Type Requires
PSNR Peak Signal-to-Noise Ratio Full-Reference Core
SSIM Structural Similarity Index Full-Reference Core
MSE Mean Squared Error Full-Reference Core
MAE Mean Absolute Error Full-Reference Core
LPIPS Learned Perceptual Image Patch Similarity Full-Reference lpips package
SSIM++ Improved Structural Similarity Full-Reference torchmetrics package
FID Fréchet Inception Distance Dataset-level pytorch-fid package

🔧 Advanced Usage

Custom Image Matching

import torch_image_metrics as tim

# Initialize image matcher for dataset evaluation
matcher = tim.ImageMatcher(match_strategy='stem')  # or 'full_name'

# Validate dataset structure
is_valid, message = matcher.validate_datasets(test_dir, ref_dir)
if not is_valid:
    print(f"Dataset validation failed: {message}")

# Find matching image pairs
pairs = matcher.find_image_pairs(test_dir, ref_dir)
print(f"Found {len(pairs)} matching pairs")

# Get detailed statistics
stats = matcher.get_matching_statistics(test_dir, ref_dir)
print(f"Matching statistics: {stats}")

Batch Processing

import torch
import torch_image_metrics as tim

# Process batches of images efficiently
batch_size = 8
calc = tim.Calculator(device='cuda')

test_batch = torch.rand(batch_size, 3, 256, 256)
ref_batch = torch.rand(batch_size, 3, 256, 256)

# Batch computation for better performance
for i in range(batch_size):
    metrics = calc.compute_all_metrics(
        test_batch[i:i+1], 
        ref_batch[i:i+1]
    )
    print(f"Image {i}: PSNR={metrics.psnr_db:.2f} dB")

🔧 Configuration

Optional Dependencies

torch-image-metrics gracefully handles optional dependencies:

  • LPIPS: Install with pip install lpips
  • ImprovedSSIM: Install with pip install torchmetrics
  • FID: Install with pip install pytorch-fid

If optional dependencies are not available, the corresponding metrics will return None values.

Device Management

import torch_image_metrics as tim

# Automatic device detection
calc = tim.Calculator()  # Uses CUDA if available, else CPU

# Explicit device specification
calc_gpu = tim.Calculator(device='cuda')
calc_cpu = tim.Calculator(device='cpu')

# Device verification
print(f"Using device: {calc.device}")

📈 Performance Tips

  1. Use GPU: Enable CUDA for significant speedup
  2. Batch Processing: Process multiple images together when possible
  3. Appropriate Image Size: Resize large images for faster computation
  4. Disable Unused Metrics: Turn off expensive metrics like LPIPS/FID if not needed
# Performance-optimized evaluator
evaluator = tim.Evaluator(
    device='cuda',
    use_lpips=False,      # Disable for speed
    use_fid=False,        # Disable for speed
    batch_size=32,        # Larger batches
    image_size=256        # Resize for speed
)

🤝 Contributing

We welcome contributions! Please see our Contributing Guidelines for details.

Development Setup

# Clone and setup development environment
git clone https://github.com/mdipcit/torch-image-metrics.git
cd torch-image-metrics

# Install with development dependencies
pip install -e .[dev]

# Run tests
pytest

# Code quality checks
ruff check src/ --fix
ruff format src/
mypy src/torch_image_metrics/

📚 Documentation

  • API Reference: [Coming Soon]
  • Examples: See the examples/ directory
  • Migration Guide: [Coming Soon]

🧪 Testing

# Run all tests
pytest

# Run with coverage
pytest --cov=src/torch_image_metrics --cov-report=term-missing

# Run specific test categories
pytest tests/unit/          # Unit tests
pytest tests/integration/   # Integration tests

🛡️ Requirements

  • Python: 3.10 or 3.11
  • PyTorch: ≥2.0.0
  • torchvision: ≥0.15.0
  • Pillow: ≥9.0.0
  • NumPy: ≥1.21.0

Optional Dependencies

  • lpips: ≥0.1.4 (for LPIPS metric)
  • pytorch-fid: ≥0.3.0 (for FID metric)
  • torchmetrics: ≥1.8.2 (for ImprovedSSIM metric)

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.

🙏 Acknowledgments

This library was developed as part of the Generative-Latent-Optimization project and extracted into a standalone package for broader use.

📞 Support


Made with ❤️ for the computer vision community

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_image_metrics-0.1.0.tar.gz (24.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_image_metrics-0.1.0-py3-none-any.whl (32.3 kB view details)

Uploaded Python 3

File details

Details for the file torch_image_metrics-0.1.0.tar.gz.

File metadata

  • Download URL: torch_image_metrics-0.1.0.tar.gz
  • Upload date:
  • Size: 24.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.6

File hashes

Hashes for torch_image_metrics-0.1.0.tar.gz
Algorithm Hash digest
SHA256 97e45cfadeabbcf83e4bfd8045e0fdcef4a3fa6c8b000884e1ce46cbee8d6087
MD5 4f518bc2d2daa878a91f34999c09dc0c
BLAKE2b-256 bcae58d5034a5050a014b4482d4c0e119c95fe429c70a082c552a65170711ec6

See more details on using hashes here.

File details

Details for the file torch_image_metrics-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_image_metrics-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 bd12a8061b0cc65d5a2305c1577bf23753b8b6fb277948301c8525cb30c9172b
MD5 8365092dfae7d1f73408f07b9b305125
BLAKE2b-256 7e6a7864190af701aa34b0f8cdb9b3411652870676a12e02b027950e8e390888

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page