Skip to main content

MLoRAx is a minimalist library for low-rank adaptation designd to effortlessly enable parameter-efficient training for Transformer-based models.

Project description

Installation

Choice 1: Install with pip

You can install mlorax with pip. This is the recommended way to use mlorax if you want to receive future updates.

pip install mlorax

Choice 2: Just copy the code

You can also directly copy the code from mlorax.py and paste it into your project. This is the easiest way to use mlorax if you do not care about future updates.

Choice 3: Install from source

You can also install mlorax from source. You only need to do this if you want to contribute to the project.

git clone https://github.com/yongchanghao/MLoRAx.git
cd MLoRAx
pip install -e .

Usage

It is extremely easy to use mlorax to convert any Flax model to a LoRA model. The following code snippet shows how to convert a T5 model to a LoRA model based on HuggingFace's FlaxT5ForConditionalGeneration class.

+ import mlorax
  model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
- params = model.params
- apply_fn = model.__call__
+ lora_spec = mlorax.LoRASpec(rank=16, rules=['Attention.q', 'Attention.v'])
+ params, apply_fn, merge_fn = mlorax.lora_init(lora_spec, model)
  state = TrainState(apply_fn=apply_fn, params=params, tx=tx, **kwargs)

That's it! You can now train the model as usual.

Principles

Always use the returned apply_fn for model forwarding if possible. Otherwise use params=merge_fn(params) to pass the merged parameters to the function. For example, if you want to use model.generate for text generation, you can do the following:

- outputs = model.generate(**batch, params=params)
+ outputs = model.generate(**batch, params=merge_fn(params))

Example and Results

Please refer to the examples folder for details.

Citation

If you find MLoRAx useful, please cite the following paper:

@software{hao2023lmrax,
  author = {Yongchang Hao},
  title = {{T}he {LMR}ax {E}cosystem: a minimalist library for training {T}ransformer models with {JAX}},
  year = {2023},
  url = {https://github.com/yongchanghao/LMRax},
  version = {0.9.4}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

mlorax-0.9.4-py3-none-any.whl (8.8 kB view details)

Uploaded Python 3

File details

Details for the file mlorax-0.9.4-py3-none-any.whl.

File metadata

  • Download URL: mlorax-0.9.4-py3-none-any.whl
  • Upload date:
  • Size: 8.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for mlorax-0.9.4-py3-none-any.whl
Algorithm Hash digest
SHA256 efa3a4163be7df90bb7452b9db5b45551bbbcad60aaaef11dfff5ac9444061b3
MD5 cadfe8c9ea5a257a50a9d5f054c0308a
BLAKE2b-256 bbfe6049af628625d23589378073113839cafbb3d78b092449856e286d8d1b91

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page