Skip to main content

Cutting transformers layers

Project description

:scissors: Short Transformers

pypi Version Ruff

Installation:

pip install short-transformers

Required additional dependencies: transformers, datasets.

Quickstart:

from short_transformers import ShortTransformer

# load from path/hf_hub
model = ShortTransformer.from_pretrained(model_name)

# or use hf model
model = ShortTransformer.from_model(hf_model)

# remove n layers, use hf dataset to find the optimal cut
short_model = model.remove_layers(block_size=5, dataset=dataset, limit=1000)

# continue training to heal after the cut
# ...

# save as hf model
short_mdoel.save_pretrained(output_path)

Example log output:

[2024-05-04 18:39:03] [short-transformers] Running inference on 1000 samples.
1000it [00:06, ?it/s]
[2024-05-04 18:39:09] [short-transformers] Inference results saved to `model.memory`.
[2024-05-04 18:39:09] [short-transformers] Choosing optimal 5-layers block to prune.
[2024-05-04 18:39:09] [short-transformers] Best 5-layers block to prune starts at layer: 23.
[2024-05-04 18:39:09] [short-transformers] Removing layers: [23, 24, 25, 26, 27]
[2024-05-04 18:39:09] [short-transformers] Changing model config to reflect changes:
        config.num_hidden_layers: 32 -> 27
        config._name_or_path: meta-llama/Meta-Llama-3-8B -> meta-llama/Meta-Llama-3-8B-27L
[2024-05-04 18:39:09] [short-transformers] Cleaning `model.memory`.

Both short_model and the saved model are fully compatible with transformers. See examples/basic.py for a complete working example.

Pruning in steps:

Pruning can composed step-by-step and customized:

  1. Analyze model layers:
from datasets import load_dataset
from short_transformers import ShortTransformer
from short_transformers.graph import draw_diagram

# load from path/hf_hub
model_name = "meta-llama/Meta-Llama-3-8B"

model = ShortTransformer.from_pretrained(model_name, device_map="auto")

dataset = load_dataset("allenai/c4", "en", split="validation", streaming=True)

# run inderence on a subset of the datsets
model.analyse_layers(
    dataset=dataset,
    key="text",
    limit=100,
    max_length=1000,
)

# after this step, results will be saved to model.memory, which can be saved
model.save_memory("memory.npz")

# and visualized as seaborh graph (see examples in the README.md)
draw_diagram("memory.npz", "memory.png")

Example output:

  1. Find optimal block size and start_layer:
# finding optimial block of size 'num' to prune
start_layer = model.get_optimal_cut(block_size=5)

# evaluating all possibe block sizes to prune,
# for each block returns score 0-1
# which is averaged over samples distance between input and output to/from a block
block_score = model.get_block_score_stats(return_md=True, threshold=0.3)

Example output:

Block_size Removed_layers Score (avg dist)
1 25-25 0.123
2 24-25 0.155
3 25-27 0.181
4 24-27 0.204
5 23-27 0.226
6 22-27 0.248
7 22-28 0.268
8 20-27 0.291
  1. Pruning layers:
# pruning 5-layers block
# this will also clean the `model.memory``
model.cut(start_layer=start_layer, block_size=5)

# or clean the model memory manually
model.clean_memory()

# save the pruned model
model.save_pretrained("model_path")

See example/prune_in_steps.py for a complete working example.

Supported pruning methods:

  • based on layer input/output distances:

    • angular distance of the last token (original)
    • averaged angular distances of all tokens
  • todo: based on layer linear replacement trining loss

Citing

If you use Short Transformers in your research, please cite with the following BibText

@misc{russak2024shorttransformers,
    title  = {ShortTransformers, optimal layer pruning tools},
    author = {Melisa Russak},
    url    = {https://github.com/melisa/short-transformers},
    year   = {2024}
}
@misc{gromov2024unreasonable,
      title={The Unreasonable Ineffectiveness of the Deeper Layers}, 
      author={Andrey Gromov and Kushal Tirumala and Hassan Shapourian and Paolo Glorioso and Daniel A. Roberts},
      year={2024},
      eprint={2403.17887},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

short_transformers-0.2.0.tar.gz (6.2 kB view hashes)

Uploaded Source

Built Distribution

short_transformers-0.2.0-py3-none-any.whl (7.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page