Attach custom heads to transformer models.
Project description
Documentation | Getting Started
Transformer Heads
This library aims to be an allround toolkit for attaching, training, saving and loading of new heads for transformer models.
A new head could be:
- A linear probe used to get an understanding of the information processing in a transformer architecture
- A head to be finetuned jointly with the weights of a pretrained transformer model to perform a completely different kind of task.
- E.g. a transformer pretrained to do causal language modelling could get a sequence classification head attached and be finetuned to do sentiment classification.
- Or one could attach a regression head to turn a large language model into a value function for a reinforcement learning problem.
On top of that, attaching multiple heads at once can make multi-task learning easy, making it possible to train very general models.
Installation
Install from pypi: pip install transformer-heads
.
Or, clone this repo and from the root of this repository:
pip install -e .
Usage
Create head configurations
head_config = HeadConfig(
name=f"imdb_head_3",
layer_hook=-3,
in_size=hidden_size,
output_activation="linear",
pred_for_sequence=True,
loss_fct="cross_entropy",
num_outputs=2,
)
Create a model with your head from a pretrained transformer model
model = load_headed(
LlamaForCausalLM,
"meta-llama/Llama-2-7b-hf",
head_configs=[heads_config],
)
Train you model using (for example) the simple to use huggingface Trainer interface:
trainer = Trainer(
model,
args=args,
train_dataset=imdb_dataset["train"],
data_collator=collator,
)
For a more in-depth introduction and a fully working example, check the linear probe notebook.
Joint training of multiple linear probes
Notebooks
This repository contains multiple jupyter notebooks for a tutorial/illustration of how do do certain things with this library. Here is an overview of which notebook you should check out depending on the use you are interested in.
- Linear Probes (understanding the inner workings of transformers)
- Basic example with one probe for causal LM: notebooks/gpt2/linear_probe.ipynb
- Train many probes for causal LM at once: notebooks/gpt2/multi_linear_probe.ipynb
- Train many probes for text classification at once: notebooks/gpt2/text_classification_linear_probe.ipynb
- Finetuning on a new type of task (with a new head)
- QLoRA: notebooks/gpt2/text_classification_qlora.ipynb
- Full finetuning: notebooks/gpt2/text_classification_full_finetune.ipynb
- Joint multi-task learning
- Many heads doing completely different tasks + QLoRA, all trained at the same time: notebooks/gpt2/joint_multitask_learning.ipynb
- Regression with pretrained transformers
- Check the regression heads of this notebook: notebooks/gpt2/joint_multitask_learning.ipynb
- Saving and loading
Joint multi-task training with different types of heads and QLoRA.
More custom loss functions and models
At the state of writing, only a subset of loss functions / models are supported out of the box. At the time of writing, the supported models are Mistral-7b
, LLaMA 2
(all model sizes) and gpt2
. Check transformer_heads/constants.py for more up to date info.
However, it is not so hard to add/use different loss functions/models. You'll just need to add their respective information to loss_fct_map
and model_type_map
. Just import from transformer_heads.constants
. To add a loss function, add a mapping from string to torch class. To add a model add a mapping from model type to a 2 tuple out of attribute name of the base model in the Model Class and Base model class. That may sound confusing, but what that means is just the following:
from transformer_heads.constants import model_type_map, loss_fct_map
import torch.nn as nn
from transformers import MistralModel
loss_fct_map["bce"] = nn.BCELoss()
model_type_map["mistral"] = ("model",MistralModel)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for transformer_heads-0.0.4-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 976b0a137195adc65a1b7c02c07413393912a75c4129f262a154d94965e42cb1 |
|
MD5 | 4debc1b461ab8f2ffd4debc24180f762 |
|
BLAKE2b-256 | dc98a0aa75416e4eeca750d224cf10481632ee0f82ba7469715371512c0f2158 |