Skip to main content

Cutting transformers layers

Project description

:scissors: Short Transformers

Normalized angular distance from initial layer l (x-axis) with block size n (y-axis).

pypi Version Ruff

Installation:

pip install short-transformers

Required additional dependencies: torch, transformers, datasets, accelerate.

Quickstart:

from short_transformers import ShortTransformer
from datasets import load_dataset

# load from path/hf_hub
model = ShortTransformer.from_pretrained(model_name)

# or use hf model
# model = ShortTransformer.from_model(hf_model)

# load hf dataset
dataset = load_dataset("allenai/c4", "en", split="validation", streaming=True)

# remove 5 layers, use the dataset to find the least important layers to remove
short_model = model.remove_layers(block_size=5, dataset=dataset, limit=1000)

# continue training to heal after the cut
# ...

# save as hf model
short_mdoel.save_pretrained(output_path)

Both short_model and the saved model are fully compatible with transformers. See examples/basic.py for a complete working example.

Pruning in steps:

Pruning can composed step-by-step and customized:

  1. Analyze model layers:
from datasets import load_dataset
from short_transformers import ShortTransformer
from short_transformers.utils import (
    draw_diagram,
    get_scored_blocks,
    get_best_pruning_start,
)
# load from path/hf_hub
model_name = "meta-llama/Meta-Llama-3-8B"

model = ShortTransformer.from_pretrained(model_name, device_map="auto")

dataset = load_dataset("allenai/c4", "en", split="validation", streaming=True)

# calculate distances between inputs/outputs from/to model layers
# results in a triangular numpy array of shape (layer_count, layer_count)
# results[x, y] - averaged distances for block of size x starting at layer y
results = model.analyse_layers(
    dataset=dataset,
    tokenizer=tokenizer,
    use_chat_template=False,
    key="text",
    limit=100,
    max_length=1000,
)

# draw results
# diagrams style matches the style of original article
# "The Unreasonable Ineffectiveness of the Deeper Layers"
draw_diagram(results, "results.png", title="Meta-Llama-3-8B", normalized=True)

Example output:

  1. Find optimal block_size and start_layer:
# find optimial block of size 'block_size' to prune
start_layer = get_best_pruning_start(results, block_size=5)

# evaluate all possibe block sizes to prune,
# for each block returns score 0-1
# which is averaged over samples distance between input and output to/from a block
block_score = get_scored_blocks(results, return_md=True, threshold=0.3)

Example output:

Block_size Removed_layers Score (avg dist)
1 25-25 0.123
2 24-25 0.155
3 25-27 0.181
4 24-27 0.204
5 23-27 0.226
6 22-27 0.248
7 22-28 0.268
8 20-27 0.291
  1. Pruning layers:
# prune 5-layers block
model.prune(start_layer=start_layer, block_size=5)

# save the pruned model
model.save_pretrained("model_output_dir")

See example/prune_in_steps.py for a complete working example.

  1. Changing the pruning method:

Default pruning method is based on angular distance of the last token. It is possible to overwrite the distance by using model.set_metric(some_callable) before model.analyse_layers().

# ...
from short_transformers.dist import get_angular_distance_ith_token

model_name = "meta-llama/Meta-Llama-3-8B"
model = ShortTransformer.from_pretrained(model_name, device_map="auto")

# choose metric
# calculate distances based on the angular distance of the i=0 token
model.set_metric(get_angular_distance_ith_token(i=0))

# load dataset ...

results = model.analyse_layers(
    dataset=dataset,
    tokenizer=tokenizer,
    key="text",
    limit=1,
    max_length=1000,
)

Supported metric for layer importance calculation:

Example outputs:

Meta-Llama-3-8B-Instruct

Layerwise distances:

Blockwise distances:

Euclidian Dist Last Token

Figure 1: Euclidian Dist Last Token. Figure 2: Euclidian Dist Last Token Normalised

Relative Magnitude

Figure 1: Relative Magnitude. Figure 2: Relative Magnitude Normalised

Bi Score

Figure 1: Bi Score. Figure 2: Bi Score Normalised

Linear Approximation Last Token

Figure 1: Linear Approximation Last Token. Figure 2: Linear Approximation Last Token Normalised

Angular Distance All Tokens

Figure 1: Angular Distance All Tokens. Figure 2: Angular Distance All Tokens Normalised

Angular Distance Last Token

Figure 1: Angular Distance Last Token. Figure 2: Angular Distance Last Token Normalised

Yi-1.5-9B-Chat-16K

Layerwise distances:

Blockwise distances:

Euclidian Dist Last Token

Figure 1: Euclidian Dist Last Token. Figure 2: Euclidian Dist Last Token Normalised

Relative Magnitude

Figure 1: Relative Magnitude. Figure 2: Relative Magnitude Normalised

Bi Score

Figure 1: Bi Score. Figure 2: Bi Score Normalised

Linear Approximation Last Token

Figure 1: Linear Approximation Last Token. Figure 2: Linear Approximation Last Token Normalised

Angular Distance Last Token

Figure 1: Angular Distance Last Token. Figure 2: Angular Distance Last Token Normalised

Citing:

If you use Short Transformers in your research, please cite with the following BibText

@misc{russak2024shorttransformers,
    title  = {ShortTransformers, optimal layer pruning tools},
    author = {Melisa Russak},
    url    = {https://github.com/melisa/short-transformers},
    year   = {2024}
}
@misc{gromov2024unreasonable,
      title={The Unreasonable Ineffectiveness of the Deeper Layers}, 
      author={Andrey Gromov and Kushal Tirumala and Hassan Shapourian and Paolo Glorioso and Daniel A. Roberts},
      year={2024},
      eprint={2403.17887},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
@misc{razzhigaev2024transformer,
      title={Your Transformer is Secretly Linear}, 
      author={Anton Razzhigaev and Matvey Mikhalchuk and Elizaveta Goncharova and Nikolai Gerasimenko and Ivan Oseledets and Denis Dimitrov and Andrey Kuznetsov},
      year={2024},
      eprint={2405.12250},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{men2024shortgpt,
      title={ShortGPT: Layers in Large Language Models are More Redundant Than You Expect}, 
      author={Xin Men and Mingyu Xu and Qingyu Zhang and Bingning Wang and Hongyu Lin and Yaojie Lu and Xianpei Han and Weipeng Chen},
      year={2024},
      eprint={2403.03853},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
@misc{samragh2023weight,
      title={Weight subcloning: direct initialization of transformers using larger pretrained ones}, 
      author={Mohammad Samragh and Mehrdad Farajtabar and Sachin Mehta and Raviteja Vemulapalli and Fartash Faghri and Devang Naik and Oncel Tuzel and Mohammad Rastegari},
      year={2023},
      eprint={2312.09299},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

short_transformers-1.0.0.tar.gz (14.3 kB view hashes)

Uploaded Source

Built Distribution

short_transformers-1.0.0-py3-none-any.whl (17.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page