Skip to main content

DRLX is a library for distributed training of diffusion models via RL

Project description

Diffuser Reinforcement Learning X

DRLX is a library for distributed training of diffusion models via RL. It is meant to wrap around 🤗 Hugging Face's Diffusers library and uses Accelerate for Multi-GPU and Multi-Node (as of yet untested)

📖 Documentation

Setup

You can install the library from pypi:

pip install drlx

or from source:

pip install git+https://github.com/CarperAI/DRLX.git

How to use

Currently we have only tested the library with StableDiffusion 1.4, but the plug and play nature of it means that realistically any denoiser from any pipeline should be usable. Models saved with DRLX are compatible with the pipeline they originated from and can be loaded like any other pretrained model. Currently the only algorithm supported for training is DDPO.

from drlx.reward_modelling.aesthetics import Aesthetics
from drlx.pipeline.pickapic_prompts import PickAPicPrompts
from drlx.trainer.ddpo_trainer import DDPOTrainer
from drlx.configs import DRLXConfig

# We import a reward model, a prompt pipeline, the trainer and config

pipe = PickAPicPrompts()
config = DRLXConfig.load_yaml("configs/my_cfg.yml")
trainer = DDPOTrainer(config)

trainer.train(pipe, Aesthetics())

And then to use a trained model for inference:

pipe = StableDiffusionPipeline.from_pretrained("out/ddpo_exp")
prompt = "A mad panda scientist"
image = pipe(prompt).images[0]
image.save("test.jpeg")

Accelerated Training

accelerate config
accelerate launch -m [your module]

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

drlx-0.0.2.tar.gz (32.5 kB view details)

Uploaded Source

Built Distribution

drlx-0.0.2-py3-none-any.whl (30.0 kB view details)

Uploaded Python 3

File details

Details for the file drlx-0.0.2.tar.gz.

File metadata

  • Download URL: drlx-0.0.2.tar.gz
  • Upload date:
  • Size: 32.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.8

File hashes

Hashes for drlx-0.0.2.tar.gz
Algorithm Hash digest
SHA256 9a38b62e1261fc12c2f260ab992ff158bf5b37c8ddda27b2975ca4fd33d7c4fa
MD5 de449e801e813b275ba871f55da37be4
BLAKE2b-256 6237c1f03cd241f94cf958e61578bc8d1c9553096f9b1a2e1951aab169fa7f13

See more details on using hashes here.

File details

Details for the file drlx-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: drlx-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 30.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.8

File hashes

Hashes for drlx-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 3db6a1c395832b6518898a4548753aceea502f314343124600cd14f57265ca13
MD5 7dcd951c8bc8f9b17281149915110eca
BLAKE2b-256 2cd399dcc44374035f7665499f9b0afd639c109edb7f4c14793f4c04d5754571

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page