package for constrained text generation during decoding
Project description
seq2seq_constrained_decoding
This project includes constrained-decoding utilities for structured text generation using Huggingface seq2seq models.
Requirements
pip install torch transformers
The package is tested with transformers==4.15.0, might break for past or future versions.
Use cases
Several use-cases leverage pretrained sequence-to-sequence models, such as BART or T5, for generating a (maybe partially) structured text sequence. For example, you may want to finetine the model to generate a set of key-points summarizing an input paragraph, or to produce a text sequence that follows some strict regularities.
Below we detail about the use-cases supported by this project.
Word lists
The word_list_decoding.py module suggests simple LogitsProcessor classes which enforce allowed / forbidden words during text generation (see the WhiteListLogitsProcessor and BlackListLogitsProcessor classes respectively).
Example usage:
from transformers import LogitsProcessorList, AutoTokenizer, AutoModel
from constrained_decoding.word_list_decoding import BlackListLogitsProcessor
tokenizer = AutoTokenizer.from_pretrained('t5-small')
model = AutoModel.from_pretrained('t5-small')
bad_words = "here are words that should not occur in the generated text"
bad_word_ids = tokenizer.encode(bad_words)
black_list_processor = BlackListLogitsProcessor(bad_word_ids)
good_words = "only these words can occur in the generated text"
good_word_ids = tokenizer.encode(good_words)
white_list_processor = WhiteListLogitsProcessor(good_word_ids)
input_seq = "here are the input words to condition generated text upon"
input_ids = tokenizer.encode(input_seq, return_tensors='pt')
out = model.generate(input_ids, num_beams=10)
print(tokenizer.batch_decode(out))
# ['<pad> Hier are the input words to condition generated text upon</s>']
out = model.generate(input_ids, num_beams=10, logits_processor=[black_list_processor])
print(tokenizer.batch_decode(out))
# ['<pad> Voici voici les input mots to condition a condition a condition a condition a']
out = model.generate(input_ids, num_beams=10, logits_processor=[white_list_processor])
print(tokenizer.batch_decode(out))
# ['<pad> in the words in the words in the words in the words in the words in the words in']
Set decoding
In some scenarios, you would like to regard the output sequence as expressing a set of elements comprised of sub-sequences. For example, you might finetune your Seq2Seq model on a multi-label document classification task (e.g. generating the set of relation types occuring in the input document).
The set_decoding.SetDecodingLogitProcessor class can gurantee that no subsequence (e.g. a relation type) would occur more than once. Output subsequences are assumed to be defined using a special separator token.
DFA-based constrained decoding
The most powerful and generic constrained decoding algorithm we propose is using a Deterministic Finite Automata (DFA).
You can instanciate a DFA with the dfa.DFA class, defined over a dictionary of dictionaries.
For example, the following represents an automaton that accepts only binary numbers that are multiples of 3 (see illustration in the Wikipedia article on DFA):
from dfa import DFA
transitions = {0:{'0':0, '1':1},
1:{'0':2, '1':0},
2:{'0':1, '1':2}}
dfa = DFA(transitions, s0=0, accept_states=[0])
For defining constrained decoding using a DFA, the automaton's alphabet should correspond to tokens in the model's vocabulry.
The DFA class supports translating a dfa that uses regular words or phrases as alphabet into a tokenizer-adjusted dfa -
transitions = {0:{'John':1, 'Mike':1, 'Dan':1},
1:{'went':2, 'ran':2, 'jogged':2},
2:{'to':3, 'in':3},
3:{'the':4, 'a':4},
4:{'park':5}}
words_dfa = DFA(transitions, s0=0, accept_states=[5])
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-small")
tokens_dfa = words_dfa.adjust_for_tokenizer(tokenizer)
Eventually, generation constraining is achieved by replacing the model's beam_search method with an adapted version (in dfa_constrained_beam_search.py) that
enforces every beam to follow the automaton transitions. The probability of vocabulary entries that are not accessible according to the dfa will be set to minus infinity.
For example:
from transformers import AutoModel
model = AutoModel.from_pretrained("t5-small")
from dfa_constrained_beam_search import set_model_beam_search_to_dfa_constrained
set_model_beam_search_to_dfa_constrained(model, tokens_dfa)
# The previous two steps can equivalently be done in one call:
# set_model_beam_search_to_dfa_constrained(model, words_dfa, tokenizer)
Other supported utility functions within the DFA class include:
DFA.from_slots- for constructing a linear DFA out of a list of "slots", where each slot is represented by a list of allowed words / phrases.DFA.concat_twoandDFA.concat- for concatenating two or more (linear) DFAs into a long DFA.as_cyclic- for converting a linear DFA into a cyclic one, by merging or connecting some end-states with the initial state.
QA-SRL
Our motivational use-case is seq2seq-based QA-SRL parsing. In that project, we finetune BART/T5 on the Question-Answer driven Semantic Role Labeling task. Given a verb or nominalization in context, the task is to generate Question-Answer pairs capturing the participants of the verbal event.
To model the task using a seq2seq paradigm, the QAs are linearized into a single target sequence, using separators between QA pairs, between question and answer, and between multiple answers. Furthermore, QASRL questions must adhere a slot-based template, while answers could only be continuous spans copied from the input sentence.
Check out the qasrl.py module to see how we leverage the DFA utilities for enforcing a valid QASRL output sequence.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file constrained_decoding-0.1.0.tar.gz.
File metadata
- Download URL: constrained_decoding-0.1.0.tar.gz
- Upload date:
- Size: 58.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.7.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6e617e74d0c54a844ea4f104c0cdb4211b16dc9c8609af07afc3055f721a08f7
|
|
| MD5 |
ebf4d475d7e00387dcfa7374b6bbdf88
|
|
| BLAKE2b-256 |
d36aced55867be6daae8ff1483d73e46365889313f28a5ae8ee220ac2a5f1958
|