Skip to main content

PyTorch Implementation for PASTA, A Post-hoc Attention Steering Approach that enables users to emphasize specific contexts for LLMs.

Project description

🍝 PASTA: Post-hoc Attention Steering for LLMs 🍝

Tell Your Model Where to Attend: Post-hoc Attention Steering for LLMs (Zhang et al. 2023).

PASTA allows a user to improve LLM controllability by simply emphasizing part of the prompt (e.g. the instruction) that the LLM should focus on. It requires no changes to LLM weights and no increase in inference time.

Quickstart -- use PASTA for improved inference

  1. Install pastalib:
pip install pastalib
# Alternatively,  
# clone then pip install -e .
# pip install git+https://github.com/QingruZhang/PASTA
  1. Initialize a pre-trained LLM and PASTA.
from pastalib.pasta import PASTA 
from transformers import AutoModelForCausalLM,AutoTokenizer

# Initialize pre-trained LLM
name = "huggyllama/llama-7b"
model = AutoModelForCausalLM.from_pretrained(name)
tokenizer = AutoTokenizer.from_pretrained(name)

# Select the attention heads to be steered, 
# following the format of {'layer_id': [head_ids]}: 
head_config = {
    "3": [17, 7, 6, 12, 18], "8": [28, 21, 24], "5": [24, 4], 
    "0": [17], "4": [3], "6": [14], "7": [13], "11": [16], 
}

# Initialize the PASTA steerer
pasta = PASTA(
    model=model,
    tokenizer=tokenizer,
    head_config=head_config, 
    alpha=0.01, # scaling coefficient
    scale_position="exclude", # downweighting unselected tokens
)
  1. Select specific input spans to emphasize, and then run inference as normal.
# Model Input 
texts = ["Mary is a doctor. She obtains her bachelor degree from UCSD. Answer the occupation of Mary and generate the answer as json format."]

# ===== Without PASTA =====
# inputs = tokenizer(texts, return_tensors="pt")
# outputs = model.generate(**inputs)
# ---------------------
# ["The answer should be in json format."]  # returns answer in the wrong format

# ===== With PASTA =====
inputs, offset_mapping = pasta.inputs_from_batch(texts)
# User highlights specific input spans
emphasized_texts = ["Answer the occupation of Mary and generate the answer as json format"]
# PASTA registers the pre_forward_hook to edit attention
with pasta.apply_steering(
    model=model, 
    strings=texts, 
    substrings=emphasized_texts, 
    model_input=inputs, 
    offsets_mapping=offset_mapping
) as steered_model: 
    outputs = steered_model.generate(**inputs, max_new_tokens=128)
# -------------------------------
# ['{"name": "Mary", "occupation": "Doctor", ...}']  # returns answer in the correct format

Additional Note

  1. pastalib works with any models that apply causal attention by summing up query-key inner product with attention mask and can be applied to LLMs in a plug-and-play manner. For example, LlamaForCausalLM and GPTJForCausalLM from transformers, whose attention moduels apply attention masks following torch.matmul(query, key) + attention_mask. However, pastalib currently only supports LLAMA, LLAMA-2 and GPT-J (more models in progress!).

  2. We provide different options of head_config for LLAMA-7B and GPT-J in the folder of config/head_config, including multi-task, task-agnostic and task-specific settings. Please see detailed discussion in our paper.

Overview

The overview of this repo is as follows:

  • pastalib: contains the source code of PASTA libary, which can be applied to models from huggingface transformers.
  • evaluation: consists of evaluation pipelines for different tasks, including data/model preprocessing and task evaluators/metrices.
  • config: includes the head_config for steering attention modules of LLAMA-7B and GPT-J with PASTA.
  • scripts: consists of running scripts of four tasks: JSON Formatting, Pronouns Changing, Bias in Bios, and CounterFact.

Evaluation

The evaluation pipeline are mainly refactored from REMEDI repo. Please see more details there.

Environment Setup

Set up the environment with the following commands:

conda create -n pasta python=3.10 
pip install -r requirements.txt 
pip install -e . 

python -m spacy download en_core_web_sm
python -W ignore -m nltk.downloader punkt cmudict

By default, the preprocessed datasets, models, and results are saved in the local directory of ./data, ./models, and ./results. You can change the directory of their by setting the environment variables:

export CM_DATA_DIR=<data path> 
export CM_MODELS_DIR=<models path>
export CM_RESULTS_DIR=<results path>

Dataset Setup

  1. For CounterFact, our scripts can automatically download the dataset.

  2. For Bias in Bios, we cannot release the dataset without the authorization. The dataset must be downloaded with the official release. After downloading the data examples into the BIOS.pkl file, you can run the following scripts:

python reformat_dataset.py \
--biasbios_raw_path <path of BIOS.pkl> \
--biasbios_save_file biasbios.json 

Then, the preprocessed biasbios dataset file will be saved in CM_DATA_DIR/biasbios.json.

Evaluation

Choose any head_config files from config/head_config and evaluate the performance of PASTA with the following command.

JSON Formatting

python -m scripts.eval_biasbios_instruction \
--task json \
--apply_pasta \
--emphasized_text instruct \
--alpha 0.01 \
--scale_position exclude \
--pasta_head_config <head_config_path> \
--model huggyllama/llama-7b \
--prompt_idx 0 \
--batch_size 16 \
--max_new_tokens 128 \
--experiment_name llama_evaluation \
--device cuda 

Pronouns Changing

python -m scripts.eval_biasbios_instruction \
--task pronchange \
--apply_pasta \
--emphasized_text instruct \
--alpha 0.01 \
--scale_position exclude \
--pasta_head_config <head_config_path> \
--prompt_idx 0 \
--model huggyllama/llama-7b \
--max_new_tokens 128 \
--batch_size 16 \
--experiment_name llama_evaluation \
--device cuda 

Bias in Bios

python -m scripts.eval_bias_gen \
--model huggyllama/llama-7b \
--apply_pasta \
--alpha 0.01 \
--scale_position exclude \
--pasta_head_config <head_config_path> \
--max_length 256 \
--batch_size 16 \
--experiment_name llama_evaluation \
--device cuda 

CounterFact

python -m scripts.eval_fact_gen \
--model huggyllama/llama-7b \
--apply_pasta \
--alpha 0.01 \
--scale_position exclude \
--pasta_head_config <head_config_path> \
--add_unmediated_fact True \
--benchmarks efficacy paraphrase generation \
--experiment_name llama_evaluation

Contact

Please contact us or post an issue if you have any questions:

@misc{zhang2023tell,
    title={Tell Your Model Where to Attend: Post-hoc Attention Steering for LLMs}, 
    author={Qingru Zhang and Chandan Singh and Liyuan Liu and Xiaodong Liu and Bin Yu and Jianfeng Gao and Tuo Zhao},
    year={2023},
    eprint={2311.02262},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pastalib-0.1.3.tar.gz (54.6 kB view details)

Uploaded Source

Built Distribution

pastalib-0.1.3-py3-none-any.whl (61.2 kB view details)

Uploaded Python 3

File details

Details for the file pastalib-0.1.3.tar.gz.

File metadata

  • Download URL: pastalib-0.1.3.tar.gz
  • Upload date:
  • Size: 54.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.4

File hashes

Hashes for pastalib-0.1.3.tar.gz
Algorithm Hash digest
SHA256 bbde800c662c576513822f5c63a5d90656015bfb0548e7baeff09667f8b9a93e
MD5 aee992fff9c44635d9cacf556e381fb4
BLAKE2b-256 c931e98ba9f11838fcae7aae9e30f79ccca668850bfec47844da6af3a932340e

See more details on using hashes here.

File details

Details for the file pastalib-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: pastalib-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 61.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.4

File hashes

Hashes for pastalib-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 1f9359306008697c94261c770e293f2cb6984080ab3e4475842f617e0e9e5071
MD5 86337615a15386ce1afcb82bd9948857
BLAKE2b-256 9e3d9db37c462caf1b80557286005ab1ddec8cb72d425f532161e412347aff56

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page