Skip to main content

A package for generating synthetic data to augment image classifiers.

Project description

Synderm

Documentation Complete Dataset Train Dataset PyPI

Synderm is a package designed to enhance image classification tasks using synthetic data generation. It provides tools to generate high-quality synthetic images using diffusion models, fine-tune these models on your specific datasets, and seamlessly integrate synthetic data into your training pipelines to improve classifier performance.

Table of Contents

Features

  • Synthetic Data Generation: Utilize diffusion models to create high-quality synthetic images tailored to your dataset.
  • Fine-Tuning: Adapt pre-trained diffusion models to your specific classes using minimal real data.
  • Dataset Augmentation: Combine real and synthetic data effortlessly to enhance your training datasets.
  • Seamless Integration: Compatible with popular deep learning frameworks like PyTorch and FastAI.
  • Flexible Configuration: Easily customize prompts, training parameters, and data splits to fit your project's needs.

Dataset Details

We have developed a HuggingFace dataset with over 1 millions (more than 600GB). To support efficient use and reuse of such a large dataset, we use the WebDataset format. Using this format, data is split into into TAR shards that contain at most 5,000 images (up to ~2GB). This allows for fine-grained data subsetting with minimal memory and time overhead.

We have developed two versions of the dataset to support different applications. These are:

  1. synthetic-derm-1M: This dataset is intended for fine-grained retrieval of particular labels and generation methods. Each shard is named using the format: shard-{disease-label}-{synthetic-generation-method}-{submethod}-{index}.tar. An example shard name is shard-vitiligo-finetune-text-to-image-text-to-image-00000.tar.

  2. synthetic-derm-1M-train: This dataset is intended to be used directly for training models. We group images by generation method, perform a shuffle, and then shard the images. Each shard is named using the format: shard-{synthetic-generation-method}-{index}.tar. For model training, the dataset can still be subset to specific labels.

See WebDataset FAQ for many more examples of how to use these two datasets. We also provide a vignette demonstrating how to use these dataset.

Models

Synderm directly supports the following models for image generation:

Other diffusion models can be used but are currently untested.

All functions assume that the training and validation datasets return entries with an image, label, and id field. If your dataset does not conform to this structure, please adjust it accordingly (see examples below).

Installation

pip install synderm

Build from source

pip install -r requirements.txt
pip install -e .

Ensure you have PyTorch and FastAI installed.

Quick Start

1. Creating the Dataset

Synderm requires datasets to return entries with image, label, and id fields. Here's an example of how to create a custom dataset:

from torch.utils.data import Dataset
from PIL import Image
from pathlib import Path
import os

class SampleDataset(Dataset):
    def __init__(self, dataset_dir, split="train"):
        self.dataset_dir = Path(dataset_dir)
        self.image_paths = []
        self.labels = []
        self.split = split

        # Walk through class folders
        data_dir = self.dataset_dir / self.split
        for class_name in os.listdir(data_dir):
            class_dir = data_dir / class_name
            if not class_dir.is_dir():
                continue
                
            # Get all png images in this class folder
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith('.png'):
                    self.image_paths.append(class_dir / img_name)
                    self.labels.append(class_name)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load and convert image to RGB
        image = Image.open(image_path).convert('RGB')
        image_name = image_path.stem

        return {"id": image_name, "image": image, "label": label}

2. Training the Synthetic Image Generator

Fine-tune a diffusion model using your dataset to generate synthetic images:

from synderm.fine_tune import fine_tune_text_to_image

output_dir = os.path.join(EXPERIMENT_DIR, "dreambooth-outputs")

fine_tune_text_to_image(
    train_dataset=train_dataset,
    pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base",
    instance_prompt="An image of an English Springer",
    validation_prompt_format="An image of an English Springer",
    output_dir=output_dir,
    label_filter="English_springer",
    resolution=512,
    train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=5e-6,
    lr_scheduler="constant",
    # Additional parameters...
)

3. Generate Synthetic Images

Use the fine-tuned diffusion model to generate a set of synthetic images

model_path = os.path.join(EXPERIMENT_DIR, "dreambooth-outputs", "English_springer")
image_output_path = os.path.join(EXPERIMENT_DIR, "generations")

generate_synthetic_dataset(
    dataset= train_dataset,
    model_path = model_path,
    output_dir_path = image_output_path,
    generation_type = "text-to-image", 
    label_filter = "English_springer",
    instance_prompt = "An image of an English Springer",
    batch_size = 16,
    start_index = 0,
    num_generations_per_image = 10,
    guidance_scale = 3.0,
    num_inference_steps = 50,
    strength_inpaint = 0.970,
    strength_outpaint = 0.950,
    mask_fraction = 0.25
)

4. Augmenting the Classifier with Synthetic Images

Combine real and synthetic data to train and evaluate the classifier:

from synderm.utils import synthetic_train_val_split

synthetic_dataset = SyntheticDataset(os.path.join(image_output_path, "text-to-image"))

train, val = synthetic_train_val_split(
    real_data=train_dataset,
    synthetic_data=synthetic_dataset,
    per_class_test_size=5,
    random_state=42,
    mapping_real_to_synthetic="id"
)

Example Scripts

We include several example scripts at synderm/example_scripts:

  • train_diffusion_model_text_to_image.py: Script for fine-tuning the Stable Diffusion model conditioned on a text prompt.
  • train_diffusion_model_inpaint.py: Script for fine-tuning the Stable Diffusion model conditioned on a text prompt, and random masks of an image.
  • generate_synthetic_images.py: Script for generating synthetic images using fine-tuned models
  • sample_datasets.py: A few example Torch datasets that are compatible with this package. Includes a FitzDataset sample that can be used once the original images are downloaded (see Data)

Data

The original Fitzpatrick17k dataset can be installed from this GitHub link. The images need to be downloaded from original source. We include clean training and held-out splits in the fitz_metadata folder.

Contributing

Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

synderm-0.1.4.tar.gz (31.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

synderm-0.1.4-py3-none-any.whl (33.8 kB view details)

Uploaded Python 3

File details

Details for the file synderm-0.1.4.tar.gz.

File metadata

  • Download URL: synderm-0.1.4.tar.gz
  • Upload date:
  • Size: 31.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for synderm-0.1.4.tar.gz
Algorithm Hash digest
SHA256 9b16ee009218b998e27858e2e3bcbe562ce449297c5a01bc885eede094dd77fa
MD5 fe99b1eb53193d74a1a925780c22b64f
BLAKE2b-256 bae36fb11c9e298e6d9d6cb0c14c93edfbab97d76dcda28cb197cb9063755ced

See more details on using hashes here.

File details

Details for the file synderm-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: synderm-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 33.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for synderm-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 d533fa2ef533d2e06faeab6e9f95b50ccd430d01aa69429ee17ddbf4f01069d9
MD5 6f9415ab3aaa3d35eae4598fcf28ccd7
BLAKE2b-256 2c2cbe8c2da4e7e944f27cd2ae4b1d60a401702f7a6e8119783b8323f991b6f4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page