Skip to main content

Toolchain for creating and training Stable Diffusion models with custom datasets

Project description

Dataset Rising

A toolchain for creating and training Stable Diffusion 1.x, Stable Diffusion 2.x, and Stable Diffusion XL models with custom datasets.

With this toolchain, you can:

  • Crawl and download metadata and images from 'booru' style image boards
  • Combine multiple sources of images (including your own custom sources)
  • Build datasets based on your personal preferences and filters
  • Train Stable Diffusion models with your datasets
  • Convert models into Stable Diffusion WebUI compatible models
  • Use only the parts you need – the toolchain uses modular design, YAML configuration files, and JSONL data exchange formats
  • Work with confidence that the end-to-end tooling has been tested with Nvidia's RTX30x0, RTX40x0, A100, and H100 GPUs

Requirements

  • Python >=3.8
  • Docker >=22.0.0

Tested With

  • MacOS 13 (M1)
  • Ubuntu 22 (x86_64)

Full Example

Below is a summary of each step in dataset generation process. For a full production-quality example, see e621-rising-configs (NSFW).

0. Installation

# install DatasetRising
pip3 install DatasetRising

# start MongoDB database; use `dr-db-down` to stop
dr-db-up

1. Download Metadata (Posts, Tags, ...)

Dataset Rising has a crawler (dr-crawl) to download metadata (=posts and tags) from booru-style image boards.

You must select a unique user agent string for your crawler (--agent AGENT_STRING). This string will be passed to the image board with every HTTP request. If you don't pick a user agent that uniquely identifies you, the image boards will likely block your requests. For example:

--agent 'my-imageboard-crawler/1.0 (user @my-username-on-the-imageboard)'

The crawler will automatically manage rate limits and retries. If you want to automatically resume a previous (failed) crawl, use --recover.

## download tag metadata to /tmp/tags.jsonl
dr-crawl --output /tmp/e962-tags.jsonl --type tags --source e926 --recover --agent '<AGENT_STRING>'

## download posts metadata to /tmp/e926.net-posts.jsonl
dr-crawl --output /tmp/e926.net-posts.jsonl --type posts --source e926 --recover --agent '<AGENT_STRING>'

2. Import Metadata

This section requires a running MongoDB database, which you can start with dr-db-up command.

Once you have enough post and tag metadata, it's time to import the data into a database.

Dataset Rising uses MongoDB as a store for the post and tag metadata. Use dr-import to import the metadata downloaded in the previous step into MongoDB.

If you want to adjust how the tag metadata is treated during the import, review files in <dataset-rising>/examples/tag_normalizer and set the optional parameters --prefilter FILE, --rewrites FILE, --aspect-ratios FILE, --category-weights FILE, and --symbols FILE accordingly.

dr-import --tags /tmp/e926.net-tags.jsonl --posts /tmp/e926.net-posts.jsonl --source e926

3. Preview Selectors

This section requires a running MongoDB database, which you can start with dr-db-up command.

After the metadata has been imported into a database, you can use selector files to select a subset of the posts in a dataset.

Your goal is not to include all images, but to produce a set of high quality samples. The selectors are the mechanism for that.

Each selector contains a positive and negative list of tags. A post will be included by the selector, if it contains at least one tag from the positive list and none of the tags in the negative list.

Note that a great dataset will contain positive and negative examples. If you only train your dataset with positive samples, your model will not be able to use negative prompts well. That's why the examples below include four different types of selectors.

Dataset Rising has example selectors available in <dataset-rising>/examples/select.

To make sure your selectors are producing the kind of samples you want, use the dr-preview script:

# generate a HTML preview of how the selector will perform (note: --aggregate is required):
dr-preview --selector ./examples/select/tier-1/tier-1.yaml --output /tmp/preview/tier-1 --limit 1000 --output --aggregate

# generate a HTML preview of how each sub-selector will perform:
dr-preview --selector ./examples/select/tier-1/helpers/artists.yaml --output /tmp/preview/tier-1-artists

4. Select Images For a Dataset

This section requires a running MongoDB database, which you can start with dr-db-up command.

When you're confident that the selectors are producing the right kind of samples, it's time to select the posts for building a dataset. Use dr-select to select posts from the database and store them in a JSONL file.

cd <dataset-rising>/database

dr-select --selector ./examples/select/tier-1/tier-1.yaml --output /tmp/tier-1.jsonl
dr-select --selector ./examples/select/tier-2/tier-2.yaml --output /tmp/tier-2.jsonl

5. Build a Dataset

After selecting the posts for the dataset, use dr-join to combine the selections and dr-build to download the images and build the actual dataset.

By default, the build script prunes all tags that have fewer than 100 samples. To adjust this limit, use --min-posts-per-tag LIMIT.

The build script will also prune all images that have fewer than 10 tags. To adjust this limit, use --min-tags-per-post LIMIT.

Adding a percentage at the end of a --source tells the build script to pick that many samples of the total dataset from the given source, e.g. --source ./my.jsonl:50%.

dr-join \
  --samples '/tmp/tier-1.jsonl:80%' \
  --samples '/tmp/tier-2.jsonl:20%' \
  --output '/tmp/joined.jsonl'

dr-build \
  --source '/tmp/joined.jsonl' \
  --output '/tmp/my-dataset' \
  --upload-to-hf 'username/dataset-name' \
  --upload-to-s3 's3://some-bucket/some/path'

6. Train a Model

The dataset built by the dr-build script is ready to be used for training as is. Dataset Rising uses Huggingface Accelerate to train Stable Diffusion models.

To train a model, you will need to pick a base model to start from. The --base-model can be any Diffusers compatible model, such as:

Note that your training results will be improved significantly if you set --image_width and --image_height to match the resolution the base model was trained with.

This example does not scale to multiple GPUs. See the Advanced Topics section for multi-GPU training.

dr-train \
  --pretrained-model-name-or-path 'stabilityai/stable-diffusion-xl-base-1.0' \
  --dataset-name 'username/dataset-name' \
  --output '/tmp/dataset-rising-v3-model' \
  --resolution 1024 \
  --maintain-aspect-ratio \
  --reshuffle-tags \
  --tag-separator ' ' \
  --random-flip \
  --train-batch-size 32 \
  --learning-rate 4e-6 \
  --use-ema \
  --max-grad-norm 1 \
  --checkpointing-steps 1000 \
  --lr-scheduler constant \
  --lr-warmup-steps 0

7. Generate Samples

After training, you can use the dr-generate script to verify that the model is working as expected.

dr-generate \
  --model '/tmp/dataset-rising-v3-model' \
  --output '/tmp/samples' \
  --prompt 'cat playing chess with a horse' \
  --samples 100 \

8. Use the Model with Stable Diffusion WebUI

In order to use the model with Stable Diffusion WebUI, it has to be converted to the safetensors format.

# Stable Diffusion XL models:
dr-convert-sdxl \
  --model_path '/tmp/dataset-rising-v3-model' \
  --checkpoint_path '/tmp/dataset-rising-v3-model.safetensors' \
  --use_safetensors

# Other Stable Diffusion models:
dr-convert-sd \
  --model_path '/tmp/dataset-rising-v3-model' \
  --checkpoint_path '/tmp/dataset-rising-v3-model.safetensors' \
  --use_safetensors
  
# Copy the model to the WebUI models directory:
cp '/tmp/dataset-rising-v3-model.safetensors' '<webui-root>/models/Stable-diffusion'

# Copy the model configuration file to WebUI models directory:
cp '/tmp/dataset-rising-v3-model.yaml' '<webui-root>/models/Stable-diffusion'

Uninstall

The only part of Dataset Rising that requires uninstallation is the MongoDB database. You can uninstall the database with the following commands:

# Shut down MongoDB instance
dr-db-down

# Remove MongoDB container and its data -- warning! data loss will occur
dr-db-uninstall

Advanced Topics

Resetting the Database

To reset the database, run the following commands.

Warning: You will lose all data in the database.

dr-db-uninstall && dr-db-up && dr-db-create

Importing Posts from Multiple Sources

The append script allows you to import posts from additional sources.

Use import to import the first source and define the tag namespace, then use append to import additional sources.

# main sources and tags
dr-import ...

# additional sources
dr-append --input /tmp/gelbooru-posts.jsonl --source gelbooru

Multi-GPU Training

Multi-GPU training can be carried out with Huggingface Accelerate library.

Before training, run accelerate config to set up your Multi-GPU environment.

cd <dataset-rising>/train

# set up environment
accelerate config

# run training
accelerate launch \
  --multi_gpu \
  --mixed_precision=${PRECISION} \
  dr_train.py \
    --pretrained-model-name-or-path 'stabilityai/stable-diffusion-xl-base-1.0' \
    --dataset-name 'username/dataset-name' \
    --resolution 1024 \
    --maintain-aspect-ratio \
    --reshuffle-tags \
    --tag-separator ' ' \
    --random-flip \
    --train-batch-size 32 \
    --learning-rate 4e-6 \
    --use-ema \
    --max-grad-norm 1 \
    --checkpointing-steps 1000 \
    --lr-scheduler constant \
    --lr-warmup-steps 0

Setting Up a Training Machine

  • Install dataset-rising
  • Install Huggingface CLI
  • Install Accelerate CLI
  • Configure Huggingface CLI (huggingface-cli login)
  • Configure Accelerate CLI (accelerate config)

Optional

Troubleshooting

NCCL Errors

Some configurations will require NCCL_P2P_DISABLE=1 and/or NCCL_IB_DISABLE=1 environment variables to be set.

export NCCL_P2P_DISABLE=1
export NCCL_IB_DISABLE=1

dr-train ...

Cache Directories

Use HF_DATASETS_CACHE and HF_MODULES_CACHE to control where Huggingface stores its cache files

export HF_DATASETS_CACHE=/workspace/cache/huggingface/datasets
export HF_MODULES_CACHE=/workspace/cache/huggingface/modules

dr-train ...

Developers

Setting Up

Creates a virtual environment, installs packages, and sets up a MongoDB database on Docker.

cd <dataset-rising>
./up.sh

Shutting Down

Stops the MongoDB database container. The database can be restarted by running ./up.sh again.

cd <dataset-rising>
./down.sh

Uninstall

Warning: This step removes the MongoDB database container and all data stored on it.

cd <dataset-rising>
./uninstall.sh

Deployments

python3 -m pip install --upgrade build twine
python3 -m build 
python3 -m twine upload dist/*

Architecture

flowchart TD
    CRAWL[Crawl/Download posts, tags, and tag aliases] -- JSONL --> IMPORT
    IMPORT[Import posts, tags, and tag aliases] --> STORE
    APPEND[Append additional posts] --> STORE
    STORE[Database] --> PREVIEW
    STORE --> SELECT1
    STORE --> SELECT2
    STORE --> SELECT3
    PREVIEW[Preview selectors] --> HTML(HTML)
    SELECT1[Select samples] -- JSONL --> JOIN
    SELECT2[Select samples] -- JSONL --> JOIN
    SELECT3[Select samples] -- JSONL --> JOIN
    JOIN[Join and prune samples] -- JSONL --> BUILD 
    BUILD[Build dataset] -- HF Dataset/Parquet --> TRAIN
    TRAIN[Train model] --> MODEL[Model]

Links

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

DatasetRising-1.0.4.tar.gz (154.3 kB view details)

Uploaded Source

Built Distribution

DatasetRising-1.0.4-py3-none-any.whl (177.1 kB view details)

Uploaded Python 3

File details

Details for the file DatasetRising-1.0.4.tar.gz.

File metadata

  • Download URL: DatasetRising-1.0.4.tar.gz
  • Upload date:
  • Size: 154.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for DatasetRising-1.0.4.tar.gz
Algorithm Hash digest
SHA256 17d7d2f58cfc42a7b361c36d67141fd6ad9426b3510d8a71685834b63f000ecb
MD5 dbf2864f9a633d44cf5b7a6d1ff66e70
BLAKE2b-256 28cfce4c6cb823db02fd3a03f7d696dcd63e63eb56f8fa54ad74b063375a40ca

See more details on using hashes here.

File details

Details for the file DatasetRising-1.0.4-py3-none-any.whl.

File metadata

File hashes

Hashes for DatasetRising-1.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 0cf7063f75b6ce7ed269b45af6c9b2e61b2db453fe3974c4b128867131ec647a
MD5 f62795a94c34bb5a76c08451ab24dc48
BLAKE2b-256 ebd3dac3d05f5bdcd5f191319b1ddecf88f615c8b1affb50cd856cfbcc08c2ca

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page