Skip to main content

MLoRAx is a minimalist library for low-rank adaptation designd to effortlessly enable parameter-efficient training for Transformer-based models.

Project description

Installation

Choice 1: Install with pip

You can install mlorax with pip. This is the recommended way to use mlorax if you want to receive future updates.

pip install mlorax

Choice 2: Just copy the code

You can also directly copy the code from mlorax.py and paste it into your project. This is the easiest way to use mlorax if you do not care about future updates.

Choice 3: Install from source

You can also install mlorax from source. You only need to do this if you want to contribute to the project.

git clone https://github.com/yongchanghao/MLoRAx.git
cd MLoRAx
pip install -e .

Usage

It is extremely easy to use mlorax to convert any Flax model to a LoRA model. The following code snippet shows how to convert a T5 model to a LoRA model based on HuggingFace's FlaxT5ForConditionalGeneration class.

+ import mlorax
  model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
- params = model.params
- apply_fn = model.__call__
+ lora_spec = mlorax.LoRASpec(rank=16, rules=['Attention.q', 'Attention.v'])
+ params, apply_fn, merge_fn = mlorax.lora_init(lora_spec, model)
  state = TrainState(apply_fn=apply_fn, params=params, tx=tx, **kwargs)

That's it! You can now train the model as usual.

Principles

Always use the returned apply_fn for model forwarding if possible. Otherwise use params=merge_fn(params) to pass the merged parameters to the function. For example, if you want to use model.generate for text generation, you can do the following:

- outputs = model.generate(**batch, params=params)
+ outputs = model.generate(**batch, params=merge_fn(params))

Example and Results

Please refer to the examples folder for details.

Citation

If you find MLoRAx useful, please cite the following paper:

@software{hao2023lmrax,
  author = {Yongchang Hao},
  title = {{T}he {LMR}ax {E}cosystem: a minimalist library for training {T}ransformer models with {JAX}},
  year = {2023},
  url = {https://github.com/yongchanghao/LMRax},
  version = {0.9.3}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

mlorax-0.9.3-py3-none-any.whl (6.3 kB view details)

Uploaded Python 3

File details

Details for the file mlorax-0.9.3-py3-none-any.whl.

File metadata

  • Download URL: mlorax-0.9.3-py3-none-any.whl
  • Upload date:
  • Size: 6.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for mlorax-0.9.3-py3-none-any.whl
Algorithm Hash digest
SHA256 be6d2b9dc8890642bbc2482cde44ff1ded9cbd501d6ad5132f22d5a86e91e464
MD5 f7706f46bc787692cef768d234db7a70
BLAKE2b-256 c79e40eb94db7148373b856faef695d0292f0c36e9c4e433adfbf33f5554d772

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page