Skip to main content

A package for generating synthetic data to augment image classifiers.

Project description

Synderm

Synderm is a package designed to enhance image classification tasks using synthetic data generation. It provides tools to generate high-quality synthetic images using diffusion models, fine-tune these models on your specific datasets, and seamlessly integrate synthetic data into your training pipelines to improve classifier performance.

Table of Contents

Features

  • Synthetic Data Generation: Utilize diffusion models to create high-quality synthetic images tailored to your dataset.
  • Fine-Tuning: Adapt pre-trained diffusion models to your specific classes using minimal real data.
  • Dataset Augmentation: Combine real and synthetic data effortlessly to enhance your training datasets.
  • Seamless Integration: Compatible with popular deep learning frameworks like PyTorch and FastAI.
  • Flexible Configuration: Easily customize prompts, training parameters, and data splits to fit your project's needs.

Models

Synderm directly supports the following models for image generation:

Other diffusion models can be used but are currently untested.

All functions assume that the training and validation datasets return entries with an image, label, and id field. If your dataset does not conform to this structure, please adjust it accordingly (see examples below).

Installation

pip install synderm

Ensure you have PyTorch and FastAI installed.

Quick Start

1. Creating the Dataset

Synderm requires datasets to return entries with image, label, and id fields. Here's an example of how to create a custom dataset:

from torch.utils.data import Dataset
from PIL import Image
from pathlib import Path
import os

class SampleDataset(Dataset):
    def __init__(self, dataset_dir, split="train"):
        self.dataset_dir = Path(dataset_dir)
        self.image_paths = []
        self.labels = []
        self.split = split

        # Walk through class folders
        data_dir = self.dataset_dir / self.split
        for class_name in os.listdir(data_dir):
            class_dir = data_dir / class_name
            if not class_dir.is_dir():
                continue
                
            # Get all png images in this class folder
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith('.png'):
                    self.image_paths.append(class_dir / img_name)
                    self.labels.append(class_name)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load and convert image to RGB
        image = Image.open(image_path).convert('RGB')
        image_name = image_path.stem

        return {"id": image_name, "image": image, "label": label}

2. Training the Synthetic Image Generator

Fine-tune a diffusion model using your dataset to generate synthetic images:

from synderm.synderm.fine_tune.text_to_image_diffusion import fine_tune_text_to_image

output_dir = os.path.join(EXPERIMENT_DIR, "dreambooth-outputs")

fine_tune_text_to_image(
    train_dataset=train_dataset,
    pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base",
    instance_prompt="An image of an English Springer",
    validation_prompt_format="An image of an English Springer",
    output_dir=output_dir,
    label_filter="English_springer",
    resolution=512,
    train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=5e-6,
    lr_scheduler="constant",
    # Additional parameters...
)

3. Generate Synthetic Images

Use the fine-tuned diffusion model to generate a set of synthetic images

model_path = os.path.join(EXPERIMENT_DIR, "dreambooth-outputs", "English_springer")
image_output_path = os.path.join(EXPERIMENT_DIR, "generations")

generate_synthetic_dataset(
    dataset= train_dataset,
    model_path = model_path,
    output_dir_path = image_output_path,
    generation_type = "text-to-image", 
    label_filter = "English_springer",
    instance_prompt = "An image of an English Springer",
    batch_size = 16,
    start_index = 0,
    num_generations_per_image = 10,
    guidance_scale = 3.0,
    num_inference_steps = 50,
    strength_inpaint = 0.970,
    strength_outpaint = 0.950,
    mask_fraction = 0.25
)

4. Augmenting the Classifier with Synthetic Images

Combine real and synthetic data to train and evaluate the classifier:

from synderm.utils.utils import synthetic_train_val_split

synthetic_dataset = SyntheticDataset(os.path.join(image_output_path, "text-to-image"))

train, val = synthetic_train_val_split(
    real_data=train_dataset,
    synthetic_data=synthetic_dataset,
    per_class_test_size=5,
    random_state=42,
    mapping_real_to_synthetic="id"
)

Examples

Please see the notebook at examples/train_with_synthetic_images.ipynb shows a complete examples.

Contributing

Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

synderm-0.1.0.tar.gz (24.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

synderm-0.1.0-py3-none-any.whl (26.4 kB view details)

Uploaded Python 3

File details

Details for the file synderm-0.1.0.tar.gz.

File metadata

  • Download URL: synderm-0.1.0.tar.gz
  • Upload date:
  • Size: 24.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.11

File hashes

Hashes for synderm-0.1.0.tar.gz
Algorithm Hash digest
SHA256 1e67437393a192b6ab7d7d25a48217ba0f36219fa0394f1d7005527682de20db
MD5 b4df127deeaf4f80e4b0424f116b187b
BLAKE2b-256 681442c7e016d966195c5b5f514e3e352da786ca4f7b4580392341212cd61996

See more details on using hashes here.

File details

Details for the file synderm-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: synderm-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 26.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.11

File hashes

Hashes for synderm-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3d98208019633c5b43f123e43e3ebd50886f545f2b536047989b47eb4a33e0e3
MD5 728ff2e04b50df9b397f2ec623601038
BLAKE2b-256 429327593beff248b36949373b2b40562bd4a9e49a18ab4eb2ecc5b972b9f4ae

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page