DRLX is a library for distributed training of diffusion models via RL
Project description
Diffuser Reinforcement Learning X
DRLX is a library for distributed training of diffusion models via RL. It is meant to wrap around 🤗 Hugging Face's Diffusers library and uses Accelerate for Multi-GPU and Multi-Node (as of yet untested)
Setup
You can install the library from pypi:
pip install drlx
or from source:
pip install git+https://github.com/CarperAI/DRLX.git
How to use
Currently we have only tested the library with StableDiffusion 1.4, but the plug and play nature of it means that realistically any denoiser from any pipeline should be usable. Models saved with DRLX are compatible with the pipeline they originated from and can be loaded like any other pretrained model. Currently the only algorithm supported for training is DDPO.
from drlx.reward_modelling.aesthetics import Aesthetics
from drlx.pipeline.pickapic_prompts import PickAPicPrompts
from drlx.trainer.ddpo_trainer import DDPOTrainer
from drlx.configs import DRLXConfig
# We import a reward model, a prompt pipeline, the trainer and config
pipe = PickAPicPrompts()
config = DRLXConfig.load_yaml("configs/my_cfg.yml")
trainer = DDPOTrainer(config)
trainer.train(pipe, Aesthetics())
And then to use a trained model for inference:
pipe = StableDiffusionPipeline.from_pretrained("out/ddpo_exp")
prompt = "A mad panda scientist"
image = pipe(prompt).images[0]
image.save("test.jpeg")
Accelerated Training
accelerate config
accelerate launch -m [your module]
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file drlx-0.0.2.tar.gz
.
File metadata
- Download URL: drlx-0.0.2.tar.gz
- Upload date:
- Size: 32.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9a38b62e1261fc12c2f260ab992ff158bf5b37c8ddda27b2975ca4fd33d7c4fa |
|
MD5 | de449e801e813b275ba871f55da37be4 |
|
BLAKE2b-256 | 6237c1f03cd241f94cf958e61578bc8d1c9553096f9b1a2e1951aab169fa7f13 |
File details
Details for the file drlx-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: drlx-0.0.2-py3-none-any.whl
- Upload date:
- Size: 30.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3db6a1c395832b6518898a4548753aceea502f314343124600cd14f57265ca13 |
|
MD5 | 7dcd951c8bc8f9b17281149915110eca |
|
BLAKE2b-256 | 2cd399dcc44374035f7665499f9b0afd639c109edb7f4c14793f4c04d5754571 |