A package for generating synthetic data to augment image classifiers.
Project description
Synderm
Synderm is a package designed to enhance image classification tasks using synthetic data generation. It provides tools to generate high-quality synthetic images using diffusion models, fine-tune these models on your specific datasets, and seamlessly integrate synthetic data into your training pipelines to improve classifier performance.
Table of Contents
Features
- Synthetic Data Generation: Utilize diffusion models to create high-quality synthetic images tailored to your dataset.
- Fine-Tuning: Adapt pre-trained diffusion models to your specific classes using minimal real data.
- Dataset Augmentation: Combine real and synthetic data effortlessly to enhance your training datasets.
- Seamless Integration: Compatible with popular deep learning frameworks like PyTorch and FastAI.
- Flexible Configuration: Easily customize prompts, training parameters, and data splits to fit your project's needs.
Models
Synderm directly supports the following models for image generation:
- Inpainting:
runwayml/stable-diffusion-inpainting - Outpainting:
stabilityai/stable-diffusion-2-1-base
Other diffusion models can be used but are currently untested.
All functions assume that the training and validation datasets return entries with an image, label, and id field. If your dataset does not conform to this structure, please adjust it accordingly (see examples below).
Installation
# To install from the Python Package Index:
pip install synderm
# Build from source
pip install -e .
Ensure you have PyTorch and FastAI installed.
Quick Start
1. Creating the Dataset
Synderm requires datasets to return entries with image, label, and id fields. Here's an example of how to create a custom dataset:
from torch.utils.data import Dataset
from PIL import Image
from pathlib import Path
import os
class SampleDataset(Dataset):
def __init__(self, dataset_dir, split="train"):
self.dataset_dir = Path(dataset_dir)
self.image_paths = []
self.labels = []
self.split = split
# Walk through class folders
data_dir = self.dataset_dir / self.split
for class_name in os.listdir(data_dir):
class_dir = data_dir / class_name
if not class_dir.is_dir():
continue
# Get all png images in this class folder
for img_name in os.listdir(class_dir):
if img_name.lower().endswith('.png'):
self.image_paths.append(class_dir / img_name)
self.labels.append(class_name)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
label = self.labels[idx]
# Load and convert image to RGB
image = Image.open(image_path).convert('RGB')
image_name = image_path.stem
return {"id": image_name, "image": image, "label": label}
2. Training the Synthetic Image Generator
Fine-tune a diffusion model using your dataset to generate synthetic images:
from synderm.synderm.fine_tune.text_to_image_diffusion import fine_tune_text_to_image
output_dir = os.path.join(EXPERIMENT_DIR, "dreambooth-outputs")
fine_tune_text_to_image(
train_dataset=train_dataset,
pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base",
instance_prompt="An image of an English Springer",
validation_prompt_format="An image of an English Springer",
output_dir=output_dir,
label_filter="English_springer",
resolution=512,
train_batch_size=4,
gradient_accumulation_steps=1,
learning_rate=5e-6,
lr_scheduler="constant",
# Additional parameters...
)
3. Generate Synthetic Images
Use the fine-tuned diffusion model to generate a set of synthetic images
model_path = os.path.join(EXPERIMENT_DIR, "dreambooth-outputs", "English_springer")
image_output_path = os.path.join(EXPERIMENT_DIR, "generations")
generate_synthetic_dataset(
dataset= train_dataset,
model_path = model_path,
output_dir_path = image_output_path,
generation_type = "text-to-image",
label_filter = "English_springer",
instance_prompt = "An image of an English Springer",
batch_size = 16,
start_index = 0,
num_generations_per_image = 10,
guidance_scale = 3.0,
num_inference_steps = 50,
strength_inpaint = 0.970,
strength_outpaint = 0.950,
mask_fraction = 0.25
)
4. Augmenting the Classifier with Synthetic Images
Combine real and synthetic data to train and evaluate the classifier:
from synderm.utils.utils import synthetic_train_val_split
synthetic_dataset = SyntheticDataset(os.path.join(image_output_path, "text-to-image"))
train, val = synthetic_train_val_split(
real_data=train_dataset,
synthetic_data=synthetic_dataset,
per_class_test_size=5,
random_state=42,
mapping_real_to_synthetic="id"
)
Examples
Please see the notebook at examples/train_with_synthetic_images.ipynb shows a complete examples.
Contributing
Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.
License
This project is licensed under the MIT License.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file synderm-0.1.1.tar.gz.
File metadata
- Download URL: synderm-0.1.1.tar.gz
- Upload date:
- Size: 4.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.10.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b3cfda3e06046c543cb53e8656dd6cf9df1df421cd11ab645e1798f3d541bca2
|
|
| MD5 |
044967a2ba8590fd1354be87c8cb0f0b
|
|
| BLAKE2b-256 |
6fb80f22145419eff2ae3452528b7bc182d9dfb7ca69086d5d14526806f39a8a
|
File details
Details for the file synderm-0.1.1-py3-none-any.whl.
File metadata
- Download URL: synderm-0.1.1-py3-none-any.whl
- Upload date:
- Size: 4.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.10.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1af34de9b561a9f672ac2473c4c1b88205ef45acfaa47219e85c8434ba0f69b6
|
|
| MD5 |
e8902e1e4136854270484020cc908222
|
|
| BLAKE2b-256 |
9aeb5970e9d28965026507d4b41c6913fea3d866636e0b9dd6e1302ff5b77b12
|