Large autoregressive language modeling helpers
Project description
transformer-utils
Utilities for the HuggingFace transformers library, focused on loading and using large pretrained autoregressive language models like GPT-2 and GPT-Neo.
This package is unofficial and not associated with HuggingFace.
Features:
- Load large (~2.7B) models in low-resource environments like Google Colab
- Get activations from any part of the model, without running parts you don't need
- Interpret models with the "logit lens"
- For background, see
- "interpreting GPT: the logit lens" by nostalgebraist
- "Finding the Words to Say: Hidden State Visualizations for Language Models" by Jay Alammar
- For background, see
Example usage
Load in a low-memory environment
Loading a 2.7B model:
from transformer_utils.low_memory import enable_low_memory_load
enable_low_memory_load()
model = transformers.AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-2.7B')
This works fine in an ordinary (non-Pro) Google Colab notebook, with ~12 GB RAM and a T5 GPU.
Inference will work up to the full context window length of 2048 tokens without memory issues.
Logit lens
import torch
import transformers
from transformer_utils.low_memory import enable_low_memory_load
enable_low_memory_load()
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
model = transformers.AutoModelForCausalLM.from_pretrained('gpt2-xl')
def text_to_input_ids(text):
toks = tokenizer.encode(text)
return torch.as_tensor(toks).view(1, -1).cuda()
input_ids = text_to_input_ids("This is an example. You can probably think of a more fun text to use than this one.")
plot_logit_lens(model, tokenizer, input_ids, start_ix=0, end_ix=45) # logits
plot_logit_lens(model, tokenizer, input_ids, start_ix=0, end_ix=45, probs=True) # probabilities
plot_logit_lens(model, tokenizer, input_ids, start_ix=0, end_ix=45, kl=True) # K-L divergence
You can do also some other things that aren't in the original blog posts. This will break down the transformer blocks into their attention and MLP parts:
plot_logit_lens(model, tokenizer, input_ids, start_ix=0, end_ix=45, include_subblocks=True)
You can also change the definition of the "decoder" to include some of the later blocks/subblocks of the model. This helps especially in interpreting GPT-Neo hidden states.
# assume we have a 48-layer model
# so 'h.47' is the final layer
# include last layer in decoder
plot_logit_lens(
model, tokenizer, input_ids, start_ix=0, end_ix=45,
decoder_layer_names=['h.47', 'final_layernorm', 'lm_head']
)
# include just the last MLP subblock in decoder
plot_logit_lens(
model, tokenizer, input_ids, start_ix=0, end_ix=45,
decoder_layer_names=['h.47.mlp', 'final_layernorm', 'lm_head']
)
Get activations from any part of the model
...and without running parts you don't need
from transformer_utils.partial_forward import partial_forward
output = partial_forward(
model=model, # your `transformers` model
output_names=[
'h.0', # output of the 1st layer
'h.2.attn.c_attn', # query/key/value matrix from the 3rd layer
'h.5.mlp.c_proj', # feed-forward activations from the 6th layer
],
input_ids # the input to run
)
# each of these is a tensor
output['h.0']
output['h.2.attn.c_attn']
output['h.5.mlp.c_proj']
For efficiency, partial_forward
doesn't run any part of the model later than the ones you specify in output_names
.
For example, suppose model
above was GPT-2 XL. Then it has 48 layers. But the forward pass in the code above stops running after the 6th layer of 48 -- so the compute and memory cost is far lower than a full model.forward
.
This makes it easy to write new "heads" that do further computation on the model's activations.
Some examples:
Using the first two layers of a model as features extractors for binary classification
output_names=['h.0', 'h.1',]
classifier_hidden_size=768
feature_vector_size = base_model.config.n_embd * len(output_names)
classifier = nn.Sequential(
nn.Linear(feature_vector_size, classifier_hidden_size),
nn.ReLU(),
nn.Linear(classifier_hidden_size, 2),
)
opt = torch.optim.Adam(classifier.parameters())
for input_ids, targets in dataset: # `dataset` is your classification train data
with torch.no_grad():
hidden_states = partial_forward(
base_model,
output_names,
input_ids,
)
# shape (batch, sequence, len(output_names) * model's hidden size)
feature_vector = torch.cat(
[hidden_states[name] for name in output_names],
dim=-1
)
# shape (batch, sequence, 2)
classifier_out = classifier(feature_vector)
# simple avg pool over sequence dim -- in practice find attention works well for this step :)
# shape (batch, 2)
logits = classifier_out.mean(dim=1)
loss = F.cross_entropy(target=targets, input=logits)
loss.backward()
opt.step()
opt.zero_grad()
Finetuning the first two layers of a model
This is exactly the same as the above, with just two changes:
- Remove the
with torch.no_grad()
wrapper aroundpartial_forward
- Optimize the base model's params too:
opt = torch.optim.Adam(list(classifier.parameters()) + list(base_model.parameters()))
If you want to train a model like these ones for real use, I recommend writing a custom nn.Module
. See here for an example.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file transformer-utils-0.1.1.tar.gz
.
File metadata
- Download URL: transformer-utils-0.1.1.tar.gz
- Upload date:
- Size: 14.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.3 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.10 tqdm/4.64.0 importlib-metadata/4.8.3 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.5 CPython/3.6.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f6f16d28e99e3a575d26fe4293096d77c247ba0b18830bb34c2f89dd4bbce4c0 |
|
MD5 | 96a8efc8767004335efb702f03734cc6 |
|
BLAKE2b-256 | 032708e48b6f46b848ebbf50b809b2fe5a71343c73d5112222fc1b4b9c2aedf1 |
File details
Details for the file transformer_utils-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: transformer_utils-0.1.1-py3-none-any.whl
- Upload date:
- Size: 17.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.3 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.10 tqdm/4.64.0 importlib-metadata/4.8.3 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.5 CPython/3.6.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1bdd30bbc2fd21db11795431e931021b9b91144d1a450d25fa04803db63043bb |
|
MD5 | 2e8f373c037c286a8f7e10e84957844f |
|
BLAKE2b-256 | 451ab7aca6edae9b8bfb5a20ba500fc572f3fed69f558ae809e2c42c1301dcd2 |