MLoRAx is a minimalist library for low-rank adaptation designd to effortlessly enable parameter-efficient training for Transformer-based models.
Project description
Installation
Choice 1: Install with pip
You can install mlorax
with pip. This is the recommended way to use mlorax
if you want to receive future updates.
pip install mlorax
Choice 2: Just copy the code
You can also directly copy the code from mlorax.py
and paste it into your project. This is the easiest way to use mlorax
if you do not care about future updates.
Choice 3: Install from source
You can also install mlorax
from source. You only need to do this if you want to contribute to the project.
git clone https://github.com/yongchanghao/MLoRAx.git
cd MLoRAx
pip install -e .
Usage
It is extremely easy to use mlorax
to convert any Flax model to a LoRA model. The following code snippet shows how to convert a T5 model to a LoRA model based on HuggingFace's FlaxT5ForConditionalGeneration class.
+ import mlorax
model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
- params = model.params
- apply_fn = model.__call__
+ lora_spec = mlorax.LoRASpec(rank=16, rules=['Attention.q', 'Attention.v'])
+ params, apply_fn, merge_fn = mlorax.lora_init(lora_spec, model)
state = TrainState(apply_fn=apply_fn, params=params, tx=tx, **kwargs)
That's it! You can now train the model as usual.
Principles
Always use the returned apply_fn
for model forwarding if possible. Otherwise use params=merge_fn(params)
to pass the merged parameters to the function. For example, if you want to use model.generate
for text generation, you can do the following:
- outputs = model.generate(**batch, params=params)
+ outputs = model.generate(**batch, params=merge_fn(params))
Example and Results
Please refer to the examples folder for details.
Citation
If you find MLoRAx useful, please cite the following paper:
@software{hao2023lmrax,
author = {Yongchang Hao},
title = {{T}he {LMR}ax {E}cosystem: a minimalist library for training {T}ransformer models with {JAX}},
year = {2023},
url = {https://github.com/yongchanghao/LMRax},
version = {0.9.5}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file mlorax-0.9.5.tar.gz
.
File metadata
- Download URL: mlorax-0.9.5.tar.gz
- Upload date:
- Size: 8.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8a86b67318acae0c481c22c0e24943891f279e73d189b084d3f996afaa36a4f2 |
|
MD5 | b02c219ce4e9f0b568b2b47774a7fda5 |
|
BLAKE2b-256 | e7d7b087069da2a70847f2190d84f861013c5d62650a1c8789c37e67e2ebbdb8 |
File details
Details for the file mlorax-0.9.5-py3-none-any.whl
.
File metadata
- Download URL: mlorax-0.9.5-py3-none-any.whl
- Upload date:
- Size: 9.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | dd487aef30bac752cde032500db6ecf71bd5c5582e56b18d72a61388d2f78599 |
|
MD5 | 761f319641a8c67507ca77ee779d2a2a |
|
BLAKE2b-256 | 6987f50e455ff749cc9125f93ff836ddceabe3e0d64c636b77df71b239ae643b |