RaanA quantization algorithm
Project description
RaanA: A Fast, Flexible, and Data-Efficient Post-Training Quantization Algorithm
This repo contains the implementation of the paper "RaanA: A Fast, Flexible, and Data-Efficient Post-Training Quantization Algorithm".
Installation
- to install from pypi:
pip install raana - to install from source:
pip install build git clone https://github.com/FFTYYY/RaanA cd RaanA python -m build
the generated.whlfiles will be indist/.
Quick Start
from transformers import AutoTokenizer, LlamaForCausalLM
from raana import quantize, zeroshot_calibration, trick_centralize, trick_norm_row
# initialize your model
model = LlamaForCausalLM.from_pretrained(...)
tokenizer = AutoTokenizer.from_pretrained(...)
# quantization
quantized_model = quantize(
model, # the model to quantize
b_candidates = list(range(1,9)), # allowed bit-width
calibrate_data = zeroshot_calibration(tokenizer), # use zero-shot calibration
avg_bits = 3.3, # average number of bits
)["model"]
# evaluate your model
evaluete(quantized_model)
A Complete Example
Too run example quantization for llama2 on wikitext2 (and reproduce the result reported in the paper):
pip install raana
git clone https://github.com/FFTYYY/RaanA
cd RaanA/examples
python wikitext2.py --model=meta-llama/Llama-2-7b-hf --avgbits=3.3
See examples/wikitext2.py for a complete example usage.
Detailed Usage
The entry point of raana is ranna.quantize.
from torch.nn import Module
from torch import Tensor
from typing import Callable
from raana.task_adaptor import TaskAdaptor
from raana.rotations import RandomRotation, default_rotation
from raana.tricks import Trick
from raana.select_layers import default_linear_selector
from raana.quantized_linear import default_weightbias_extractor, default_matmul
from raana.tricks import trick_centralize, trick_norm_col
quantize(
model : Module,
b_candidates : list[float],
calibrate_data : TaskAdaptor,
avg_bits : float,
linear_selector : Callable[[Module], bool] = default_linear_selector,
rotation_maker : Callable[[], RandomRotation] = default_rotation,
trick_makers : list[Callable[[], Trick]] = [trick_centralize, trick_norm_col],
weightbias_extractor: Callable[[Module], tuple[Tensor, Tensor | None]] = default_weightbias_extractor,
matmul : Callable[[Tensor, Tensor, Tensor, int], Tensor] = default_matmul,
)
Required Arguments
model: torch.nn.Module
- The pytorch model to be quantized.
b_candidates: list[float]
- Candidate number of bits allowed for each layer.
- Can optionally float numbers in 0~1. If so, less-than-one-bit quantization will be enabled.
- Example:
[0.5, 0.75, 1, 2, 3, 4].
calibrate_data: raana.task_adaptor.TaskAdaptor
- The calibration data used for quantization.
- For language modeling tasks, can use
raana.task_adaptor.LMAdaptor( data: list[str], tokenizer: PreTrainedTokenizer) - For zero-shot calibration in language modeling, use
raana.zeroshot_calibration(tokenizer). - For non-language modeling tasks, can write your own
TaskAdaptorclass.
avg_bits: float
- Target average number of bits per quantized linear layer. The quantizer will search for the optimal bit allocation under this constraint.
Optional Arguments
linear_selector: Callable[[torch.nn.Module], bool]
- A function to choose which sub-modules to quantize.
- There are different types of linear modules in different model implementations (e.g. some models use
nn.Linearwhile others usenn.Conv1d), so we allow the user to use this function to specify which linear modules are to quantize. - Default: selcte all
torch.nn.Linearlayers.
rotation_maker: Callable[[], raana.rotations.RandomRotation]
- A function to construct a random rotation.
- This parameter leaves flexibility for users to specify their own random rotation implementation.
- The default implementation is randomized Hadamard Transformation, as described in the paper.
- The Hadamard Transformation used in the default parameter is simply a matrix multiplication with the Hadamard matrix generated by
scipy.linear.hadamard. In order to minimize the dependency ofraana, we don't use any GPU fast Hadamard kernels in the default implementation. The users are encouraged to install fast Hadamdard kernels themselves and pass them to the quantizer through this parameter. - We encourage users to install the fast Hadamard implementation from DAO-AILab and pass it to raana:
from torch import Tensor from fast_hadamard_transform import hadamard_transform from raana.rotations import PiecewiseHadamard def hadamard(X: Tensor): # normalize it by sqrt(d) to make it an orthornormal operator. return hadamard_transform(X) / (X.size(-1) ** 0.5) quantize( ..., rotation_maker = lambda: PiecewiseHadamard( hadamard = hadamard ) )
- Default: randomized Hadamard transformation. Uses
scipy.linalg.hadamardas the implementation of Hadamard Transformation.
trick_makers: list[Callable[[], raana.tricks.Trick]]
- List of functions to construct tricks. See the paper for the definition of "trick" here.
- Currently implemented four tricks:
trick_centralize,trick_pca,trick_norm_row,trick_norm_col. - Default:
[trick_centralize, trick_norm_col].
weightbias_extractor: Callable[[nn.Module], tuple[Tensor, Tensor | None]]
- A function to extract weight and bias matrices from a linear module and transform them into the standard size.
- The returned value of this function should be extracted
weightandbiasof the provided layer.weightshould be a tensor of size(d_in, d_out), andbiasshould beNoneor a tensor of size(d_out, ). - Default:
lambda layer: (layer.weight.t().data, layer.bias.data)
matmul: Callable[[Tensor, Tensor, Tensor, int], Tensor]
- A function to perform low-precision matrix multiplication.
- Since there are no official implementation for low-precision uint-float matrix multiplication implemnetation and we want to minimize the dependency of
raana, we leave the implementation of matrix multiplication to users. - The input parameters are
X, qW, rescale, B.Xis a float tensor,qWis aB-bit uint tensor andrescaleis a float rescale tensor. This return value of this function should be equal to(X@qW - ((2**B-1)/2.*X.sum(dim=-1)).view(-1,1)) * rescale.view(1,-1). - Default: transform everything to float32 and do standard matrix multiplication. Below is the default implementation.
def default_matmul(X: tc.Tensor, qW: tc.Tensor, rescale: tc.Tensor, B: int): dtype = X.dtype X = X.to(tc.float32) rescale = rescale.to(tc.float32).view(1, -1) q_bias = (float(2 ** B - 1) / 2. * X.sum(dim = -1)).view(-1, 1) Z = (X @ qW.to(tc.float32)) * rescale Z = Z - q_bias * rescale return Z.to(dtype)
Returns
{
"model" : torch.nn.Module, # quantized model
"bits" : list[int], # allocated bitwidth per layer
"losses": list[float] # calibration loss per calibration data
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file raana-0.1.0.tar.gz.
File metadata
- Download URL: raana-0.1.0.tar.gz
- Upload date:
- Size: 16.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1198b4bcaa866e0ec008d5ef3f50218cc197875f3e3eae6e9b8b382d91b125b6
|
|
| MD5 |
dd711662d7b3146aa02ebff536b20cd1
|
|
| BLAKE2b-256 |
5b7adaf36d0fbaae7aa9917cc23086edef3dd7b431b9252017f3f3ac10ee21b6
|
File details
Details for the file raana-0.1.0-cp310-cp310-win_amd64.whl.
File metadata
- Download URL: raana-0.1.0-cp310-cp310-win_amd64.whl
- Upload date:
- Size: 138.1 kB
- Tags: CPython 3.10, Windows x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e1a3bf9fbeee710f1597e84db4e9e68643c4bef25903810996810c75034be97d
|
|
| MD5 |
00a0aba9b05cc928dc07bbaada7a7e77
|
|
| BLAKE2b-256 |
ac0c2df4bdeab20a4d9c95cc2099fa6c509a0f0478553a77104bf012fb3c6ca9
|