Skip to main content

MLoRAx is a minimalist library for low-rank adaptation designd to effortlessly enable parameter-efficient training for Transformer-based models.

Project description

Installation

Choice 1: Install with pip

You can install mlorax with pip. This is the recommended way to use mlorax if you want to receive future updates.

pip install mlorax

Choice 2: Just copy the code

You can also directly copy the code from mlorax.py and paste it into your project. This is the easiest way to use mlorax if you do not care about future updates.

Choice 3: Install from source

You can also install mlorax from source. You only need to do this if you want to contribute to the project.

git clone https://github.com/yongchanghao/MLoRAx.git
cd MLoRAx
pip install -e .

Usage

It is extremely easy to use mlorax to convert any Flax model to a LoRA model. The following code snippet shows how to convert a T5 model to a LoRA model based on HuggingFace's FlaxT5ForConditionalGeneration class.

+ import mlorax
  model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
- params = model.params
- apply_fn = model.__call__
+ lora_spec = mlorax.LoRASpec(rank=16, rules=['Attention.q', 'Attention.v'])
+ params, apply_fn, merge_fn = mlorax.lora_init(lora_spec, model)
  state = TrainState(apply_fn=apply_fn, params=params, tx=tx, **kwargs)

That's it! You can now train the model as usual.

Principles

Always use the returned apply_fn for model forwarding if possible. Otherwise use params=merge_fn(params) to pass the merged parameters to the function. For example, if you want to use model.generate for text generation, you can do the following:

- outputs = model.generate(**batch, params=params)
+ outputs = model.generate(**batch, params=merge_fn(params))

Example and Results

Please refer to the examples folder for details.

Citation

If you find MLoRAx useful, please cite the following paper:

@software{hao2023lmrax,
  author = {Yongchang Hao},
  title = {{T}he {LMR}ax {E}cosystem: a minimalist library for training {T}ransformer models with {JAX}},
  year = {2023},
  url = {https://github.com/yongchanghao/LMRax},
  version = {0.9.5}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlorax-0.9.5.tar.gz (8.4 kB view details)

Uploaded Source

Built Distribution

mlorax-0.9.5-py3-none-any.whl (9.1 kB view details)

Uploaded Python 3

File details

Details for the file mlorax-0.9.5.tar.gz.

File metadata

  • Download URL: mlorax-0.9.5.tar.gz
  • Upload date:
  • Size: 8.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.13

File hashes

Hashes for mlorax-0.9.5.tar.gz
Algorithm Hash digest
SHA256 8a86b67318acae0c481c22c0e24943891f279e73d189b084d3f996afaa36a4f2
MD5 b02c219ce4e9f0b568b2b47774a7fda5
BLAKE2b-256 e7d7b087069da2a70847f2190d84f861013c5d62650a1c8789c37e67e2ebbdb8

See more details on using hashes here.

File details

Details for the file mlorax-0.9.5-py3-none-any.whl.

File metadata

  • Download URL: mlorax-0.9.5-py3-none-any.whl
  • Upload date:
  • Size: 9.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.13

File hashes

Hashes for mlorax-0.9.5-py3-none-any.whl
Algorithm Hash digest
SHA256 dd487aef30bac752cde032500db6ecf71bd5c5582e56b18d72a61388d2f78599
MD5 761f319641a8c67507ca77ee779d2a2a
BLAKE2b-256 6987f50e455ff749cc9125f93ff836ddceabe3e0d64c636b77df71b239ae643b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page