Skip to main content

Type-hinted interface to use several decoders on text-generation models

Project description

Decoder Magic

Concept

The fluency and usefulness of text generation models depends on the decoder used to select tokens from probabilities and build the text output.

Greedy decoding always selects the most probable token, random sampling considers all possible tokens with their given probability.

We would like a common API with type hints, helpful error messages and logs, parameter and random seed restrictions, etc. to make the method of decoding clear and reproducible. In the future this should support many decoder types.

Documentation

I would like to expand on the documentation in all of the decoder options, links to relevant papers etc., to make this library and the overall decoder concept accessible to new users.

Supported methods

  • ContrastiveSearch (params: random_seed, penalty_alpha, top_k)
  • GreedyDecoder
  • RandomSampling (params: random_seed)
  • TypicalDecoder (params: random_seed, typical_p)

Writer Examples (text input and output)

from decoder import BasicWriter, RandomSampling

basic = BasicWriter('gpt2', RandomSampling)
writer_output = basic.write_text(
    prompt="Hello, my name is", max_length=20, early_stopping=True
)

Decoder Examples (with customization)

Start with a HuggingFace Transformers / PyTorch model and tokenized text:

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
content = tokenizer.encode("Hello, my name is", return_tensors="pt")

Example with Transformers' default greedy decoder:

decoder1 = GreedyDecoder(model)
greedy_output = decoder1.generate_text(
    prompt=content, max_length=20, early_stopping=True
)
tokenizer.decode(greedy_output[0], skip_special_tokens=True)

Example with typical decoding, which will require a random_seed before generating text, and a typical_p between 0 and 1:

decoder3 = TypicalDecoder(model, random_seed=603, typical_p=0.4)
typical_output = decoder3.generate_text(
    prompt=content, max_length=20, early_stopping=True
)

# new random seed
decoder3.set_random_seed(101)
typical_output_2 = decoder3.generate_text(
    prompt=content, max_length=20, early_stopping=True
)

License

Apache license for compatibility with the Transformers library

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

decoder_ring-0.1.0.tar.gz (8.5 kB view details)

Uploaded Source

Built Distribution

decoder_ring-0.1.0-py3-none-any.whl (11.8 kB view details)

Uploaded Python 3

File details

Details for the file decoder_ring-0.1.0.tar.gz.

File metadata

  • Download URL: decoder_ring-0.1.0.tar.gz
  • Upload date:
  • Size: 8.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for decoder_ring-0.1.0.tar.gz
Algorithm Hash digest
SHA256 e627f68aba2f31aa81aaa285384a902bb4e7455830342a7acbd263d43a195d47
MD5 e6291761adbe3d6b9a5390bee313fd57
BLAKE2b-256 f844e74079e26cf809f48965a4633916cf8930c56d4e3c2df276ce7501f614cb

See more details on using hashes here.

File details

Details for the file decoder_ring-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: decoder_ring-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 11.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for decoder_ring-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 732941d64c0f5e3686bf7dd576b45081c70ba54ceadcf5d89c77b9bbe24062c1
MD5 ca84ee13fcd7291fef726a9bce2dcb63
BLAKE2b-256 77a0b232092c4ecfbe40665696dd47169fea090e192162d22df05c36bdaedb63

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page