Skip to main content

Cutting transformers layers

Project description

:scissors: Short Transformers

Normalized angular distance from initial layer l (x-axis) with block size n (y-axis).

pypi Version Ruff

Installation:

pip install short-transformers

Required additional dependencies: torch, transformers, datasets, accelerate.

Quickstart:

from short_transformers import ShortTransformer
from datasets import load_dataset

# load from path/hf_hub
model = ShortTransformer.from_pretrained(model_name)

# or use hf model
# model = ShortTransformer.from_model(hf_model)

# load hf dataset
dataset = load_dataset("allenai/c4", "en", split="validation", streaming=True)

# remove 5 layers, use the dataset to find the least important layers to remove
short_model = model.remove_layers(block_size=5, dataset=dataset, limit=1000)

# continue training to heal after the cut
# ...

# save as hf model
short_mdoel.save_pretrained(output_path)

Both short_model and the saved model are fully compatible with transformers. See examples/basic.py for a complete working example.

Pruning in steps:

Pruning can composed step-by-step and customized:

  1. Analyze model layers:
from datasets import load_dataset
from short_transformers import ShortTransformer
from short_transformers.utils import (
    draw_diagram,
    get_scored_blocks,
    get_best_pruning_start,
)
# load from path/hf_hub
model_name = "meta-llama/Meta-Llama-3-8B"

model = ShortTransformer.from_pretrained(model_name, device_map="auto")

dataset = load_dataset("allenai/c4", "en", split="validation", streaming=True)

# calculate distances between inputs/outputs from/to model layers
# results in a triangular numpy array of shape (layer_count, layer_count)
# results[x, y] - averaged distances for block of size x starting at layer y
results = model.analyse_layers(
    dataset=dataset,
    tokenizer=tokenizer,
    use_chat_template=False,
    key="text",
    limit=100,
    max_length=1000,
)

# draw results
# diagrams style matches the style of original article
# "The Unreasonable Ineffectiveness of the Deeper Layers"
draw_diagram(results, "results.png", title="Meta-Llama-3-8B", normalized=True)

Example output:

  1. Find optimal block_size and start_layer:
# find optimial block of size 'block_size' to prune
start_layer = get_best_pruning_start(results, block_size=5)

# evaluate all possibe block sizes to prune,
# for each block returns score 0-1
# which is averaged over samples distance between input and output to/from a block
block_score = get_scored_blocks(results, return_md=True, threshold=0.3)

Example output:

Block_size Removed_layers Score (avg dist)
1 25-25 0.123
2 24-25 0.155
3 25-27 0.181
4 24-27 0.204
5 23-27 0.226
6 22-27 0.248
7 22-28 0.268
8 20-27 0.291
  1. Pruning layers:
# prune 5-layers block
model.prune(start_layer=start_layer, block_size=5)

# save the pruned model
model.save_pretrained("model_output_dir")

See example/prune_in_steps.py for a complete working example.

  1. Changing the pruning method:

Default pruning method is based on angular distance of the last token. It is possible to overwrite the distance by using model.set_metric(some_callable) before model.analyse_layers().

# ...
from short_transformers.dist import get_angular_distance_ith_token

model_name = "meta-llama/Meta-Llama-3-8B"
model = ShortTransformer.from_pretrained(model_name, device_map="auto")

# choose metric
# calculate distances based on the angular distance of the i=0 token
model.set_metric(get_angular_distance_ith_token(i=0))

# load dataset ...

results = model.analyse_layers(
    dataset=dataset,
    tokenizer=tokenizer,
    key="text",
    limit=1,
    max_length=1000,
)

Supported metric for layer importance calculation:

Example outputs:

Meta-Llama-3-8B-Instruct

Layerwise distances:

Blockwise distances:

Euclidian Dist Last Token

Figure 1: Euclidian Dist Last Token. Figure 2: Euclidian Dist Last Token Normalised

Relative Magnitude

Figure 1: Relative Magnitude. Figure 2: Relative Magnitude Normalised

Bi Score

Figure 1: Bi Score. Figure 2: Bi Score Normalised

Linear Approximation Last Token

Figure 1: Linear Approximation Last Token. Figure 2: Linear Approximation Last Token Normalised

Angular Distance All Tokens

Figure 1: Angular Distance All Tokens. Figure 2: Angular Distance All Tokens Normalised

Angular Distance Last Token

Figure 1: Angular Distance Last Token. Figure 2: Angular Distance Last Token Normalised

Yi-1.5-9B-Chat-16K

Layerwise distances:

Blockwise distances:

Euclidian Dist Last Token

Figure 1: Euclidian Dist Last Token. Figure 2: Euclidian Dist Last Token Normalised

Relative Magnitude

Figure 1: Relative Magnitude. Figure 2: Relative Magnitude Normalised

Bi Score

Figure 1: Bi Score. Figure 2: Bi Score Normalised

Linear Approximation Last Token

Figure 1: Linear Approximation Last Token. Figure 2: Linear Approximation Last Token Normalised

Angular Distance Last Token

Figure 1: Angular Distance Last Token. Figure 2: Angular Distance Last Token Normalised

Citing:

If you use Short Transformers in your research, please cite with the following BibText

@misc{russak2024shorttransformers,
    title  = {ShortTransformers, optimal layer pruning tools},
    author = {Melisa Russak},
    url    = {https://github.com/melisa/short-transformers},
    year   = {2024}
}
@misc{gromov2024unreasonable,
      title={The Unreasonable Ineffectiveness of the Deeper Layers}, 
      author={Andrey Gromov and Kushal Tirumala and Hassan Shapourian and Paolo Glorioso and Daniel A. Roberts},
      year={2024},
      eprint={2403.17887},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
@misc{razzhigaev2024transformer,
      title={Your Transformer is Secretly Linear}, 
      author={Anton Razzhigaev and Matvey Mikhalchuk and Elizaveta Goncharova and Nikolai Gerasimenko and Ivan Oseledets and Denis Dimitrov and Andrey Kuznetsov},
      year={2024},
      eprint={2405.12250},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{men2024shortgpt,
      title={ShortGPT: Layers in Large Language Models are More Redundant Than You Expect}, 
      author={Xin Men and Mingyu Xu and Qingyu Zhang and Bingning Wang and Hongyu Lin and Yaojie Lu and Xianpei Han and Weipeng Chen},
      year={2024},
      eprint={2403.03853},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
@misc{samragh2023weight,
      title={Weight subcloning: direct initialization of transformers using larger pretrained ones}, 
      author={Mohammad Samragh and Mehrdad Farajtabar and Sachin Mehta and Raviteja Vemulapalli and Fartash Faghri and Devang Naik and Oncel Tuzel and Mohammad Rastegari},
      year={2023},
      eprint={2312.09299},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

short_transformers-1.0.0.tar.gz (14.3 kB view details)

Uploaded Source

Built Distribution

short_transformers-1.0.0-py3-none-any.whl (17.1 kB view details)

Uploaded Python 3

File details

Details for the file short_transformers-1.0.0.tar.gz.

File metadata

  • Download URL: short_transformers-1.0.0.tar.gz
  • Upload date:
  • Size: 14.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.12 Linux/6.5.0-1020-gcp

File hashes

Hashes for short_transformers-1.0.0.tar.gz
Algorithm Hash digest
SHA256 73136a972fee817916c3d85f53312db8a0b1b87e5aa8041aac2e4b3ce6c2b163
MD5 a341854e9da14fa0741c1e63fa716beb
BLAKE2b-256 83a3e861a56f66990aac735066eb31c0a0ac3f79108e0d753b6f139d5e0f38c0

See more details on using hashes here.

File details

Details for the file short_transformers-1.0.0-py3-none-any.whl.

File metadata

File hashes

Hashes for short_transformers-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f1eb5c1493d2827efc0e94b20267418094ab0c3b724e684003bbf5c87f8676d7
MD5 53fec61e85bbd232e6221bf1fc356eea
BLAKE2b-256 d75a382c6766b8ade350d01626bffd74ce670603a9d6bd527e898b2c1eb2b985

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page