Skip to main content

project_description

Project description

lora-pytorch

A simple but robust implementation of LoRA (Low-Rank Adaptation) for PyTorch, which depends only on PyTorch itself! No dependence on transformers or other packages.

  • Compatible with LLMs, CNNs, MLPs, and other model types ✔️
  • Strongly typed ✔️
  • Fully tested ✔️

Install

PyPI:

pip install lora-pytorch

From source:

pip install "lora-pytorch @ git+ssh://git@github.com/fkodom/lora-pytorch.git"

For contributors:

# Clone repository
gh repo clone fkodom/lora-pytorch
# Install all dev dependencies (tests etc.)
cd lora-pytorch
pip install -e ".[all]"
# Setup pre-commit hooks
pre-commit install

Usage

import torch
from lora_pytorch import LoRA
from torchvision.models import resnet18, ResNet

# Wrap your model with LoRA
model = resnet18()
lora_model = LoRA.from_module(model, rank=5)

print(lora_model)
# LoRA(
#   (module): ResNet(
#     (conv1): LoRA(
#       (module): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#       (lora_module): Conv2dLoRAModule(
#         (in_conv): Conv2d(3, 5, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#         (out_conv): Conv2d(5, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
#         (dropout): Dropout(p=0.0, inplace=False)
#       )
#     )
#     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#     (relu): ReLU(inplace=True)
# ...

# Train or predict as usual.
x = torch.randn(1, 3, 224, 224)
y = lora_model(x)
# compute loss, backprop, etc...

# Merge LoRA weights into the original model.
new_model = lora_model.merge_lora(inplace=False)  # default: inplace=False

# NOTE: new_model has the same type as the original model!  Inference is just as
# fast as in the original model.
assert isinstance(new_model, ResNet)

Advanced Usage

Enable or disable LoRA as needed. (e.g. to access the original model)

NOTE: LoRA will not track gradients from the original model.

# Disable
lora_model.disable_lora()
y = lora_model(x)
print(y.requires_grad)
# False

# Re-enable
lora_model.enable_lora()
y = lora_model(x)
print(y.requires_grad)
# True

Remove LoRA from the model.

NOTE: The original model weights will be unchanged.

# Remove
original_model = lora_model.remove_lora(inplace=False)  # default: inplace=False
assert isinstance(original_model, ResNet)

Supported Layers

Layer Supported
nn.Linear
nn.Embedding
nn.MultiheadAttention
nn.TransformerEncoder
nn.TransformerEncoderLayer
nn.TransformerDecoder
nn.TransformerDecoderLayer
nn.Transformer
nn.Conv1d
nn.Conv2d
nn.Conv3d
nn.ConvTranspose1d
nn.ConvTranspose2d
nn.ConvTranspose3d

NOTE: Activation, normalization, dropout, etc. layers are not affected by LoRA. Those are not listed here, but you shouldn't have any problems using them.

TODO

  • Add support for ConvTranspose layers.
  • Experiments with large, pretrained models
    • Specifically, models that are not covered by LoRA in huggingface/transformers.
    • Lots of CV examples: ResNet, ViT, DETR, UNET, DeepLab, etc.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lora_pytorch-0.2.0.tar.gz (14.1 kB view details)

Uploaded Source

Built Distribution

lora_pytorch-0.2.0-py3-none-any.whl (12.8 kB view details)

Uploaded Python 3

File details

Details for the file lora_pytorch-0.2.0.tar.gz.

File metadata

  • Download URL: lora_pytorch-0.2.0.tar.gz
  • Upload date:
  • Size: 14.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for lora_pytorch-0.2.0.tar.gz
Algorithm Hash digest
SHA256 2d15145198429fe0134245acbf411dc480d747d15a1ccffe92650634ecd914b1
MD5 80a2c1068d8d27d254b3226fb7bb556a
BLAKE2b-256 ce8b39068c1710ace1982685d9e710d70ef8bdc4df34be74dfa647e5d5ca4ae9

See more details on using hashes here.

File details

Details for the file lora_pytorch-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: lora_pytorch-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 12.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for lora_pytorch-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cb0395d7413510a36b2e85bcf448e52f9eabd4af8da01ab7fb2c84c5d08a77e7
MD5 cb69cc6c431654c6d4e31c5a941876f7
BLAKE2b-256 aba659170ff46d9cdd5a0705257c2345a922c875a682caf63b0e9320cc47a167

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page