Skip to main content

Simple, multi-GPU diffusion model fine-tuning library (Schneewolf Labs)

Project description

๐ŸŽจ Atelier ๐Ÿ”จ

A simple, multi-GPU diffusion model fine-tuning library. One training loop, pluggable adapters and loss functions.

Sister project to Grimoire (LLM fine-tuning). Both serve as training engines for Merlina.

Why

Diffusion model training scripts tend to be monolithic โ€” model loading, data processing, the training loop, and architecture-specific forward passes all tangled together. Switching from SDXL to Qwen-Image-Edit means rewriting the whole script.

Atelier separates what varies (model architecture, training objective) from what doesn't (the training loop, multi-GPU, checkpointing, logging). Adding a new model means writing an adapter. Adding a new training objective means writing a loss function. The trainer never changes.

Install

pip install -e .

# With optional dependencies
pip install -e ".[quantization]"   # bitsandbytes for 8-bit optimizers
pip install -e ".[logging]"        # wandb
pip install -e ".[all]"            # everything

Quick start

Qwen-Image-Edit LoRA (flow matching)

from peft import LoraConfig
from atelier import AtelierTrainer, TrainingConfig
from atelier.adapters import QwenEditAdapter
from atelier.losses import FlowMatchingLoss
from atelier.data import EditingDataset, cache_embeddings

# Load adapter (handles model, VAE, text encoder, scheduler)
adapter = QwenEditAdapter("Qwen/Qwen-Image-Edit")

# Pre-compute embeddings to save VRAM during training
text_emb, target_emb, control_emb = cache_embeddings(
    raw_dataset, adapter, cache_dir="./output/cache",
)
adapter.free_encoders()  # reclaim VRAM

dataset = EditingDataset(
    raw_dataset,
    cached_text_embeddings=text_emb,
    cached_target_embeddings=target_emb,
    cached_control_embeddings=control_emb,
)

trainer = AtelierTrainer(
    adapter=adapter,
    config=TrainingConfig(
        output_dir="./output",
        num_epochs=50,
        batch_size=1,
        learning_rate=1e-4,
        gradient_accumulation_steps=2,
    ),
    loss_fn=FlowMatchingLoss(),
    train_dataset=dataset,
    peft_config=LoraConfig(
        r=64,
        lora_alpha=128,
        target_modules=["to_k", "to_q", "to_v", "to_out.0"],
    ),
)

trainer.train()
trainer.save_model("./my-lora")

Qwen-Image LoRA (text-to-image, flow matching)

Same loss as Qwen-Image-Edit; the adapter differs because the text encoder is not vision-conditioned and the transformer sees only the noised target (no control image concat).

from peft import LoraConfig
from atelier import AtelierTrainer, TrainingConfig
from atelier.adapters import QwenImageAdapter
from atelier.losses import FlowMatchingLoss
from atelier.data import EditingDataset, cache_embeddings

adapter = QwenImageAdapter("Qwen/Qwen-Image")

# Dataset only needs (prompt, chosen) โ€” no "rejected" column.
text_emb, target_emb, _ = cache_embeddings(
    raw_dataset, adapter, cache_dir="./output/cache",
)
adapter.free_encoders()

dataset = EditingDataset(
    raw_dataset,
    cached_text_embeddings=text_emb,
    cached_target_embeddings=target_emb,
)

trainer = AtelierTrainer(
    adapter=adapter,
    config=TrainingConfig(output_dir="./output", num_epochs=8, batch_size=1),
    loss_fn=FlowMatchingLoss(),
    train_dataset=dataset,
    peft_config=LoraConfig(
        r=32, lora_alpha=64,
        target_modules=["to_k", "to_q", "to_v", "to_out.0"],
    ),
)
trainer.train()
trainer.save_model("./my-qwen-image-lora")

SDXL DPO (preference optimization)

Same trainer, different adapter and loss function.

from atelier.adapters import SDXLAdapter
from atelier.losses import DiffusionDPOLoss
from atelier.data import GenerationDataset

adapter = SDXLAdapter(
    "stabilityai/stable-diffusion-xl-base-1.0",
    weights="/path/to/model.safetensors",
)
adapter.freeze_layers(strategy="color_blocks", layers="0,1")

dataset = GenerationDataset(
    raw_dataset,
    tokenizer=adapter.tokenizer,
    tokenizer_2=adapter.tokenizer_2,
)

trainer = AtelierTrainer(
    adapter=adapter,
    config=TrainingConfig(
        output_dir="./output",
        num_epochs=10,
        batch_size=1,
        learning_rate=2e-6,
        optimizer="adamw_8bit",
        mixed_precision="fp16",
    ),
    loss_fn=DiffusionDPOLoss(beta=0.4, sft_weight=0.3),
    train_dataset=dataset,
)

trainer.train()
trainer.save_model("./my-sdxl")

With LoRA

Pass a peft_config and Atelier handles the rest.

from peft import LoraConfig

trainer = AtelierTrainer(
    adapter=adapter,
    config=TrainingConfig(...),
    loss_fn=FlowMatchingLoss(),
    train_dataset=dataset,
    peft_config=LoraConfig(
        r=64,
        lora_alpha=128,
        target_modules=["to_k", "to_q", "to_v", "to_out.0"],
    ),
)

Guides

YAML config + CLI

For orchestrators (e.g. Merlina) or when you just don't want to write a Python wrapper per run, train from a YAML config:

pip install -e ".[yaml]"
python -m atelier.train --config configs/qwen_image_lora_example.yaml

# Override anything on the CLI (JSON-decoded values):
python -m atelier.train --config configs/my.yaml \
    --set training.num_epochs=2 \
    --set training.output_dir=./output-quick \
    --set 'peft.target_modules=["to_q","to_v"]'

The YAML schema mirrors the Python API one-for-one โ€” model.adapter picks an adapter from atelier.registry.ADAPTERS, loss.type picks from LOSSES, peft becomes a LoraConfig, training becomes a TrainingConfig, and dataset accepts an HF hub name, a local JSONL, or a load_from_disk path. See atelier/train.py for the full schema and configs/qwen_image_lora_example.yaml for a worked example.

Multi-GPU

No code changes. Configure with accelerate and launch:

accelerate config
accelerate launch --multi_gpu --num_processes 4 train.py
accelerate launch --use_deepspeed --deepspeed_config ds_config.json train.py

Callbacks

Subclass TrainerCallback and override the hooks you need:

from atelier import TrainerCallback

class MyCallback(TrainerCallback):
    def on_step_end(self, trainer, step, loss, metrics):
        if should_stop():
            trainer.request_stop()

    def on_log(self, trainer, metrics):
        print(f"Step {trainer.global_step}: {metrics}")

trainer = AtelierTrainer(..., callbacks=[MyCallback()])

Available hooks: on_train_begin, on_train_end, on_epoch_begin, on_epoch_end, on_step_end, on_log, on_evaluate, on_save.

Configuration

TrainingConfig fields with defaults:

Field Default Description
output_dir "./output" Checkpoints and saved models
num_epochs 3 Number of training epochs
batch_size 1 Per-device batch size
gradient_accumulation_steps 1 Steps before optimizer update
learning_rate 1e-4 Peak learning rate
weight_decay 0.01 L2 regularization
warmup_ratio 0.1 Fraction of steps for LR warmup
warmup_steps 0 Overrides warmup_ratio if > 0
max_grad_norm 1.0 Gradient clipping
mixed_precision "bf16" "no", "fp16", or "bf16"
gradient_checkpointing True Trade compute for memory
optimizer "adamw" See supported optimizers below
lr_scheduler "cosine" "linear", "cosine", "constant", "constant_with_warmup"
logging_steps 10 Log metrics every N steps
eval_steps None Evaluate every N steps
save_steps None Checkpoint every N steps
save_total_limit 2 Max checkpoints to keep
save_on_epoch_end True Checkpoint after each epoch
resume_from_checkpoint None Path to resume from
seed 42 Random seed
log_with None "wandb" for W&B tracking

Supported optimizers: adamw, adamw_8bit, paged_adamw_8bit, adafactor, sgd

Architecture

atelier/
โ”œโ”€โ”€ trainer.py           # AtelierTrainer โ€” the training loop
โ”œโ”€โ”€ config.py            # TrainingConfig dataclass
โ”œโ”€โ”€ callbacks.py         # TrainerCallback base class
โ”œโ”€โ”€ adapters/
โ”‚   โ”œโ”€โ”€ base.py          # ModelAdapter protocol
โ”‚   โ”œโ”€โ”€ qwen_edit.py     # Qwen-Image-Edit (DiT + flow matching, image-conditioned)
โ”‚   โ”œโ”€โ”€ qwen_image.py    # Qwen-Image (DiT + flow matching, text-to-image)
โ”‚   โ””โ”€โ”€ sdxl.py          # SDXL (UNet + DDPM)
โ”œโ”€โ”€ losses/
โ”‚   โ”œโ”€โ”€ flow_matching.py # Flow matching MSE
โ”‚   โ””โ”€โ”€ diffusion_dpo.py # DPO + SFT regularization
โ””โ”€โ”€ data/
    โ”œโ”€โ”€ editing.py       # Paired image editing dataset
    โ”œโ”€โ”€ generation.py    # Text-to-image dataset
    โ””โ”€โ”€ cache.py         # Embedding pre-computation

How it fits together

The adapter encapsulates everything that varies per model architecture โ€” loading, encoding, the forward pass, and saving. In Grimoire (LLM training), every model has the same forward signature (model(input_ids) โ†’ logits). In diffusion training, forward passes vary wildly: Qwen-Image-Edit needs latent packing, control image concatenation, and RoPE shapes; SDXL needs dual CLIP conditioning and time embeddings. The adapter hides this.

The loss function orchestrates the training objective โ€” sampling noise and timesteps, calling the adapter's forward pass, and computing the loss. Flow matching predicts the velocity field; DPO compares noise predictions for chosen vs rejected images.

The trainer owns the loop โ€” optimizer, gradient accumulation, checkpointing, logging. It calls loss_fn(adapter, model, batch) and never needs to know what model architecture or training objective is being used.

Loss function interface

class MyLoss:
    def __call__(self, adapter, model, batch, training=True):
        # Use adapter for noise sampling, forward pass, target computation
        return loss, metrics_dict

    def create_collator(self):
        return MyCollator()

Adapter interface

class MyAdapter(ModelAdapter):
    def model(self):            ...  # The trainable model
    def encode_images(self):    ...  # VAE encode
    def encode_text(self):      ...  # Text encode
    def sample_timesteps(self): ...  # Timestep sampling
    def add_noise(self):        ...  # Create noisy input
    def compute_target(self):   ...  # What model should predict
    def forward(self):          ...  # Architecture-specific forward
    def save_lora(self):        ...  # Save LoRA weights
    def save_model(self):       ...  # Save full model

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

atelier_diffusion-0.1.0.tar.gz (48.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

atelier_diffusion-0.1.0-py3-none-any.whl (47.3 kB view details)

Uploaded Python 3

File details

Details for the file atelier_diffusion-0.1.0.tar.gz.

File metadata

  • Download URL: atelier_diffusion-0.1.0.tar.gz
  • Upload date:
  • Size: 48.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for atelier_diffusion-0.1.0.tar.gz
Algorithm Hash digest
SHA256 b2e4bb50602f294cfed3dc33e0148f44063c9522844d483af7407298f593f5b4
MD5 8e732071576bb9981c0a54e709f1811c
BLAKE2b-256 43645b890131b20f92013a5c1ee0fdbea5d86b1b0fd5b7a01e38941c045d994d

See more details on using hashes here.

Provenance

The following attestation bundles were made for atelier_diffusion-0.1.0.tar.gz:

Publisher: release.yml on Schneewolf-Labs/atelier

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file atelier_diffusion-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for atelier_diffusion-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ab1c65fa0c1eb316e2cac6d86fce2c30048e4fc417d26b708a0027b51b34ce6e
MD5 f8fab0da29164147fc64226d4ba1135a
BLAKE2b-256 d45c13dda4c6a574d80f557bc91e89c0f6cbf147512361814dba68bcdeff88fc

See more details on using hashes here.

Provenance

The following attestation bundles were made for atelier_diffusion-0.1.0-py3-none-any.whl:

Publisher: release.yml on Schneewolf-Labs/atelier

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page