Skip to main content

Attach custom heads to transformer models.

Project description

Transformer Heads

This library aims to be an allround toolkit for attaching, training, saving and loading of new heads for transformer models.
A new head could be:

  • A linear probe used to get an understanding of the information processing in a transformer architecture
  • A head to be finetuned jointly with the weights of a pretrained transformer model to perform a completely different kind of task.
    • E.g. a transformer pretrained to do causal language modelling could get a sequence classification head attached and be finetuned to do sentiment classification.
    • Or one could attach a regression head to turn a large language model into a value function for a reinforcement learning problem.

On top of that, attaching multiple heads at once can make multi-task learning easy, making it possible to train very general models.

Check out the api documentation at Read the Docs.

Installation

Install from pypi: pip install transformer-heads.

Or, clone this repo and from the root of this repository: pip install -e .

Usage

Create head configurations

head_config = HeadConfig(
    name=f"imdb_head_3",
    layer_hook=-3,
    in_size=hidden_size,
    output_activation="linear",
    pred_for_sequence=True,
    loss_fct="cross_entropy",
    num_outputs=2,
)

Create a model with your head from a pretrained transformer model

model = load_headed(
    LlamaForCausalLM,
    "meta-llama/Llama-2-7b-hf",
    head_configs=[heads_config],
)

Train you model using (for example) the simple to use huggingface Trainer interface:

trainer = Trainer(
    model,
    args=args,
    train_dataset=imdb_dataset["train"],
    data_collator=collator,
)

For a more in-depth introduction and a fully working example, check the linear probe notebook.

Joint training of multiple linear probes

_images/multi_linear_probe.svg

Notebooks

This repository contains multiple jupyter notebooks for a tutorial/illustration of how do do certain things with this library. Here is an overview of which notebook you should check out depending on the use you are interested in.

Joint multi-task training with different types of heads and QLoRA.

_images/example_architecture.svg

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformer_heads-0.0.3.tar.gz (382.1 kB view details)

Uploaded Source

Built Distribution

transformer_heads-0.0.3-py3-none-any.whl (20.3 kB view details)

Uploaded Python 3

File details

Details for the file transformer_heads-0.0.3.tar.gz.

File metadata

  • Download URL: transformer_heads-0.0.3.tar.gz
  • Upload date:
  • Size: 382.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.0.0 CPython/3.12.2

File hashes

Hashes for transformer_heads-0.0.3.tar.gz
Algorithm Hash digest
SHA256 507038ac6d7c37a78431a0b525ec692dc5baae60dd9c834019e2fa1f694ebf36
MD5 d84bc81c01ff54e7845885ef8178a58c
BLAKE2b-256 3968f3791b2c2e738bfc7e03ed8ee8b3b4576cf70b098b503f0f44a8e7eeeea0

See more details on using hashes here.

File details

Details for the file transformer_heads-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for transformer_heads-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 8bc0a1ca69f1c5d2edba14834df9e88fb4b8399acaecc3d35aaeca7d02016a2a
MD5 d8d0ae8a6d924d147b3db10133bdd3a1
BLAKE2b-256 ac8eba1060f057d19dd0d3c0178b396efee09a63c4dc203605be9fa018753e19

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page