project_description
Project description
lora-pytorch
A simple but robust implementation of LoRA (Low-Rank Adaptation) for PyTorch, which depends only on PyTorch itself! No dependence on transformers or other packages.
- Compatible with LLMs, CNNs, MLPs, and other model types ✔️
- Strongly typed ✔️
- Fully tested ✔️
Install
PyPI:
pip install lora-pytorch
From source:
pip install "lora-pytorch @ git+ssh://git@github.com/fkodom/lora-pytorch.git"
For contributors:
# Clone repository
gh repo clone fkodom/lora-pytorch
# Install all dev dependencies (tests etc.)
cd lora-pytorch
pip install -e ".[all]"
# Setup pre-commit hooks
pre-commit install
Usage
import torch
from lora_pytorch import LoRA
from torchvision.models import resnet18, ResNet
# Wrap your model with LoRA
model = resnet18()
lora_model = LoRA.from_module(model, rank=5)
print(lora_model)
# LoRA(
# (module): ResNet(
# (conv1): LoRA(
# (module): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# (lora_module): Conv2dLoRAModule(
# (in_conv): Conv2d(3, 5, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# (out_conv): Conv2d(5, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
# (dropout): Dropout(p=0.0, inplace=False)
# )
# )
# (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (relu): ReLU(inplace=True)
# ...
# Train or predict as usual.
x = torch.randn(1, 3, 224, 224)
y = lora_model(x)
# compute loss, backprop, etc...
# Merge LoRA weights into the original model.
new_model = lora_model.merge_lora(inplace=False) # default: inplace=False
# NOTE: new_model has the same type as the original model! Inference is just as
# fast as in the original model.
assert isinstance(new_model, ResNet)
Advanced Usage
Enable or disable LoRA as needed. (e.g. to access the original model)
NOTE: LoRA will not track gradients from the original model.
# Disable
lora_model.disable_lora()
y = lora_model(x)
print(y.requires_grad)
# False
# Re-enable
lora_model.enable_lora()
y = lora_model(x)
print(y.requires_grad)
# True
Remove LoRA from the model.
NOTE: The original model weights will be unchanged.
# Remove
original_model = lora_model.remove_lora(inplace=False) # default: inplace=False
assert isinstance(original_model, ResNet)
Supported Layers
| Layer | Supported |
|---|---|
nn.Linear |
✅ |
nn.Embedding |
✅ |
nn.MultiheadAttention |
✅ |
nn.TransformerEncoder |
✅ |
nn.TransformerEncoderLayer |
✅ |
nn.TransformerDecoder |
✅ |
nn.TransformerDecoderLayer |
✅ |
nn.Transformer |
✅ |
nn.Conv1d |
✅ |
nn.Conv2d |
✅ |
nn.Conv3d |
✅ |
nn.ConvTranspose1d |
❌ |
nn.ConvTranspose2d |
❌ |
nn.ConvTranspose3d |
❌ |
NOTE: Activation, normalization, dropout, etc. layers are not affected by LoRA. Those are not listed here, but you shouldn't have any problems using them.
TODO
- Add support for
ConvTransposelayers. - Experiments with large, pretrained models
- Specifically, models that are not covered by LoRA in huggingface/transformers.
- Lots of CV examples: ResNet, ViT, DETR, UNET, DeepLab, etc.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file lora_pytorch-0.2.0.tar.gz.
File metadata
- Download URL: lora_pytorch-0.2.0.tar.gz
- Upload date:
- Size: 14.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2d15145198429fe0134245acbf411dc480d747d15a1ccffe92650634ecd914b1
|
|
| MD5 |
80a2c1068d8d27d254b3226fb7bb556a
|
|
| BLAKE2b-256 |
ce8b39068c1710ace1982685d9e710d70ef8bdc4df34be74dfa647e5d5ca4ae9
|
File details
Details for the file lora_pytorch-0.2.0-py3-none-any.whl.
File metadata
- Download URL: lora_pytorch-0.2.0-py3-none-any.whl
- Upload date:
- Size: 12.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cb0395d7413510a36b2e85bcf448e52f9eabd4af8da01ab7fb2c84c5d08a77e7
|
|
| MD5 |
cb69cc6c431654c6d4e31c5a941876f7
|
|
| BLAKE2b-256 |
aba659170ff46d9cdd5a0705257c2345a922c875a682caf63b0e9320cc47a167
|