Skip to main content

The simplest way to train and run adapters on top of foundational models

Project description

Finegrain Refiners Library

The simplest way to train and run adapters on top of foundational models


PyPI - Python Version PyPI Status license

Motivation

At Finegrain, we're on a mission to automate product photography. Given our "no human in the loop approach", nailing the quality of the outputs we generate is paramount to our success.

That's why we're building Refiners.

It's a framework to easily bridge the last mile quality gap of foundational models like Stable Diffusion or Segment Anything Model (SAM), by adapting them to specific tasks with lightweight trainable and composable patches.

We decided to build Refiners in the open.

It's because model adaptation is a new paradigm that goes beyond our specific use cases. Our hope is to help people looking at creating their own adapters save time, whatever the foundation model they're using.

Design

We are huge fans of PyTorch (we actually were core committers to Torch in another life), but we felt it's too low level for the specific model adaptation task: PyTorch models are generally hard to understand, and their adaptation requires intricate ad hoc code.

Instead, we needed:

  • A model structure that's human readable so that you know what models do and how they work right here, right now
  • A mechanism to easily inject parameters in some target layers, or between them
  • A way to easily pass data (like a conditioning input) between layers even when deeply nested
  • Native support for iconic adapter types like LoRAs and their community trained incarnations (hosted on Civitai and the likes)

Refiners is designed to tackle all these challenges while remaining just one abstraction away from our beloved PyTorch.

Downsides

As they say, there is no free lunch. Given Refiners comes with a new model structure, it can only work with models implemented that way. For now, we support Stable Diffusion 1.5, but more is in the making (SDXL, SAM, ...) - stay tuned.

Overview

The Refiners library is made of:

  1. An abstraction layer (called Fluxion) on top of PyTorch to easily build models
  2. A zoo of compatible foundational models
  3. Adapter APIs to easily patch supported foundational models
  4. Training utils to train concrete adapters
  5. Conversion scripts to easily use existing community adapters

Key Concepts

The Chain class

The Chain class is at the core of Refiners. It basically lets you express models as a composition of basic layers (linear, convolution, attention, etc) in a declarative way.

E.g.: this is how a Vision Transformer (ViT) looks like with Refiners:

import torch
import refiners.fluxion.layers as fl

class ViT(fl.Chain):
    # The Vision Transformer model structure is entirely defined in the constructor. It is
    # ready-to-use right after i.e. no need to implement any forward function or add extra logic
    def __init__(
        self,
        embedding_dim: int = 512,
        patch_size: int = 16,
        image_size: int = 384,
        num_layers: int = 12,
        num_heads: int = 8,
    ):
        num_patches = (image_size // patch_size)
        super().__init__(
            fl.Conv2d(in_channels=3, out_channels=embedding_dim, kernel_size=patch_size, stride=patch_size),
            fl.Reshape(num_patches**2, embedding_dim),
            # The Residual layer implements the so-called skip-connection, i.e. x + F(x).
            # Here the patch embeddings (x) are summed with the position embeddings (F(x)) whose
            # weights are stored in the Parameter layer (note: there is no extra classification
            # token in this toy example)
            fl.Residual(fl.Parameter(num_patches**2, embedding_dim)),
            # These are the transformer encoders:
            *(
                fl.Chain(
                    fl.LayerNorm(embedding_dim),
                    fl.Residual(
                        # The Parallel layer is used to pass multiple inputs to a downstream
                        # layer, here multiheaded self-attention
                        fl.Parallel(
                            fl.Identity(),
                            fl.Identity(),
                            fl.Identity()
                        ),
                        fl.Attention(
                            embedding_dim=embedding_dim,
                            num_heads=num_heads,
                            key_embedding_dim=embedding_dim,
                            value_embedding_dim=embedding_dim,
                        ),
                    ),
                    fl.LayerNorm(embedding_dim),
                    fl.Residual(
                        fl.Linear(embedding_dim, embedding_dim * 4),
                        fl.GeLU(),
                        fl.Linear(embedding_dim * 4, embedding_dim),
                    ),
                    fl.Chain(
                        fl.Linear(embedding_dim, embedding_dim * 4),
                        fl.GeLU(),
                        fl.Linear(embedding_dim * 4, embedding_dim),
                    ),
                )
                for _ in range(num_layers)
            ),
            fl.Reshape(embedding_dim, num_patches, num_patches),
        )

vit = ViT(embedding_dim=768, image_size=224, num_heads=12)  # ~ViT-B/16 like
x = torch.randn(2, 3, 224, 224)
y = vit(x)

The Context API

The Chain class has a context provider that allows you to pass data to layers even when deeply nested.

E.g. to implement cross-attention you would just need to modify the ViT model like in the toy example below:

@@ -21,8 +21,8 @@
                     fl.Residual(
                         fl.Parallel(
                             fl.Identity(),
-                            fl.Identity(),
-                            fl.Identity()
+                            fl.UseContext(context="cross_attention", key="my_embed"),
+                            fl.UseContext(context="cross_attention", key="my_embed"),
                         ),  # used to pass multiple inputs to a layer
                         fl.Attention(
                             embedding_dim=embedding_dim,
@@ -49,5 +49,6 @@
         )

 vit = ViT(embedding_dim=768, image_size=224, num_heads=12)  # ~ViT-B/16 like
+vit.set_context("cross_attention", {"my_embed": torch.randn(2, 196, 768)})
 x = torch.randn(2, 3, 224, 224)
 y = vit(x)

The Adapter API

The Adapter API lets you easily patch models by injecting parameters in targeted layers. It comes with built-in support for canonical adapter types like LoRA, but you can also implement your custom adapters with it.

E.g. to inject LoRA layers in all attention's linear layers:

from refiners.adapters.lora import LoraAdapter

for layer in vit.layers(fl.Attention):
    for linear, parent in layer.walk(fl.Linear):
        adapter = LoraAdapter(target=linear, rank=64, device=vit.device, dtype=vit.dtype)
        adapter.inject(parent)

# ... and load existing weights if the LoRAs are pretrained ...

Getting Started

Install

# inference only
pip install refiners

Or:

# inference + training
pip install 'refiners[training]'

Hello World

Here is how to perform a text-to-image inference using the Stable Diffusion 1.5 foundational model patched with a Pokemon LoRA:

Step 1: prepare the model weights in refiners' format:

python scripts/convert-clip-weights.py --output-file CLIPTextEncoderL.safetensors
python scripts/convert-sd-lda-weights.py --output-file lda.safetensors
python scripts/convert-sd-unet-weights.py --output-file unet.safetensors

Note: this will download the original weights from https://huggingface.co/runwayml/stable-diffusion-v1-5 which takes some time. If you already have this repo cloned locally, use the --from /path/to/stable-diffusion-v1-5 option instead.

Step 2: download and convert a community Pokemon LoRA, e.g. this one

curl -LO https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin
python scripts/convert-lora-weights.py \
  --from pytorch_lora_weights.bin \
  --output-file pokemon_lora.safetensors

Step 3: run inference using the GPU:

from refiners.foundationals.latent_diffusion import StableDiffusion_1
from refiners.foundationals.latent_diffusion.lora import LoraWeights
from refiners.fluxion.utils import load_from_safetensors, manual_seed
import torch


sd15 = StableDiffusion_1(device="cuda")
sd15.clip_text_encoder.load_state_dict(load_from_safetensors("CLIPTextEncoderL.safetensors"))
sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors"))
sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors"))

# This uses the LoraAdapter internally and takes care to inject it where it should
lora_weights = LoraWeights("pokemon_lora.safetensors", device=sd15.device)
lora_weights.patch(sd15, scale=1.0)

prompt = "a cute cat"

with torch.no_grad():
    clip_text_embedding = sd15.compute_text_embedding(prompt)

sd15.set_num_inference_steps(30)

manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=sd15.device)

with torch.no_grad():
    for step in sd15.steps:
        x = sd15(
            x,
            step=step,
            clip_text_embedding=clip_text_embedding,
            condition_scale=7.5,
        )
    predicted_image = sd15.lda.decode_latents(x)
    predicted_image.save("pokemon_cat.png")

You should get:

pokemon cat output

Training

Refiners has a built-in training utils library and provides scripts that can be used as a starting point.

E.g. to train a LoRA on top of Stable Diffusion, copy and edit configs/finetune-lora.toml to suit your needs and launch the training as follows:

python scripts/training/finetune-ldm-lora.py configs/finetune-lora.toml

Credits

We took inspiration from these great projects:

Citation

@misc{the-finegrain-team-2023-refiners,
  author = {Benjamin Trom and Pierre Chapuis and Cédric Deltheil},
  title = {Refiners: The simplest way to train and run adapters on top of foundational models},
  year = {2023},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/finegrain-ai/refiners}}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

refiners-0.1.0.tar.gz (1.4 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

refiners-0.1.0-py3-none-any.whl (1.4 MB view details)

Uploaded Python 3

File details

Details for the file refiners-0.1.0.tar.gz.

File metadata

  • Download URL: refiners-0.1.0.tar.gz
  • Upload date:
  • Size: 1.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.1 CPython/3.10.4 Darwin/22.2.0

File hashes

Hashes for refiners-0.1.0.tar.gz
Algorithm Hash digest
SHA256 1c93c005360d9cfc2ed4a01edb8fa311345889745a7694f76cd5fb031f33e620
MD5 576a5258ebfa50697c19a19d279dd6e5
BLAKE2b-256 69e41778d2f581c8e5fab40c22aa18c7842f52c6a1cbbd2a3fb61500d15093c8

See more details on using hashes here.

File details

Details for the file refiners-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: refiners-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 1.4 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.1 CPython/3.10.4 Darwin/22.2.0

File hashes

Hashes for refiners-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 88bc972acc4418bcece9f3e02eb7a10bbf1e187ce45026906453a4852ff00111
MD5 9c07f7bdcda1a820546b88204c0e7bb6
BLAKE2b-256 eef57c34b6bf662275e7b66789a49d9a88a504ccfb3355d5a449cd41a808202b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page