Simple, multi-GPU diffusion model fine-tuning library (Schneewolf Labs)
Project description
๐จ Atelier ๐จ
A simple, multi-GPU diffusion model fine-tuning library. One training loop, pluggable adapters and loss functions.
Sister project to Grimoire (LLM fine-tuning). Both serve as training engines for Merlina.
Why
Diffusion model training scripts tend to be monolithic โ model loading, data processing, the training loop, and architecture-specific forward passes all tangled together. Switching from SDXL to Qwen-Image-Edit means rewriting the whole script.
Atelier separates what varies (model architecture, training objective) from what doesn't (the training loop, multi-GPU, checkpointing, logging). Adding a new model means writing an adapter. Adding a new training objective means writing a loss function. The trainer never changes.
Install
pip install -e .
# With optional dependencies
pip install -e ".[quantization]" # bitsandbytes for 8-bit optimizers
pip install -e ".[logging]" # wandb
pip install -e ".[all]" # everything
Quick start
Qwen-Image-Edit LoRA (flow matching)
from peft import LoraConfig
from atelier import AtelierTrainer, TrainingConfig
from atelier.adapters import QwenEditAdapter
from atelier.losses import FlowMatchingLoss
from atelier.data import EditingDataset, cache_embeddings
# Load adapter (handles model, VAE, text encoder, scheduler)
adapter = QwenEditAdapter("Qwen/Qwen-Image-Edit")
# Pre-compute embeddings to save VRAM during training
text_emb, target_emb, control_emb = cache_embeddings(
raw_dataset, adapter, cache_dir="./output/cache",
)
adapter.free_encoders() # reclaim VRAM
dataset = EditingDataset(
raw_dataset,
cached_text_embeddings=text_emb,
cached_target_embeddings=target_emb,
cached_control_embeddings=control_emb,
)
trainer = AtelierTrainer(
adapter=adapter,
config=TrainingConfig(
output_dir="./output",
num_epochs=50,
batch_size=1,
learning_rate=1e-4,
gradient_accumulation_steps=2,
),
loss_fn=FlowMatchingLoss(),
train_dataset=dataset,
peft_config=LoraConfig(
r=64,
lora_alpha=128,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
),
)
trainer.train()
trainer.save_model("./my-lora")
Qwen-Image LoRA (text-to-image, flow matching)
Same loss as Qwen-Image-Edit; the adapter differs because the text encoder is not vision-conditioned and the transformer sees only the noised target (no control image concat).
from peft import LoraConfig
from atelier import AtelierTrainer, TrainingConfig
from atelier.adapters import QwenImageAdapter
from atelier.losses import FlowMatchingLoss
from atelier.data import EditingDataset, cache_embeddings
adapter = QwenImageAdapter("Qwen/Qwen-Image")
# Dataset only needs (prompt, chosen) โ no "rejected" column.
text_emb, target_emb, _ = cache_embeddings(
raw_dataset, adapter, cache_dir="./output/cache",
)
adapter.free_encoders()
dataset = EditingDataset(
raw_dataset,
cached_text_embeddings=text_emb,
cached_target_embeddings=target_emb,
)
trainer = AtelierTrainer(
adapter=adapter,
config=TrainingConfig(output_dir="./output", num_epochs=8, batch_size=1),
loss_fn=FlowMatchingLoss(),
train_dataset=dataset,
peft_config=LoraConfig(
r=32, lora_alpha=64,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
),
)
trainer.train()
trainer.save_model("./my-qwen-image-lora")
SDXL DPO (preference optimization)
Same trainer, different adapter and loss function.
from atelier.adapters import SDXLAdapter
from atelier.losses import DiffusionDPOLoss
from atelier.data import GenerationDataset
adapter = SDXLAdapter(
"stabilityai/stable-diffusion-xl-base-1.0",
weights="/path/to/model.safetensors",
)
adapter.freeze_layers(strategy="color_blocks", layers="0,1")
dataset = GenerationDataset(
raw_dataset,
tokenizer=adapter.tokenizer,
tokenizer_2=adapter.tokenizer_2,
)
trainer = AtelierTrainer(
adapter=adapter,
config=TrainingConfig(
output_dir="./output",
num_epochs=10,
batch_size=1,
learning_rate=2e-6,
optimizer="adamw_8bit",
mixed_precision="fp16",
),
loss_fn=DiffusionDPOLoss(beta=0.4, sft_weight=0.3),
train_dataset=dataset,
)
trainer.train()
trainer.save_model("./my-sdxl")
With LoRA
Pass a peft_config and Atelier handles the rest.
from peft import LoraConfig
trainer = AtelierTrainer(
adapter=adapter,
config=TrainingConfig(...),
loss_fn=FlowMatchingLoss(),
train_dataset=dataset,
peft_config=LoraConfig(
r=64,
lora_alpha=128,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
),
)
Guides
- Loss Formulas โ Math for flow matching and diffusion DPO
- Adapters โ Writing a custom adapter for a new model architecture
- Callbacks โ Hooking into the training loop
- Multi-GPU and DeepSpeed โ Distributed training setup
YAML config + CLI
For orchestrators (e.g. Merlina) or when you just don't want to write a Python wrapper per run, train from a YAML config:
pip install -e ".[yaml]"
python -m atelier.train --config configs/qwen_image_lora_example.yaml
# Override anything on the CLI (JSON-decoded values):
python -m atelier.train --config configs/my.yaml \
--set training.num_epochs=2 \
--set training.output_dir=./output-quick \
--set 'peft.target_modules=["to_q","to_v"]'
The YAML schema mirrors the Python API one-for-one โ model.adapter picks an
adapter from atelier.registry.ADAPTERS, loss.type picks from LOSSES,
peft becomes a LoraConfig, training becomes a TrainingConfig, and
dataset accepts an HF hub name, a local JSONL, or a load_from_disk path.
See atelier/train.py for the full schema and configs/qwen_image_lora_example.yaml
for a worked example.
Multi-GPU
No code changes. Configure with accelerate and launch:
accelerate config
accelerate launch --multi_gpu --num_processes 4 train.py
accelerate launch --use_deepspeed --deepspeed_config ds_config.json train.py
Callbacks
Subclass TrainerCallback and override the hooks you need:
from atelier import TrainerCallback
class MyCallback(TrainerCallback):
def on_step_end(self, trainer, step, loss, metrics):
if should_stop():
trainer.request_stop()
def on_log(self, trainer, metrics):
print(f"Step {trainer.global_step}: {metrics}")
trainer = AtelierTrainer(..., callbacks=[MyCallback()])
Available hooks: on_train_begin, on_train_end, on_epoch_begin, on_epoch_end, on_step_end, on_log, on_evaluate, on_save.
Configuration
TrainingConfig fields with defaults:
| Field | Default | Description |
|---|---|---|
output_dir |
"./output" |
Checkpoints and saved models |
num_epochs |
3 |
Number of training epochs |
batch_size |
1 |
Per-device batch size |
gradient_accumulation_steps |
1 |
Steps before optimizer update |
learning_rate |
1e-4 |
Peak learning rate |
weight_decay |
0.01 |
L2 regularization |
warmup_ratio |
0.1 |
Fraction of steps for LR warmup |
warmup_steps |
0 |
Overrides warmup_ratio if > 0 |
max_grad_norm |
1.0 |
Gradient clipping |
mixed_precision |
"bf16" |
"no", "fp16", or "bf16" |
gradient_checkpointing |
True |
Trade compute for memory |
optimizer |
"adamw" |
See supported optimizers below |
lr_scheduler |
"cosine" |
"linear", "cosine", "constant", "constant_with_warmup" |
logging_steps |
10 |
Log metrics every N steps |
eval_steps |
None |
Evaluate every N steps |
save_steps |
None |
Checkpoint every N steps |
save_total_limit |
2 |
Max checkpoints to keep |
save_on_epoch_end |
True |
Checkpoint after each epoch |
resume_from_checkpoint |
None |
Path to resume from |
seed |
42 |
Random seed |
log_with |
None |
"wandb" for W&B tracking |
Supported optimizers: adamw, adamw_8bit, paged_adamw_8bit, adafactor, sgd
Architecture
atelier/
โโโ trainer.py # AtelierTrainer โ the training loop
โโโ config.py # TrainingConfig dataclass
โโโ callbacks.py # TrainerCallback base class
โโโ adapters/
โ โโโ base.py # ModelAdapter protocol
โ โโโ qwen_edit.py # Qwen-Image-Edit (DiT + flow matching, image-conditioned)
โ โโโ qwen_image.py # Qwen-Image (DiT + flow matching, text-to-image)
โ โโโ sdxl.py # SDXL (UNet + DDPM)
โโโ losses/
โ โโโ flow_matching.py # Flow matching MSE
โ โโโ diffusion_dpo.py # DPO + SFT regularization
โโโ data/
โโโ editing.py # Paired image editing dataset
โโโ generation.py # Text-to-image dataset
โโโ cache.py # Embedding pre-computation
How it fits together
The adapter encapsulates everything that varies per model architecture โ loading, encoding, the forward pass, and saving. In Grimoire (LLM training), every model has the same forward signature (model(input_ids) โ logits). In diffusion training, forward passes vary wildly: Qwen-Image-Edit needs latent packing, control image concatenation, and RoPE shapes; SDXL needs dual CLIP conditioning and time embeddings. The adapter hides this.
The loss function orchestrates the training objective โ sampling noise and timesteps, calling the adapter's forward pass, and computing the loss. Flow matching predicts the velocity field; DPO compares noise predictions for chosen vs rejected images.
The trainer owns the loop โ optimizer, gradient accumulation, checkpointing, logging. It calls loss_fn(adapter, model, batch) and never needs to know what model architecture or training objective is being used.
Loss function interface
class MyLoss:
def __call__(self, adapter, model, batch, training=True):
# Use adapter for noise sampling, forward pass, target computation
return loss, metrics_dict
def create_collator(self):
return MyCollator()
Adapter interface
class MyAdapter(ModelAdapter):
def model(self): ... # The trainable model
def encode_images(self): ... # VAE encode
def encode_text(self): ... # Text encode
def sample_timesteps(self): ... # Timestep sampling
def add_noise(self): ... # Create noisy input
def compute_target(self): ... # What model should predict
def forward(self): ... # Architecture-specific forward
def save_lora(self): ... # Save LoRA weights
def save_model(self): ... # Save full model
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file atelier_diffusion-0.1.0.tar.gz.
File metadata
- Download URL: atelier_diffusion-0.1.0.tar.gz
- Upload date:
- Size: 48.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b2e4bb50602f294cfed3dc33e0148f44063c9522844d483af7407298f593f5b4
|
|
| MD5 |
8e732071576bb9981c0a54e709f1811c
|
|
| BLAKE2b-256 |
43645b890131b20f92013a5c1ee0fdbea5d86b1b0fd5b7a01e38941c045d994d
|
Provenance
The following attestation bundles were made for atelier_diffusion-0.1.0.tar.gz:
Publisher:
release.yml on Schneewolf-Labs/atelier
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
atelier_diffusion-0.1.0.tar.gz -
Subject digest:
b2e4bb50602f294cfed3dc33e0148f44063c9522844d483af7407298f593f5b4 - Sigstore transparency entry: 1671861398
- Sigstore integration time:
-
Permalink:
Schneewolf-Labs/atelier@09db08e63dc2172bda62e5782e90d449d84c4980 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/Schneewolf-Labs
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@09db08e63dc2172bda62e5782e90d449d84c4980 -
Trigger Event:
push
-
Statement type:
File details
Details for the file atelier_diffusion-0.1.0-py3-none-any.whl.
File metadata
- Download URL: atelier_diffusion-0.1.0-py3-none-any.whl
- Upload date:
- Size: 47.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ab1c65fa0c1eb316e2cac6d86fce2c30048e4fc417d26b708a0027b51b34ce6e
|
|
| MD5 |
f8fab0da29164147fc64226d4ba1135a
|
|
| BLAKE2b-256 |
d45c13dda4c6a574d80f557bc91e89c0f6cbf147512361814dba68bcdeff88fc
|
Provenance
The following attestation bundles were made for atelier_diffusion-0.1.0-py3-none-any.whl:
Publisher:
release.yml on Schneewolf-Labs/atelier
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
atelier_diffusion-0.1.0-py3-none-any.whl -
Subject digest:
ab1c65fa0c1eb316e2cac6d86fce2c30048e4fc417d26b708a0027b51b34ce6e - Sigstore transparency entry: 1671861437
- Sigstore integration time:
-
Permalink:
Schneewolf-Labs/atelier@09db08e63dc2172bda62e5782e90d449d84c4980 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/Schneewolf-Labs
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@09db08e63dc2172bda62e5782e90d449d84c4980 -
Trigger Event:
push
-
Statement type: