Skip to main content

Bayesian LoRA adapters for Language Models

Project description

Bayesian LoRA

Code for the paper Bayesian Low-Rank Adaptation for Large Language Models.

See the explanatory blog post and documentation.

Installation

pip install bayesian-lora

Example

We provide a comprehensive example in examples/example_usage.py, running through the main methods using Phi-2 on ARC-E.

Note that running this requires a local installation with a few extra dependencies. Run:

git clone https://github.com/MaximeRobeyns/bayesian_lora
cd bayesian_lora
pip install -e ".[examples]"

and then

python ./examples/example_usage.py

The main functions this library provides are for calculating Kronecker factors, the marginal likelihood, and the posterior predictive distribution. We show how to use these in the examples below.

Calculating (low-rank) Kronecker factors

First, wrap your model call in a function that takes a batch from your data loader, and returns the relevant logits. For a CausalLM from HuggingFace:

def fwd_call(model: nn.Module, batch_prompts: Any) -> t.Tensor:
    inputs = tokenizer(batch_prompts).to(device)
    outputs = model(**inputs)
    logits = outputs.logits[:, -1]  # Get the last token logits
    return logits

You can now call our calculate_kronecker_factors function:

from bayesian_lora import calculate_kronecker_factors

factors = calculate_kronecker_factors(
    model,            # Your model (not necessarily PEFT)
    fwd_call,         # Model call wrapper, defined above
    train_loader,     # Your training data loader
    cfg.n_kfac,       # (Optional) rank to use
    cfg.lr_threshold, # (Optional) threshold for low-rank approximation
    ["lora"],         # modules to target
    use_tqdm=True,    # (Optional) use tqdm for progress bar
)

In the above, the ["lora"] argument contains a case-insensitive list of keywords to identify modules to target. Since we're working with a LoRa model, we choose "lora" to target (e.g. layers.0.q_proj.lora_A, etc).

The factors are a dictionary with keys being the full name of the targetted modules, and a tuple of two tensors as the values: the first being the (possibly low-rank) Kronecker factor corresponding to the input activations, and the second being the (possibly low-rank) factor corresponding to the output gradients.

See the K-FAC docs for more detail.

Model Evidence

We provide a function called model_evidence which returns the evidence / marginal likelihood.

from bayesian_lora import model_evidence

evidence = model_evidence(
    model,           # Your model
    log_likelihood,  # A Tensor with model's log likelihood on some eval dataset
    factors,         # Kronecker factors, as calculated above
    n_lora,          # rank used in the LoRA adapters
    n_kfac,          # rank used in the Kronecker factors
    prior_var,       # prior variance hyperparameter, as a tensor
)

You can then use evidence as the loss in a normal training loop, presuming your parameters (e.g. prior_var have gradients).

Posterior Predictive Distribution

To get the parameters of the Gaussian over the logits, use the jacobian_mean and variance functions.

with t.no_grad():
    for batch in validation_loader
        prompts, classes = batch

        batch_inputs = tokenizer(prompts)

        # Predict the output logit locations
        # target_ids is a tensor containing the indices of the target tokens
        # e.g. [354, 355, 356].
        jacobian, f_mu = jacobian_mean(
            model, batch_inputs, target_ids
        )

        # Predict the output logit variances
        f_var = variance(
            batch_inputs,     # inputs
            jacobian,         # the Jacobian dictionary, obtained above
            factors,          # Kronecker factors, as calculated above
            prior_var,        # prior variance hyperparameter, as a tensor
            classes.size(-1), # number of classes to predict
            n_lora,           # rank of the LoRA adapters
            n_kfac,           # rank of the Kronecker factors
            device,           # device to use
        )

        # Now use the parameters to e.g. sample logits from the Gaussian
        # predictive, parametrised by f_mu, f_var
        L = t.linalg.cholesky(f_var)
        samples = 100_000
        f_mu = f_mu.expand(samples, *f_mu.shape)
        L = L.expand(samples, *L.shape)
        eps = t.randn_like(f_mu)
        logits = (f_mu + L @ eps).squeeze(-1).softmax(-1).mean(0)

The above is a minimal example; see this section of the documentation for more detail.

Development

This library is intentionally very small and hackable. It has two main files, and three dependencies (torch, tqdm and jaxtyping.)

  • main.py contains methods specific to the paper,
  • kfac.py contains relatively portable K-FAC methods

Feel free to directly copy the code into your projects and hack on it.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bayesian_lora-0.0.3.tar.gz (19.9 kB view details)

Uploaded Source

Built Distribution

bayesian_lora-0.0.3-py3-none-any.whl (16.8 kB view details)

Uploaded Python 3

File details

Details for the file bayesian_lora-0.0.3.tar.gz.

File metadata

  • Download URL: bayesian_lora-0.0.3.tar.gz
  • Upload date:
  • Size: 19.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.13

File hashes

Hashes for bayesian_lora-0.0.3.tar.gz
Algorithm Hash digest
SHA256 4bfa2cf6634495b94bf24b35e0547e3c38af62b2d685c2a0b7ca9d71322402d5
MD5 f06800b9a0dd4948050bfe346375bb5d
BLAKE2b-256 ed246cb8dc876e00f5e05b66a6767d332ef50554584836e280b64419e672d471

See more details on using hashes here.

File details

Details for the file bayesian_lora-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for bayesian_lora-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 0c366c4a7502d8b1424ec3ddfc2d79707bf2844872470c6fb6e9b44d3a833073
MD5 fb4ce99ef84a11b35b0b5dcb6023c16f
BLAKE2b-256 3be2d9fc41ba15fa0c1439cc50c410721d5b60da5baf24044f6963a101bd5196

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page