Skip to main content

A package for generating synthetic data to augment image classifiers.

Project description

Synderm

Documentation Dataset

Synderm is a package designed to enhance image classification tasks using synthetic data generation. It provides tools to generate high-quality synthetic images using diffusion models, fine-tune these models on your specific datasets, and seamlessly integrate synthetic data into your training pipelines to improve classifier performance.

Table of Contents

Features

  • Synthetic Data Generation: Utilize diffusion models to create high-quality synthetic images tailored to your dataset.
  • Fine-Tuning: Adapt pre-trained diffusion models to your specific classes using minimal real data.
  • Dataset Augmentation: Combine real and synthetic data effortlessly to enhance your training datasets.
  • Seamless Integration: Compatible with popular deep learning frameworks like PyTorch and FastAI.
  • Flexible Configuration: Easily customize prompts, training parameters, and data splits to fit your project's needs.

Models

Synderm directly supports the following models for image generation:

Other diffusion models can be used but are currently untested.

All functions assume that the training and validation datasets return entries with an image, label, and id field. If your dataset does not conform to this structure, please adjust it accordingly (see examples below).

Installation

# To install from the Python Package Index:
pip install synderm

# Build from source
pip install -e .

Ensure you have PyTorch and FastAI installed.

Quick Start

1. Creating the Dataset

Synderm requires datasets to return entries with image, label, and id fields. Here's an example of how to create a custom dataset:

from torch.utils.data import Dataset
from PIL import Image
from pathlib import Path
import os

class SampleDataset(Dataset):
    def __init__(self, dataset_dir, split="train"):
        self.dataset_dir = Path(dataset_dir)
        self.image_paths = []
        self.labels = []
        self.split = split

        # Walk through class folders
        data_dir = self.dataset_dir / self.split
        for class_name in os.listdir(data_dir):
            class_dir = data_dir / class_name
            if not class_dir.is_dir():
                continue
                
            # Get all png images in this class folder
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith('.png'):
                    self.image_paths.append(class_dir / img_name)
                    self.labels.append(class_name)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load and convert image to RGB
        image = Image.open(image_path).convert('RGB')
        image_name = image_path.stem

        return {"id": image_name, "image": image, "label": label}

2. Training the Synthetic Image Generator

Fine-tune a diffusion model using your dataset to generate synthetic images:

from synderm.synderm.fine_tune.text_to_image_diffusion import fine_tune_text_to_image

output_dir = os.path.join(EXPERIMENT_DIR, "dreambooth-outputs")

fine_tune_text_to_image(
    train_dataset=train_dataset,
    pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base",
    instance_prompt="An image of an English Springer",
    validation_prompt_format="An image of an English Springer",
    output_dir=output_dir,
    label_filter="English_springer",
    resolution=512,
    train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=5e-6,
    lr_scheduler="constant",
    # Additional parameters...
)

3. Generate Synthetic Images

Use the fine-tuned diffusion model to generate a set of synthetic images

model_path = os.path.join(EXPERIMENT_DIR, "dreambooth-outputs", "English_springer")
image_output_path = os.path.join(EXPERIMENT_DIR, "generations")

generate_synthetic_dataset(
    dataset= train_dataset,
    model_path = model_path,
    output_dir_path = image_output_path,
    generation_type = "text-to-image", 
    label_filter = "English_springer",
    instance_prompt = "An image of an English Springer",
    batch_size = 16,
    start_index = 0,
    num_generations_per_image = 10,
    guidance_scale = 3.0,
    num_inference_steps = 50,
    strength_inpaint = 0.970,
    strength_outpaint = 0.950,
    mask_fraction = 0.25
)

4. Augmenting the Classifier with Synthetic Images

Combine real and synthetic data to train and evaluate the classifier:

from synderm.utils.utils import synthetic_train_val_split

synthetic_dataset = SyntheticDataset(os.path.join(image_output_path, "text-to-image"))

train, val = synthetic_train_val_split(
    real_data=train_dataset,
    synthetic_data=synthetic_dataset,
    per_class_test_size=5,
    random_state=42,
    mapping_real_to_synthetic="id"
)

Example Scripts

We include several example scripts at synderm/example_scripts:

  • train_diffusion_model_text_to_image.py: Script for fine-tuning the Stable Diffusion model conditioned on a text prompt.
  • train_diffusion_model_inpaint.py: Script for fine-tuning the Stable Diffusion model conditioned on a text prompt, and random masks of an image.
  • generate_synthetic_images.py: Script for generating synthetic images using fine-tuned models
  • sample_datasets.py: A few example Torch datasets that are compatible with this package. Includes a FitzDataset sample that can be used once the original images are downloaded (see Data)

Data

The original Fitzpatrick17k dataset can be installed from this GitHub link. The images need to be downloaded from original source. We include clean training and held-out splits in the fitz_metadata folder.

Contributing

Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

synderm-0.1.2.tar.gz (5.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

synderm-0.1.2-py3-none-any.whl (4.7 kB view details)

Uploaded Python 3

File details

Details for the file synderm-0.1.2.tar.gz.

File metadata

  • Download URL: synderm-0.1.2.tar.gz
  • Upload date:
  • Size: 5.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.11

File hashes

Hashes for synderm-0.1.2.tar.gz
Algorithm Hash digest
SHA256 bd5d9fd8716282759f2959a8031c608303e4adfe1faeb61e120b3cedd1ab9e4d
MD5 845f13a64a2bce3c467e1eaa7a8cb05e
BLAKE2b-256 91f36887541f2ab490c8ec8fae94e725d9045044cd11b30fec962b815e96dd12

See more details on using hashes here.

File details

Details for the file synderm-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: synderm-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 4.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.11

File hashes

Hashes for synderm-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 bfa3e8770662bacb953b192a2faed1c9df49093f4e50857f83b9b4db7955070b
MD5 66e49cbcf7d471bf8a45ce52246c21cb
BLAKE2b-256 871b52be849ebe827544a7792ca3cd6ab88fa97f6ce926878945f98ad269f5d3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page