Skip to main content

Attach custom heads to transformer models.

Project description

Documentation | Getting Started | Reddit Post with more info

Transformer Heads

This library aims to be an allround toolkit for attaching, training, saving and loading of new heads for transformer models.
A new head could be:

  • A linear probe used to get an understanding of the information processing in a transformer architecture
  • A head to be finetuned jointly with the weights of a pretrained transformer model to perform a completely different kind of task.
    • E.g. a transformer pretrained to do causal language modelling could get a sequence classification head attached and be finetuned to do sentiment classification.
    • Or one could attach a regression head to turn a large language model into a value function for a reinforcement learning problem.

On top of that, attaching multiple heads at once can make multi-task learning easy, making it possible to train very general models.

Installation

Install from pypi: pip install transformer-heads.

Or, clone this repo and from the root of this repository: pip install -e .

Usage

Create head configurations

head_config = HeadConfig(
    name=f"imdb_head_3",
    layer_hook=-3,  # Attach at the output of the third-to-last transformer-block
    in_size=hidden_size,
    output_activation="linear",
    pred_for_sequence=True,
    loss_fct="cross_entropy",
    num_outputs=2,
    target="label" # The name of the ground-truth column in the dataset
)

Create a model with your head from a pretrained transformer model

model = load_headed(
    LlamaForCausalLM,
    "meta-llama/Llama-2-7b-hf",
    head_configs=[heads_config],
)

Train you model using (for example) the simple to use huggingface Trainer interface:

trainer = Trainer(
    model,
    args=args,
    train_dataset=imdb_dataset["train"],
    data_collator=collator,
)

For a more in-depth introduction and a fully working example, check the linear probe notebook.

Joint training of multiple linear probes

_images/multi_linear_probe.svg

Notebooks

This repository contains multiple jupyter notebooks for a tutorial/illustration of how do do certain things with this library. Here is an overview of which notebook you should check out depending on the use you are interested in.

Joint multi-task training with different types of heads and QLoRA.

_images/example_architecture.svg

More custom loss functions and models

At the state of writing, only a subset of loss functions / models are supported out of the box. At the time of writing, the supported models are Mistral-7b, LLaMA 2 (all model sizes) and gpt2. Check transformer_heads/constants.py for more up to date info.

However, it is not so hard to add/use different loss functions/models. You'll just need to add their respective information to loss_fct_map and model_type_map. Just import from transformer_heads.constants. To add a loss function, add a mapping from string to torch class. To add a model add a mapping from model type to a 2 tuple out of attribute name of the base model in the Model Class and Base model class. That may sound confusing, but what that means is just the following:

from transformer_heads.constants import model_type_map, loss_fct_map
import torch.nn as nn
from transformers import MistralModel

loss_fct_map["bce"] = nn.BCELoss()
model_type_map["mistral"] = ("model",MistralModel)

Can my transformer architecture be supported?

One of the basic assumtions of my library is that there is a transformer class such as the LlamaForCausalLM class of huggingface that has an attribute pointing to a base model that outputs raw hidden state. If your transformers model is built up in a similar way, adding support may be as easy as adding an entry to the model_type_map with the name of the attribute and the class of the base model. You can either do that by importing from constants.py or by adding it directly and creating a pull request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformer_heads-0.0.5.tar.gz (148.2 kB view details)

Uploaded Source

Built Distribution

transformer_heads-0.0.5-py3-none-any.whl (22.9 kB view details)

Uploaded Python 3

File details

Details for the file transformer_heads-0.0.5.tar.gz.

File metadata

  • Download URL: transformer_heads-0.0.5.tar.gz
  • Upload date:
  • Size: 148.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.0.0 CPython/3.12.2

File hashes

Hashes for transformer_heads-0.0.5.tar.gz
Algorithm Hash digest
SHA256 3681bca5b3ac2b069503792b8fd075f7ef7fc112c3fd7ffa09b0801b2c431f9c
MD5 4f660ef5cc9e5cebc21a701d2df0eede
BLAKE2b-256 deeca908ed58a1f9934e57c58b1996213c478cb938e9e88c77798897e41fee7a

See more details on using hashes here.

File details

Details for the file transformer_heads-0.0.5-py3-none-any.whl.

File metadata

File hashes

Hashes for transformer_heads-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 15c3fee912382c7cd8a5a5450c4f13720d1bc1aa4f6ff4cbff4164bc0c2e668e
MD5 cc5d03835abe914e1e8c565d5151f4cf
BLAKE2b-256 07b2af97dd3d2b9b03b0725d1f6ff791d84ccdf283c448bf3e8a3b7bfa4517d9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page