Skip to main content

RaanA quantization algorithm

Project description

RaanA: A Fast, Flexible, and Data-Efficient Post-Training Quantization Algorithm

This repo contains the implementation of the paper "RaanA: A Fast, Flexible, and Data-Efficient Post-Training Quantization Algorithm".

Installation

  1. to install from pypi: pip install raana
  2. to install from source:
    pip install build 
    git clone https://github.com/FFTYYY/RaanA
    cd RaanA
    python -m build 
    
    the generated .whl files will be in dist/.

Quick Start

from transformers import AutoTokenizer, LlamaForCausalLM
from raana import quantize, zeroshot_calibration, trick_centralize, trick_norm_row

# initialize your model
model     = LlamaForCausalLM.from_pretrained(...)
tokenizer = AutoTokenizer.from_pretrained(...)

# quantization
quantized_model = quantize(
    model,                                              # the model to quantize
    b_candidates    = list(range(1,9)),                 # allowed bit-width
    calibrate_data  = zeroshot_calibration(tokenizer),  # use zero-shot calibration
    avg_bits        = 3.3,                              # average number of bits
)["model"]

# evaluate your model
evaluete(quantized_model)

A Complete Example

Too run example quantization for llama2 on wikitext2 (and reproduce the result reported in the paper):

pip install raana
git clone https://github.com/FFTYYY/RaanA
cd RaanA/examples
python wikitext2.py --model=meta-llama/Llama-2-7b-hf --avgbits=3.3

See examples/wikitext2.py for a complete example usage.

Detailed Usage

The entry point of raana is ranna.quantize.

from torch.nn   import Module
from torch      import Tensor
from typing     import Callable

from raana.task_adaptor     import TaskAdaptor
from raana.rotations        import RandomRotation, default_rotation
from raana.tricks           import Trick
from raana.select_layers    import default_linear_selector
from raana.quantized_linear import default_weightbias_extractor, default_matmul
from raana.tricks           import trick_centralize, trick_norm_col

quantize(
    model               : Module,
    b_candidates        : list[float],
    calibrate_data      : TaskAdaptor,
    avg_bits            : float,
    linear_selector     : Callable[[Module], bool]          = default_linear_selector,
    rotation_maker      : Callable[[], RandomRotation]      = default_rotation,
    trick_makers        : list[Callable[[], Trick]]         = [trick_centralize, trick_norm_col],
    weightbias_extractor: Callable[[Module], tuple[Tensor, Tensor | None]] = default_weightbias_extractor,
    matmul              : Callable[[Tensor, Tensor, Tensor, int], Tensor]  = default_matmul,
)

Required Arguments

model: torch.nn.Module

  • The pytorch model to be quantized.

b_candidates: list[float]

  • Candidate number of bits allowed for each layer.
  • Can optionally float numbers in 0~1. If so, less-than-one-bit quantization will be enabled.
  • Example: [0.5, 0.75, 1, 2, 3, 4].

calibrate_data: raana.task_adaptor.TaskAdaptor

  • The calibration data used for quantization.
  • For language modeling tasks, can use raana.task_adaptor.LMAdaptor( data: list[str], tokenizer: PreTrainedTokenizer)
  • For zero-shot calibration in language modeling, use raana.zeroshot_calibration(tokenizer).
  • For non-language modeling tasks, can write your own TaskAdaptor class.

avg_bits: float

  • Target average number of bits per quantized linear layer. The quantizer will search for the optimal bit allocation under this constraint.

Optional Arguments

linear_selector: Callable[[torch.nn.Module], bool]

  • A function to choose which sub-modules to quantize.
  • There are different types of linear modules in different model implementations (e.g. some models use nn.Linear while others use nn.Conv1d), so we allow the user to use this function to specify which linear modules are to quantize.
  • Default: selcte all torch.nn.Linear layers.

rotation_maker: Callable[[], raana.rotations.RandomRotation]

  • A function to construct a random rotation.
  • This parameter leaves flexibility for users to specify their own random rotation implementation.
  • The default implementation is randomized Hadamard Transformation, as described in the paper.
  • The Hadamard Transformation used in the default parameter is simply a matrix multiplication with the Hadamard matrix generated by scipy.linear.hadamard. In order to minimize the dependency of raana, we don't use any GPU fast Hadamard kernels in the default implementation. The users are encouraged to install fast Hadamdard kernels themselves and pass them to the quantizer through this parameter.
  • We encourage users to install the fast Hadamard implementation from DAO-AILab and pass it to raana:
    from torch import Tensor
    from fast_hadamard_transform import hadamard_transform
    from raana.rotations import PiecewiseHadamard
    
    def hadamard(X: Tensor):
        # normalize it by sqrt(d) to make it an orthornormal operator.
        return hadamard_transform(X) / (X.size(-1) ** 0.5) 
    
    quantize(
        ..., 
        rotation_maker = lambda: PiecewiseHadamard( hadamard = hadamard )
    )
    
  • Default: randomized Hadamard transformation. Uses scipy.linalg.hadamard as the implementation of Hadamard Transformation.

trick_makers: list[Callable[[], raana.tricks.Trick]]

  • List of functions to construct tricks. See the paper for the definition of "trick" here.
  • Currently implemented four tricks: trick_centralize, trick_pca, trick_norm_row, trick_norm_col.
  • Default: [trick_centralize, trick_norm_col].

weightbias_extractor: Callable[[nn.Module], tuple[Tensor, Tensor | None]]

  • A function to extract weight and bias matrices from a linear module and transform them into the standard size.
  • The returned value of this function should be extracted weight and bias of the provided layer. weight should be a tensor of size (d_in, d_out), and bias should be None or a tensor of size (d_out, ).
  • Default: lambda layer: (layer.weight.t().data, layer.bias.data)

matmul: Callable[[Tensor, Tensor, Tensor, int], Tensor]

  • A function to perform low-precision matrix multiplication.
  • Since there are no official implementation for low-precision uint-float matrix multiplication implemnetation and we want to minimize the dependency of raana, we leave the implementation of matrix multiplication to users.
  • The input parameters are X, qW, rescale, B. X is a float tensor, qW is a B-bit uint tensor and rescale is a float rescale tensor. This return value of this function should be equal to (X@qW - ((2**B-1)/2.*X.sum(dim=-1)).view(-1,1)) * rescale.view(1,-1).
  • Default: transform everything to float32 and do standard matrix multiplication. Below is the default implementation.
    def default_matmul(X: tc.Tensor, qW: tc.Tensor, rescale: tc.Tensor, B: int):
        dtype = X.dtype
        X       = X.to(tc.float32)
        rescale = rescale.to(tc.float32).view(1, -1)
        q_bias  = (float(2 ** B - 1) / 2. * X.sum(dim = -1)).view(-1, 1)
        Z = (X @ qW.to(tc.float32)) * rescale
        Z = Z - q_bias * rescale
        return Z.to(dtype)
    

Returns

{
    "model" : torch.nn.Module,  # quantized model
    "bits"  : list[int],        # allocated bitwidth per layer
    "losses": list[float]       # calibration loss per calibration data
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

raana-0.1.0.tar.gz (16.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

raana-0.1.0-cp310-cp310-win_amd64.whl (138.1 kB view details)

Uploaded CPython 3.10Windows x86-64

File details

Details for the file raana-0.1.0.tar.gz.

File metadata

  • Download URL: raana-0.1.0.tar.gz
  • Upload date:
  • Size: 16.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.8

File hashes

Hashes for raana-0.1.0.tar.gz
Algorithm Hash digest
SHA256 1198b4bcaa866e0ec008d5ef3f50218cc197875f3e3eae6e9b8b382d91b125b6
MD5 dd711662d7b3146aa02ebff536b20cd1
BLAKE2b-256 5b7adaf36d0fbaae7aa9917cc23086edef3dd7b431b9252017f3f3ac10ee21b6

See more details on using hashes here.

File details

Details for the file raana-0.1.0-cp310-cp310-win_amd64.whl.

File metadata

  • Download URL: raana-0.1.0-cp310-cp310-win_amd64.whl
  • Upload date:
  • Size: 138.1 kB
  • Tags: CPython 3.10, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.8

File hashes

Hashes for raana-0.1.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 e1a3bf9fbeee710f1597e84db4e9e68643c4bef25903810996810c75034be97d
MD5 00a0aba9b05cc928dc07bbaada7a7e77
BLAKE2b-256 ac0c2df4bdeab20a4d9c95cc2099fa6c509a0f0478553a77104bf012fb3c6ca9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page