project_description
Project description
lora-pytorch
A simple but robust implementation of LoRA (Low-Rank Adaptation) for PyTorch, which depends only on PyTorch itself! No dependence on transformers
or other packages.
- Compatible with LLMs, CNNs, MLPs, and other model types ✔️
- Strongly typed ✔️
- Fully tested ✔️
Install
PyPI:
pip install lora-pytorch
From source:
pip install "lora-pytorch @ git+ssh://git@github.com/fkodom/lora-pytorch.git"
For contributors:
# Clone repository
gh repo clone fkodom/lora-pytorch
# Install all dev dependencies (tests etc.)
cd lora-pytorch
pip install -e ".[all]"
# Setup pre-commit hooks
pre-commit install
Usage
import torch
from lora_pytorch import LoRA
from torchvision.models import resnet18, ResNet
# Wrap your model with LoRA
model = resnet18()
lora_model = LoRA.from_module(model, rank=5)
print(lora_model)
# LoRA(
# (module): ResNet(
# (conv1): LoRA(
# (module): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# (lora_module): Conv2dLoRAModule(
# (in_conv): Conv2d(3, 5, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# (out_conv): Conv2d(5, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
# (dropout): Dropout(p=0.0, inplace=False)
# )
# )
# (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (relu): ReLU(inplace=True)
# ...
# Train or predict as usual.
x = torch.randn(1, 3, 224, 224)
y = lora_model(x)
# compute loss, backprop, etc...
# Merge LoRA weights into the original model.
new_model = lora_model.merge_lora(inplace=False) # default: inplace=False
# NOTE: new_model has the same type as the original model! Inference is just as
# fast as in the original model.
assert isinstance(new_model, ResNet)
Advanced Usage
Enable or disable LoRA
as needed. (e.g. to access the original model)
NOTE: LoRA
will not track gradients from the original model.
# Disable
lora_model.disable_lora()
y = lora_model(x)
print(y.requires_grad)
# False
# Re-enable
lora_model.enable_lora()
y = lora_model(x)
print(y.requires_grad)
# True
Remove LoRA
from the model.
NOTE: The original model weights will be unchanged.
# Remove
original_model = lora_model.remove_lora(inplace=False) # default: inplace=False
assert isinstance(original_model, ResNet)
Supported Layers
Layer | Supported |
---|---|
nn.Linear |
✅ |
nn.Embedding |
✅ |
nn.MultiheadAttention |
✅ |
nn.TransformerEncoder |
✅ |
nn.TransformerEncoderLayer |
✅ |
nn.TransformerDecoder |
✅ |
nn.TransformerDecoderLayer |
✅ |
nn.Transformer |
✅ |
nn.Conv1d |
✅ |
nn.Conv2d |
✅ |
nn.Conv3d |
✅ |
nn.ConvTranspose1d |
❌ |
nn.ConvTranspose2d |
❌ |
nn.ConvTranspose3d |
❌ |
NOTE: Activation, normalization, dropout, etc. layers are not affected by LoRA
. Those are not listed here, but you shouldn't have any problems using them.
TODO
- Add support for
ConvTranspose
layers. - Experiments with large, pretrained models
- Specifically, models that are not covered by LoRA in huggingface/transformers.
- Lots of CV examples: ResNet, ViT, DETR, UNET, DeepLab, etc.