Skip to main content

No project description provided

Project description

💐 Flora: Low-Rank Adapters Are Secretly Gradient Compressors

This is the official repository for the paper Flora: Low-Rank Adapters Are Secretly Gradient Compressors in ICML 2024. This repository contains the code for the experiments in the paper.

Flora dramatically decreases the GPU memory needed to pre-train and fine-tune models without compromising quality or causing significant slowdown.

Flora Graphs

Installation

You can install the library using pip:

pip install 'flora-opt[torch]' # for PyTorch

or

pip install 'flora-opt[jax]' # for JAX

Usage

The library is designed to be compatible with huggingface libraries. You can use it as a drop-in replacement for huggingface libraries. Here is an example for PyTorch:

- optimizer = transformers.AdamW(model.parameters(), lr=1e-5)
+ optimizer = flora_opt.Flora(model.parameters(), lr=1e-5)
- accelerator = accelerate.Accelerator(**kwargs)
+ accelerator = flora_opt.FloraAccelerator(**kwargs)

Everything else remains the same. You can find more examples in the examples folder.

How it works

Normally, there are three components in deep learning training: the model parameters, the optimizer, and the activations. Take the Adam optimizer as an example, the overall procedure is as follows:

Procedure Memory

Since Adam needs to store the first- and second-order moments for each parameter, the memory may double the size of the model.

In our work, instead of maintaining the optimizers' states, we propose to use low-rank random projections to compress the moments. The overall procedure is as follows:

Procedure Memory

The low-rank random projections reduces the memory usage by a factor of rank / d, where rank is the rank of the low-rank random projections and d is the dimension of the model parameters. In addition, the low-rank random projections can be generated on-the-fly using the random seed, which further reduces the memory usage.

Moreover, Flora is compatible with the existing memory-efficient training techniques. For example, Flora can be combined with the activation checkpointing (AC) and layer-by-layer update (LOMO) to further reduce the memory usage, as shown below:

Procedure Memory

In summary, Flora is a simple yet effective method to compress the optimizer's states, which can be easily integrated into existing training frameworks.

Paper Replications (with JAX)

To replicate major experiments in the paper, run the following commands:

pip install -r examples/flax/requirements.txt
sh replicate.sh

You can also run individual experiments by selecting the corresponding script in the file replicate.sh.

Explanation

The arguments for the Flora optimizer (for PyTorch) are explained below:

flora_opt.Flora(
    params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]],  # model parameters
    lr: float = None,  # learning rate
    rank: int = None,  # rank of the low-rank random projections
    kappa: int = 1000,  # the interval for updating the low-rank random projections
    eps: tuple[float, float] = (1e-30, 1e-3),  # Adafactor parameter
    clip_threshold: float = 1.0,  # Adafactor parameter
    decay_rate: float = -0.8,  # decay rate in Adafactor
    beta1: Optional[float] = None,  # decay rate for the first moment
    weight_decay: float = 0.0,  # weight decay coefficient
    scale_parameter: bool = True,  # Adafactor parameter
    relative_step: bool = False,  # Adafactor parameter
    warmup_init: bool = False,  # Adafactor parameter
    factorize_second_moment: bool = True,  # use Adafactor or Adam
    seed: int = 0,  # random seed to generate the low-rank random projections
    quantization: bool = False,  # whether to quantize the states
)

For JAX, the arguments are similar. The translation can be found in flora_opt/optimizers/torch/__init__.py.

Citation

@inproceedings{hao2024flora,
  title={Flora: Low-Rank Adapters Are Secretly Gradient Compressors},
  author={Hao, Yongchang and Cao, Yanshuai and Mou, Lili},
  booktitle={Forty-first International Conference on Machine Learning},
  url={https://arxiv.org/abs/2402.03293},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flora_opt-0.0.1.tar.gz (18.3 kB view details)

Uploaded Source

Built Distribution

flora_opt-0.0.1-py3-none-any.whl (18.7 kB view details)

Uploaded Python 3

File details

Details for the file flora_opt-0.0.1.tar.gz.

File metadata

  • Download URL: flora_opt-0.0.1.tar.gz
  • Upload date:
  • Size: 18.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.10.0

File hashes

Hashes for flora_opt-0.0.1.tar.gz
Algorithm Hash digest
SHA256 43836a62793a9bea16e9bc314813e298ae7e2fdefae7eb6d7b3efef38b8a80f8
MD5 f189e1f1248188b6a94927bf87b45f3b
BLAKE2b-256 77866da49a5cc95c9d66edd33ed5e10b00add6d510f75e67d58f80b42ef138d5

See more details on using hashes here.

File details

Details for the file flora_opt-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: flora_opt-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 18.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.10.0

File hashes

Hashes for flora_opt-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d84b62cf7c8b060043a1bbae1d06dd6a0b9ca541f22c0c3696a171e12825ce60
MD5 807a75e367b8c52b2402bd8f9ea9ef26
BLAKE2b-256 1ecb2def687092cb47c75ed7ca05f7140eeec70b16eeef11c8a5e981b79693de

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page