Skip to main content

A minimal library for Gaussian density map workflows with PyTorch and albumentations

Project description

density-maps

A minimal library for Gaussian density map workflows with PyTorch and albumentations.

The core idea is that DensityMapGenerator owns sigma and alpha — every other component (transforms, metrics) derives what it needs from the generator instance, so there is only one place to configure the density scaling.


Installation

pip install density-maps

For PyTorch Lightning support:

pip install density-maps[lightning]

Concepts

Component Responsibility
DensityMapGenerator keypoints → density map tensor; owns sigma / alpha
KeypointTransform wraps any albumentations transform, injects keypoint sync automatically
DensityMapTransform full pipeline: image + keypoints → (image_tensor, density_map)
TiledPredictor runs any nn.Module on large images via overlapping tiles
count_* metrics operate on plain counts; no knowledge of sigma required

Simple Single-Class Workflow

Here's a basic end-to-end workflow with a simple dataset and model:

import numpy as np
import torch
from density_maps import DensityMapGenerator, DensityMapTransform, count_metrics
import albumentations as A

# Setup
generator = DensityMapGenerator(sigma=2.0)
transform = DensityMapTransform(density_generator=generator)

# Example data
keypoints = np.array([[100.0, 150.0], [200.0, 120.0]])  # (N, 2) coordinates
image = np.random.randn(224, 224, 3).astype(np.float32)  # (H, W, C)

# Generate density map
image_tensor, density_map = transform(image, keypoints)
# image_tensor : (3, 224, 224)
# density_map  : (1, 224, 224)

# Simple model inference (assuming you have a trained model)
model = torch.nn.Conv2d(3, 1, 3, padding=1)  # Any torch.nn.Module that outputs a tensor of shape (B, C, H, W)
with torch.no_grad():
	pred_density_map = model(image_tensor.unsqueeze(0)) # (1, 1, 224, 224)
density_map = density_map.unsqueeze(0)
# Calculate metrics
pred_count = generator.to_count(pred_density_map)  # (1,)
target_count = generator.to_count(density_map)     # (1,)
metrics = count_metrics(pred_count, target_count)

Image Augmentation

DensityMapTransform automatically synchronizes image augmentations with keypoint transformations using Albumentations keypoints before creating the density map:

import albumentations as A
from density_maps import DensityMapGenerator, DensityMapTransform

generator = DensityMapGenerator(sigma=2.0)
albu_transform = A.Compose([...])
transform = DensityMapTransform(albu_transform=albu_transform, density_generator=generator)

Note: The A.Compose must include keypoint_params with format "xy" or "yx".


3. Multi-class Support

For multi-class scenarios, pass a dictionary mapping class indices to keypoints:

import numpy as np
from density_maps import DensityMapGenerator, DensityMapTransform
import albumentations as A
from albumentations import ToTensorV2

# Configure for multi-class
generator = DensityMapGenerator(sigma=2.0, n_classes=3)
transform = DensityMapTransform(
    albu_transform=A.Compose([
        A.Resize(224, 224),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(), # Optional: Otherwise gets handled internally during call
    ]),
    density_generator=generator
)

# Multi-class keypoints
keypoints = {
    0: np.array([[120.0, 80.0], [200.0, 150.0]]),   # class 0
    1: np.array([[55.0, 310.0]]),                    # class 1
    2: np.array([[400.0, 220.0], [310.0, 90.0]]),   # class 2
}

image_tensor, density_map = transform(image, keypoints)
# density_map : (3, 224, 224) — one channel per class
counts = generator.to_count(density_map)  # tensor([2., 1., 2.])

Large Image Inference

For large images or images of flexible spatial resolution, use TiledPredictor:

import torch
from density_maps import TiledPredictor

# TiledPredictor needs a tile_size that matches your augmentation
predictor = TiledPredictor(model, tile_size=224, overlap=0.25, n_classes=1)

# Works on any spatial resolution
full_image = torch.rand(1, 3, 1024, 768)
result = predictor(full_image, device=torch.device("cuda"))

# Convert to count
count = generator.to_count(result)  # scalar count for the full image

Important: When using TiledPredictor, ensure your augmentations use a RandomCrop, CenterCrop or Crop that matches the tile_size parameter.


Custom Datasets

If you have a custom dataset, you can implement it as a torch.utils.data.Dataset subclass. In that case you need to provide a transform that applies the necessary augmentations.

import cv2
from pathlib import Path
from torch.utils.data import Dataset
from density_maps import DensityMapGenerator, DensityMapTransform
import albumentations as A

class DensityMapDataset(Dataset):
    def __init__(self, image_paths: list[Path], keypoint_paths: list[Path], transform: DensityMapTransform):
        self.image_paths = image_paths
        self.keypoint_paths = keypoint_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.cvtColor(cv2.imread(self.image_paths[idx]), cv2.COLOR_BGR2RGB)
        keypoints = np.load(self.keypoint_paths[idx])  # (N, 2) float32 (x, y)

        image_tensor, density_map = self.transform(image, keypoints)
        # image_tensor : (3, H, W)
        # density_map  : (1, H, W) or (n_classes, H, W) for multi-class
        return image_tensor, density_map

# Usage
generator = DensityMapGenerator(sigma=2.0)
transform = DensityMapTransform(
    albu_transform=A.Compose([
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ]),
    density_generator=generator
)

dataset = DensityMapDataset(image_paths=[...], keypoint_paths=[...], transform=transform)
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)

Training with Provided Models

The library includes several suggested model architectures. You can use this with any torch training workflow. Here is an example using PyTorch Lightning.

import lightning as L
import torch
from density_maps import DensityMapGenerator, DensityMapLoss
from density_maps.models import FCRN, Unet, SAUnet

class DensityMapLightningModule(L.LightningModule):
    def __init__(self, model, loss):
        super().__init__()
        self.model = model
        self.loss = loss

    def training_step(self, batch, batch_idx):
        imgs, dmap = batch
        pred_dmap = self.model(imgs)
        loss = self.loss(pred_dmap, dmap)
        self.log("train_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        imgs, dmap = batch
        pred_dmap = self.model(imgs)
        loss = self.loss(pred_dmap, dmap)
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-3)

# Setup
generator = DensityMapGenerator(sigma=2.0)
model = FCRN(hidden_dims=[8, 16, 32], input_channels=3, output_channels=1)
loss = DensityMapLoss(generator)
lightning_model = DensityMapLightningModule(model, loss)

# Train
trainer = L.Trainer(max_epochs=10)
trainer.fit(lightning_model, train_dataloaders=train_loader)

Saving and Loading

7. Model and Transform Persistence

Save both your model and transform configuration:

# Save model
torch.save(model.state_dict(), 'checkpoints/model_weights.pth')

# Save transform config
transform.save_config('checkpoints/transform_config.json')

# Load model
model.load_state_dict(torch.load('checkpoints/model_weights.pth'))

# Load transform
loaded_transform = DensityMapTransform.load_config('checkpoints/transform_config.json')

Note: You need to reconstruct the transform with the same parameters when loading for valid inference.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

density_maps-0.9.0.tar.gz (14.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

density_maps-0.9.0-py3-none-any.whl (18.4 kB view details)

Uploaded Python 3

File details

Details for the file density_maps-0.9.0.tar.gz.

File metadata

  • Download URL: density_maps-0.9.0.tar.gz
  • Upload date:
  • Size: 14.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for density_maps-0.9.0.tar.gz
Algorithm Hash digest
SHA256 d5d7696ce9b3c083fe23609b521e0a01eb2e06d6ccede04e0780ff2486f4d530
MD5 7de31241ba6d35d429f24502b47349f7
BLAKE2b-256 683a2e7534352eafda3a1fbe51c54a1d1c5560e6ec86b82e5522c2debd16d090

See more details on using hashes here.

File details

Details for the file density_maps-0.9.0-py3-none-any.whl.

File metadata

  • Download URL: density_maps-0.9.0-py3-none-any.whl
  • Upload date:
  • Size: 18.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for density_maps-0.9.0-py3-none-any.whl
Algorithm Hash digest
SHA256 51346ca3eca60235ae3f9be68ae66aec8f019795bc3b7bd2aa2f217c45854d00
MD5 8096a752d3891c8e3629babfeddd81b3
BLAKE2b-256 cc58244b93bdf041a4b51212a078dd6461d0c46ecc61bc195651a9af76dde18a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page