No project description provided
Project description
💐 Flora: Low-Rank Adapters Are Secretly Gradient Compressors
This is the official repository for the paper Flora: Low-Rank Adapters Are Secretly Gradient Compressors in ICML 2024. This repository contains the code for the experiments in the paper.
Flora dramatically decreases the GPU memory needed to pre-train and fine-tune models without compromising quality or causing significant slowdown.
Installation
You can install the library using pip:
pip install 'flora-opt[torch]' # for PyTorch
or
pip install 'flora-opt[jax]' # for JAX
Usage
The library is designed to be compatible with huggingface libraries. You can use it as a drop-in replacement for huggingface libraries. Here is an example for PyTorch:
- optimizer = transformers.AdamW(model.parameters(), lr=1e-5)
+ optimizer = flora_opt.Flora(model.parameters(), lr=1e-5)
- accelerator = accelerate.Accelerator(**kwargs)
+ accelerator = flora_opt.FloraAccelerator(**kwargs)
Everything else remains the same. You can find more examples in the examples
folder.
How it works
Normally, there are three components in deep learning training: the model parameters, the optimizer, and the activations. Take the Adam optimizer as an example, the overall procedure is as follows:
Procedure | Memory |
---|---|
Since Adam needs to store the first- and second-order moments for each parameter, the memory may double the size of the model.
In our work, instead of maintaining the optimizers' states, we propose to use low-rank random projections to compress the moments. The overall procedure is as follows:
Procedure | Memory |
---|---|
The low-rank random projections reduces the memory usage by a factor of rank / d
, where rank
is the rank of the low-rank random projections and d
is the dimension of the model parameters. In addition, the low-rank random projections can be generated on-the-fly using the random seed, which further reduces the memory usage.
Moreover, Flora is compatible with the existing memory-efficient training techniques. For example, Flora can be combined with the activation checkpointing (AC) and layer-by-layer update (LOMO) to further reduce the memory usage, as shown below:
Procedure | Memory |
---|---|
In summary, Flora is a simple yet effective method to compress the optimizer's states, which can be easily integrated into existing training frameworks.
Paper Replications (with JAX)
To replicate major experiments in the paper, run the following commands:
pip install -r examples/flax/requirements.txt
sh replicate.sh
You can also run individual experiments by selecting the corresponding script in the file replicate.sh
.
Explanation
The arguments for the Flora
optimizer (for PyTorch) are explained below:
flora_opt.Flora(
params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]], # model parameters
lr: float = None, # learning rate
rank: int = None, # rank of the low-rank random projections
kappa: int = 1000, # the interval for updating the low-rank random projections
eps: tuple[float, float] = (1e-30, 1e-3), # Adafactor parameter
clip_threshold: float = 1.0, # Adafactor parameter
decay_rate: float = -0.8, # decay rate in Adafactor
beta1: Optional[float] = None, # decay rate for the first moment
weight_decay: float = 0.0, # weight decay coefficient
scale_parameter: bool = True, # Adafactor parameter
relative_step: bool = False, # Adafactor parameter
warmup_init: bool = False, # Adafactor parameter
factorize_second_moment: bool = True, # use Adafactor or Adam
seed: int = 0, # random seed to generate the low-rank random projections
quantization: bool = False, # whether to quantize the states
)
For JAX, the arguments are similar. The translation can be found in flora_opt/optimizers/torch/__init__.py
.
Citation
@inproceedings{hao2024flora,
title={Flora: Low-Rank Adapters Are Secretly Gradient Compressors},
author={Hao, Yongchang and Cao, Yanshuai and Mou, Lili},
booktitle={Forty-first International Conference on Machine Learning},
url={https://arxiv.org/abs/2402.03293},
year={2024}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file flora_opt-0.0.1.tar.gz
.
File metadata
- Download URL: flora_opt-0.0.1.tar.gz
- Upload date:
- Size: 18.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.10.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 43836a62793a9bea16e9bc314813e298ae7e2fdefae7eb6d7b3efef38b8a80f8 |
|
MD5 | f189e1f1248188b6a94927bf87b45f3b |
|
BLAKE2b-256 | 77866da49a5cc95c9d66edd33ed5e10b00add6d510f75e67d58f80b42ef138d5 |
File details
Details for the file flora_opt-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: flora_opt-0.0.1-py3-none-any.whl
- Upload date:
- Size: 18.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.10.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d84b62cf7c8b060043a1bbae1d06dd6a0b9ca541f22c0c3696a171e12825ce60 |
|
MD5 | 807a75e367b8c52b2402bd8f9ea9ef26 |
|
BLAKE2b-256 | 1ecb2def687092cb47c75ed7ca05f7140eeec70b16eeef11c8a5e981b79693de |