Skip to main content

project_description

Project description

lora-pytorch

A simple but robust implementation of LoRA (Low-Rank Adaptation) for PyTorch, which depends only on PyTorch itself! No dependence on transformers or other packages.

  • Compatible with LLMs, CNNs, MLPs, and other model types ✔️
  • Strongly typed ✔️
  • Fully tested ✔️

Install

PyPI:

pip install lora-pytorch

From source:

pip install "lora-pytorch @ git+ssh://git@github.com/fkodom/lora-pytorch.git"

For contributors:

# Clone repository
gh repo clone fkodom/lora-pytorch
# Install all dev dependencies (tests etc.)
cd lora-pytorch
pip install -e ".[all]"
# Setup pre-commit hooks
pre-commit install

Usage

import torch
from lora_pytorch import LoRA
from torchvision.models import resnet18, ResNet

# Wrap your model with LoRA
model = resnet18()
lora_model = LoRA.from_module(model, rank=5)

print(lora_model)
# LoRA(
#   (module): ResNet(
#     (conv1): LoRA(
#       (module): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#       (lora_module): Conv2dLoRAModule(
#         (in_conv): Conv2d(3, 5, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#         (out_conv): Conv2d(5, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
#         (dropout): Dropout(p=0.0, inplace=False)
#       )
#     )
#     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#     (relu): ReLU(inplace=True)
# ...

# Train or predict as usual.
x = torch.randn(1, 3, 224, 224)
y = lora_model(x)
# compute loss, backprop, etc...

# Merge LoRA weights into the original model.
new_model = lora_model.merge_lora(inplace=False)  # default: inplace=False

# NOTE: new_model has the same type as the original model!  Inference is just as
# fast as in the original model.
assert isinstance(new_model, ResNet)

Advanced Usage

Enable or disable LoRA as needed. (e.g. to access the original model)

NOTE: LoRA will not track gradients from the original model.

# Disable
lora_model.disable_lora()
y = lora_model(x)
print(y.requires_grad)
# False

# Re-enable
lora_model.enable_lora()
y = lora_model(x)
print(y.requires_grad)
# True

Remove LoRA from the model.

NOTE: The original model weights will be unchanged.

# Remove
original_model = lora_model.remove_lora(inplace=False)  # default: inplace=False
assert isinstance(original_model, ResNet)

Supported Layers

Layer Supported
nn.Linear
nn.Embedding
nn.MultiheadAttention
nn.TransformerEncoder
nn.TransformerEncoderLayer
nn.TransformerDecoder
nn.TransformerDecoderLayer
nn.Transformer
nn.Conv1d
nn.Conv2d
nn.Conv3d
nn.ConvTranspose1d
nn.ConvTranspose2d
nn.ConvTranspose3d

NOTE: Activation, normalization, dropout, etc. layers are not affected by LoRA. Those are not listed here, but you shouldn't have any problems using them.

TODO

  • Add support for ConvTranspose layers.
  • Experiments with large, pretrained models
    • Specifically, models that are not covered by LoRA in huggingface/transformers.
    • Lots of CV examples: ResNet, ViT, DETR, UNET, DeepLab, etc.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lora_pytorch-0.2.0.tar.gz (14.1 kB view hashes)

Uploaded Source

Built Distribution

lora_pytorch-0.2.0-py3-none-any.whl (12.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page