Skip to main content

Implementation of stop sequencer for Huggingface Transformers

Project description

Stop Sequencer

  • Implementation for stop sequencer for Huggingface Transformers
  • Because there is a limitation in implementation, post-processing must be used together.

1. Installation

pip install stop-sequencer

2. Usage

2.1. Generation without StopSequencer

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokens = tokenizer(
    "Kevin: Hello "
    "Ryan: Hi "
    "Kevin: What are you doing? "
    "Ryan: I am watching TV. you? "
    "Kevin: ",
    return_tensors="pt",
)["input_ids"]

outputs = model.generate(
    tokens,
    num_beams=5,
    no_repeat_ngram_size=4,
    repetition_penalty=1.5,
    max_length=100,
)

outputs = tokenizer.batch_decode(outputs[:, tokens.size(-1):], skip_special_tokens=True)[0]
print(outputs)
ive been watching TV for a long time. Ryan: I have been watching TV since I was 12 years old. Kevin: So what do you want me to do? Ryan: Well, I want you to watch TV. You know what I mean? I'm going to be watching TV. I'm not going to sit down and watch TV. I don't want to



2.2. Generation with StopSequencer

  • If you look at the example, you can see that Ryan: I have is generated and then generation is finished.
  • Due to the limitation of Huggingface Transformers, after stop texts are generated, the generation can be terminated by checking conditions.
from stop_sequencer import StopSequencer

stop_texts = ["Ryan:", "Kevin:"]

stop_sequencer = StopSequencer(
    model,
    model_type="causal",  # or seq2seq
    tokenizer=tokenizer,
)

model = stop_sequencer.register_stop_texts(
    stop_texts=stop_texts,
    input_length=tokens.size(-1),
)

outputs = model.generate(
    tokens,
    num_beams=5,
    no_repeat_ngram_size=4,
    repetition_penalty=1.5,
    max_length=100,
)

outputs = tokenizer.batch_decode(outputs[:, tokens.size(-1):], skip_special_tokens=True)[0]
print(outputs)
ive been watching TV for a long time. Ryan: I have



3. Generation with StopSequencer + post-processing

  • Therefore, post-processing must be performed to completely exclude stop texts from generated text.
for s in stop_texts:
    outputs = outputs.split(s)[0].strip()

print(outputs)
ive been watching TV for a long time.



License

Copyright 2021 Hyunwoong Ko.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

stop_sequencer-1.2.0-py3-none-any.whl (5.1 kB view details)

Uploaded Python 3

File details

Details for the file stop_sequencer-1.2.0-py3-none-any.whl.

File metadata

  • Download URL: stop_sequencer-1.2.0-py3-none-any.whl
  • Upload date:
  • Size: 5.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.49.0 CPython/3.7.3

File hashes

Hashes for stop_sequencer-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 14c83bf52e3a8ee08bb0a7f96d71c027ba8f8d0980c6c6efa14a365be182e301
MD5 10a43fc86935393cc304774da4d0b5e2
BLAKE2b-256 aded81564056c3cf82c7295fc143ee7586a6306ecc5a2cec9ec6b88be2609c08

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page