Skip to main content

No project description provided

Project description

StableDiffusionInpaintingFineTune

This project provides a toolkit for fine-tuning the Stable Diffusion model for inpainting tasks (image restoration based on a mask) using PyTorch and Hugging Face Diffusers libraries.

Requirements

Before starting, you need to install the following libraries: .. code-block:: python

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

  • torch
  • diffusers
  • transformers
  • accelerate
  • huggingface_hub
  • PIL
  • numpy
  • tqdm

Description

StableDiffusionInpaintingFineTune


This class is responsible for fine-tuning the Stable Diffusion model for the inpainting task. It supports training both the text encoder and the UNet model and uses various settings to control the training process.

Constructor
^^^^^^^^^^^

.. code-block:: python

   __init__(self, pretrained_model_name_or_path, resolution, center_crop, ...)

- **pretrained_model_name_or_path**: The path or name of the pre-trained model.
- **resolution**: The resolution of the images.
- **center_crop**: Whether to apply center cropping during data preparation.
- **train_text_encoder**: Whether to train the text encoder.
- **dataset**: The dataset object.
- **learning_rate**: The initial learning rate.
- **max_training_steps**: The maximum number of training steps.
- **save_steps**: The number of steps between saving checkpoints.
- **train_batch_size**: The batch size.
- **gradient_accumulation_steps**: The number of steps to accumulate gradients.
- **mixed_precision**: Use of mixed precision ("fp16", "bf16", or None).
- **gradient_checkpointing**: Use of gradient checkpointing.
- **use_8bit_adam**: Use of the 8-bit Adam optimizer.
- **seed**: The random seed for reproducibility.
- **output_dir**: The directory for saving results.
- **push_to_hub**: Whether to upload the results to the Hugging Face Hub.
- **repo_id**: The repository ID on Hugging Face Hub.

Methods
^^^^^^^

- **prepare_mask_and_masked_image(image, mask)**: Prepares the mask and masked image.
- **random_mask(im_shape, ratio=1, mask_full_image=False)**: Generates a random mask.
- **load_args_for_training()**: Loads the necessary components of the model for training.
- **collate_fn(examples)**: Forms a batch of data for the model.
- **__call__(self, *args, **kwargs)**: The main method for running the training process.

Usage
-----

To start training, you should create an instance of the ``StableDiffusionInpaintingFineTune`` class and call its ``__call__`` method, passing the necessary arguments.

.. code-block:: python

   model = StableDiffusionInpaintingFineTune(
       pretrained_model_name_or_path="path_to_model",
       resolution=512,
       center_crop=True,
       ...
   )

   model()

License
-------

The project is distributed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dreamfinetune-1.5.tar.gz (17.0 kB view details)

Uploaded Source

Built Distribution

dreamfinetune-1.5-py3-none-any.whl (19.9 kB view details)

Uploaded Python 3

File details

Details for the file dreamfinetune-1.5.tar.gz.

File metadata

  • Download URL: dreamfinetune-1.5.tar.gz
  • Upload date:
  • Size: 17.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.4

File hashes

Hashes for dreamfinetune-1.5.tar.gz
Algorithm Hash digest
SHA256 350bef3fa10162a107295363cb32509122b253bdabf9481c7fbdb908d2811cf3
MD5 e4f8269f4c86f33708662c6d827f9773
BLAKE2b-256 dca2ff805d93aa609042d01a0e383c19a07172a3bc1345ea00dd3161a060c4bc

See more details on using hashes here.

File details

Details for the file dreamfinetune-1.5-py3-none-any.whl.

File metadata

  • Download URL: dreamfinetune-1.5-py3-none-any.whl
  • Upload date:
  • Size: 19.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.4

File hashes

Hashes for dreamfinetune-1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 de14b16154a840b1aede4d9ecb9f9f5f2649f1d488c5d8a52d2a19a6de3cdf61
MD5 c0a6c0d01414e7ee36d143070739a218
BLAKE2b-256 d84206e66a10af3d54f2a13c70e6347c71bc1f0c7266fcbb97cc965e84ee0a83

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page