Skip to main content

MLoRAx is a minimalist library for low-rank adaptation designd to effortlessly enable parameter-efficient training for Transformer-based models.

Project description

Installation

Choice 1: Just copy the code

You can just copy the code from mlorax.py and paste it into your project. This is the easiest way to use mlorax if you do not care about future updates.

Choice 2: Install with pip

You can also install mlorax with pip. This is the recommended way to use mlorax if you want to receive future updates.

pip install mlorax

Choice 3: Install from source

You can also install mlorax from source. You only need to do this if you want to contribute to the project.

git clone https://github.com/yongchanghao/MLoRAx.git
cd MLoRAx
pip install -e .

Usage

It is extremely easy to use mlorax to convert any Flax model to a LoRA model. The following code snippet shows how to convert a T5 model to a LoRA model based on HuggingFace's FlaxT5ForConditionalGeneration class.

+ import mlorax
  model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
- params = model.params
- apply_fn = model.__call__
+ lora_spec = mlorax.LoRASpec(rank=16, rules=['Attention.q', 'Attention.v'])
+ params, apply_fn, merge_fn = mlorax.lora_init(lora_spec, model)
  state = TrainState(apply_fn=apply_fn, params=params, tx=tx, **kwargs)

That's it! You can now train the model as usual.

Principles

Always use the returned apply_fn for model forwarding if possible. Otherwise use params=merge_fn(params) to pass the merged parameters to the function. For example, if you want to use model.generate for text generation, you can do the following:

- outputs = model.generate(**batch, params=params)
+ outputs = model.generate(**batch, params=merge_fn(params))

Example and Results

Please refer to the examples folder for more examples and results.

Citation

If you find MLoRAx useful, please cite the following paper:

@software{hao2023mlorax,
  author = {Yongchang Hao},
  title = {{ML}o{RA}x: a minimalist library for low-rank adaptation for {T}ransformer-based models},
  year = {2023},
  url = {https://github.com/yongchanghao/MLoRAx},
  version = {0.9.1}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

mlorax-0.9.1-py3-none-any.whl (8.1 kB view details)

Uploaded Python 3

File details

Details for the file mlorax-0.9.1-py3-none-any.whl.

File metadata

  • Download URL: mlorax-0.9.1-py3-none-any.whl
  • Upload date:
  • Size: 8.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for mlorax-0.9.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7e7fc6c60f7781a5ca726b0fee0fc7418ddd02fcaa38333d1422c300602644cb
MD5 64a4cab15a496696780bb78642349cef
BLAKE2b-256 744dbc0f3eb8eb8a11df99ee7a7706cb1aac1dd210eb9d40592443be9d354b55

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page