project_description
Project description
lora-pytorch
A simple but robust implementation of LoRA (Low-Rank Adaptation) for PyTorch, which depends only on PyTorch itself! No dependence on transformers
or other packages.
- Compatible with LLMs, CNNs, MLPs, and other model types ✔️
- Strongly typed ✔️
- Fully tested ✔️
Install
PyPI:
pip install lora-pytorch
From source:
pip install "lora-pytorch @ git+ssh://git@github.com/fkodom/lora-pytorch.git"
For contributors:
# Clone repository
gh repo clone fkodom/lora-pytorch
# Install all dev dependencies (tests etc.)
cd lora-pytorch
pip install -e ".[all]"
# Setup pre-commit hooks
pre-commit install
Usage
import torch
from lora_pytorch import LoRA
from torchvision.models import resnet18, ResNet
# Wrap your model with LoRA
model = resnet18()
lora_model = LoRA.from_module(model, rank=5)
print(lora_model)
# LoRA(
# (module): ResNet(
# (conv1): LoRA(
# (module): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# (lora_module): Conv2dLoRAModule(
# (in_conv): Conv2d(3, 5, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# (out_conv): Conv2d(5, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
# (dropout): Dropout(p=0.0, inplace=False)
# )
# )
# (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (relu): ReLU(inplace=True)
# ...
# Train or predict as usual.
x = torch.randn(1, 3, 224, 224)
y = lora_model(x)
# compute loss, backprop, etc...
# Merge LoRA weights into the original model.
new_model = lora_model.merge_lora(inplace=False) # default: inplace=False
# NOTE: new_model has the same type as the original model! Inference is just as
# fast as in the original model.
assert isinstance(new_model, ResNet)
Advanced Usage
Enable or disable LoRA
as needed. (e.g. to access the original model)
NOTE: LoRA
will not track gradients from the original model.
# Disable
lora_model.disable_lora()
y = lora_model(x)
print(y.requires_grad)
# False
# Re-enable
lora_model.enable_lora()
y = lora_model(x)
print(y.requires_grad)
# True
Remove LoRA
from the model.
NOTE: The original model weights will be unchanged.
# Remove
original_model = lora_model.remove_lora(inplace=False) # default: inplace=False
assert isinstance(original_model, ResNet)
Supported Layers
Layer | Supported |
---|---|
nn.Linear |
✅ |
nn.Embedding |
✅ |
nn.MultiheadAttention |
✅ |
nn.TransformerEncoder |
✅ |
nn.TransformerEncoderLayer |
✅ |
nn.TransformerDecoder |
✅ |
nn.TransformerDecoderLayer |
✅ |
nn.Transformer |
✅ |
nn.Conv1d |
✅ |
nn.Conv2d |
✅ |
nn.Conv3d |
✅ |
nn.ConvTranspose1d |
❌ |
nn.ConvTranspose2d |
❌ |
nn.ConvTranspose3d |
❌ |
NOTE: Activation, normalization, dropout, etc. layers are not affected by LoRA
. Those are not listed here, but you shouldn't have any problems using them.
TODO
- Add support for
ConvTranspose
layers. - Experiments with large, pretrained models
- Specifically, models that are not covered by LoRA in huggingface/transformers.
- Lots of CV examples: ResNet, ViT, DETR, UNET, DeepLab, etc.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file lora_pytorch-0.2.0.tar.gz
.
File metadata
- Download URL: lora_pytorch-0.2.0.tar.gz
- Upload date:
- Size: 14.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2d15145198429fe0134245acbf411dc480d747d15a1ccffe92650634ecd914b1 |
|
MD5 | 80a2c1068d8d27d254b3226fb7bb556a |
|
BLAKE2b-256 | ce8b39068c1710ace1982685d9e710d70ef8bdc4df34be74dfa647e5d5ca4ae9 |
File details
Details for the file lora_pytorch-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: lora_pytorch-0.2.0-py3-none-any.whl
- Upload date:
- Size: 12.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | cb0395d7413510a36b2e85bcf448e52f9eabd4af8da01ab7fb2c84c5d08a77e7 |
|
MD5 | cb69cc6c431654c6d4e31c5a941876f7 |
|
BLAKE2b-256 | aba659170ff46d9cdd5a0705257c2345a922c875a682caf63b0e9320cc47a167 |