Skip to main content

A package for generating synthetic data to augment image classifiers.

Project description

Synderm

Synderm is a package designed to enhance image classification tasks using synthetic data generation. It provides tools to generate high-quality synthetic images using diffusion models, fine-tune these models on your specific datasets, and seamlessly integrate synthetic data into your training pipelines to improve classifier performance.

Table of Contents

Features

  • Synthetic Data Generation: Utilize diffusion models to create high-quality synthetic images tailored to your dataset.
  • Fine-Tuning: Adapt pre-trained diffusion models to your specific classes using minimal real data.
  • Dataset Augmentation: Combine real and synthetic data effortlessly to enhance your training datasets.
  • Seamless Integration: Compatible with popular deep learning frameworks like PyTorch and FastAI.
  • Flexible Configuration: Easily customize prompts, training parameters, and data splits to fit your project's needs.

Models

Synderm directly supports the following models for image generation:

Other diffusion models can be used but are currently untested.

All functions assume that the training and validation datasets return entries with an image, label, and id field. If your dataset does not conform to this structure, please adjust it accordingly (see examples below).

Installation

# To install from the Python Package Index:
pip install synderm

# Build from source
pip install -e .

Ensure you have PyTorch and FastAI installed.

Quick Start

1. Creating the Dataset

Synderm requires datasets to return entries with image, label, and id fields. Here's an example of how to create a custom dataset:

from torch.utils.data import Dataset
from PIL import Image
from pathlib import Path
import os

class SampleDataset(Dataset):
    def __init__(self, dataset_dir, split="train"):
        self.dataset_dir = Path(dataset_dir)
        self.image_paths = []
        self.labels = []
        self.split = split

        # Walk through class folders
        data_dir = self.dataset_dir / self.split
        for class_name in os.listdir(data_dir):
            class_dir = data_dir / class_name
            if not class_dir.is_dir():
                continue
                
            # Get all png images in this class folder
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith('.png'):
                    self.image_paths.append(class_dir / img_name)
                    self.labels.append(class_name)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load and convert image to RGB
        image = Image.open(image_path).convert('RGB')
        image_name = image_path.stem

        return {"id": image_name, "image": image, "label": label}

2. Training the Synthetic Image Generator

Fine-tune a diffusion model using your dataset to generate synthetic images:

from synderm.synderm.fine_tune.text_to_image_diffusion import fine_tune_text_to_image

output_dir = os.path.join(EXPERIMENT_DIR, "dreambooth-outputs")

fine_tune_text_to_image(
    train_dataset=train_dataset,
    pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base",
    instance_prompt="An image of an English Springer",
    validation_prompt_format="An image of an English Springer",
    output_dir=output_dir,
    label_filter="English_springer",
    resolution=512,
    train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=5e-6,
    lr_scheduler="constant",
    # Additional parameters...
)

3. Generate Synthetic Images

Use the fine-tuned diffusion model to generate a set of synthetic images

model_path = os.path.join(EXPERIMENT_DIR, "dreambooth-outputs", "English_springer")
image_output_path = os.path.join(EXPERIMENT_DIR, "generations")

generate_synthetic_dataset(
    dataset= train_dataset,
    model_path = model_path,
    output_dir_path = image_output_path,
    generation_type = "text-to-image", 
    label_filter = "English_springer",
    instance_prompt = "An image of an English Springer",
    batch_size = 16,
    start_index = 0,
    num_generations_per_image = 10,
    guidance_scale = 3.0,
    num_inference_steps = 50,
    strength_inpaint = 0.970,
    strength_outpaint = 0.950,
    mask_fraction = 0.25
)

4. Augmenting the Classifier with Synthetic Images

Combine real and synthetic data to train and evaluate the classifier:

from synderm.utils.utils import synthetic_train_val_split

synthetic_dataset = SyntheticDataset(os.path.join(image_output_path, "text-to-image"))

train, val = synthetic_train_val_split(
    real_data=train_dataset,
    synthetic_data=synthetic_dataset,
    per_class_test_size=5,
    random_state=42,
    mapping_real_to_synthetic="id"
)

Examples

Please see the notebook at examples/train_with_synthetic_images.ipynb shows a complete examples.

Contributing

Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

synderm-0.1.1.tar.gz (4.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

synderm-0.1.1-py3-none-any.whl (4.4 kB view details)

Uploaded Python 3

File details

Details for the file synderm-0.1.1.tar.gz.

File metadata

  • Download URL: synderm-0.1.1.tar.gz
  • Upload date:
  • Size: 4.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.11

File hashes

Hashes for synderm-0.1.1.tar.gz
Algorithm Hash digest
SHA256 b3cfda3e06046c543cb53e8656dd6cf9df1df421cd11ab645e1798f3d541bca2
MD5 044967a2ba8590fd1354be87c8cb0f0b
BLAKE2b-256 6fb80f22145419eff2ae3452528b7bc182d9dfb7ca69086d5d14526806f39a8a

See more details on using hashes here.

File details

Details for the file synderm-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: synderm-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 4.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.11

File hashes

Hashes for synderm-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 1af34de9b561a9f672ac2473c4c1b88205ef45acfaa47219e85c8434ba0f69b6
MD5 e8902e1e4136854270484020cc908222
BLAKE2b-256 9aeb5970e9d28965026507d4b41c6913fea3d866636e0b9dd6e1302ff5b77b12

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page