MLoRAx is a minimalist library for low-rank adaptation designd to effortlessly enable parameter-efficient training for Transformer-based models.
Project description
Installation
Choice 1: Just copy the code
You can just copy the code from mlorax.py
and paste it into your project. This is the easiest way to use mlorax
if you do not care about future updates.
Choice 2: Install with pip
You can also install mlorax
with pip. This is the recommended way to use mlorax
if you want to receive future updates.
pip install mlorax
Choice 3: Install from source
You can also install mlorax
from source. You only need to do this if you want to contribute to the project.
git clone https://github.com/yongchanghao/MLoRAx.git
cd MLoRAx
pip install -e .
Usage
It is extremely easy to use mlorax
to convert any Flax model to a LoRA model. The following code snippet shows how to convert a T5 model to a LoRA model based on HuggingFace's FlaxT5ForConditionalGeneration class.
+ import mlorax
model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
- params = model.params
- apply_fn = model.__call__
+ lora_spec = mlorax.LoRASpec(rank=16, rules=['Attention.q', 'Attention.v'])
+ params, apply_fn, merge_fn = mlorax.lora_init(lora_spec, model)
state = TrainState(apply_fn=apply_fn, params=params, tx=tx, **kwargs)
That's it! You can now train the model as usual.
Principles
Always use the returned apply_fn
for model forwarding if possible. Otherwise use params=merge_fn(params)
to pass the merged parameters to the function. For example, if you want to use model.generate
for text generation, you can do the following:
- outputs = model.generate(**batch, params=params)
+ outputs = model.generate(**batch, params=merge_fn(params))
Example and Results
For the code used in the example, check out the examples folder.
Citation
If you find MLoRAx useful, please cite the following paper:
@software{hao2023mlorax,
author = {Yongchang Hao},
title = {{ML}o{RA}x: a minimalist library for low-rank adaptation for {T}ransformer-based models},
year = {2023},
url = {https://github.com/yongchanghao/MLoRAx},
version = {0.9.0}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file mlorax-0.9.0-py3-none-any.whl
.
File metadata
- Download URL: mlorax-0.9.0-py3-none-any.whl
- Upload date:
- Size: 8.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7356dc29d3038eb85ebd6b6256d5b333290a05468199de50ae6229996ca4bedd |
|
MD5 | 5c83a090803044c92ba1ab283dc0b208 |
|
BLAKE2b-256 | e719a6822a8f3cf99a31ae2df07366cf4f8929dcb359dd72e290115389121a37 |