Type-hinted interface to use several decoders on text-generation models
Project description
Decoder Magic
Concept
The fluency and usefulness of text generation models depends on the decoder used to select tokens from probabilities and build the text output.
Greedy decoding always selects the most probable token, random sampling considers all possible tokens with their given probability.
We would like a common API with type hints, helpful error messages and logs, parameter and random seed restrictions, etc. to make the method of decoding clear and reproducible. In the future this should support many decoder types.
Documentation
I would like to expand on the documentation in all of the decoder options, links to relevant papers etc., to make this library and the overall decoder concept accessible to new users.
Supported methods
- ContrastiveSearch (params: random_seed, penalty_alpha, top_k)
- GreedyDecoder
- RandomSampling (params: random_seed)
- TypicalDecoder (params: random_seed, typical_p)
Writer Examples (text input and output)
from decoder import BasicWriter, RandomSampling
basic = BasicWriter('gpt2', RandomSampling)
writer_output = basic.write_text(
prompt="Hello, my name is", max_length=20, early_stopping=True
)
Decoder Examples (with customization)
Start with a HuggingFace Transformers / PyTorch model and tokenized text:
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
content = tokenizer.encode("Hello, my name is", return_tensors="pt")
Example with Transformers' default greedy decoder:
decoder1 = GreedyDecoder(model)
greedy_output = decoder1.generate_text(
prompt=content, max_length=20, early_stopping=True
)
tokenizer.decode(greedy_output[0], skip_special_tokens=True)
Example with typical decoding, which will require a random_seed
before generating text, and a typical_p
between 0 and 1:
decoder3 = TypicalDecoder(model, random_seed=603, typical_p=0.4)
typical_output = decoder3.generate_text(
prompt=content, max_length=20, early_stopping=True
)
# new random seed
decoder3.set_random_seed(101)
typical_output_2 = decoder3.generate_text(
prompt=content, max_length=20, early_stopping=True
)
License
Apache license for compatibility with the Transformers library
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file decoder_ring-0.1.0.tar.gz
.
File metadata
- Download URL: decoder_ring-0.1.0.tar.gz
- Upload date:
- Size: 8.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e627f68aba2f31aa81aaa285384a902bb4e7455830342a7acbd263d43a195d47 |
|
MD5 | e6291761adbe3d6b9a5390bee313fd57 |
|
BLAKE2b-256 | f844e74079e26cf809f48965a4633916cf8930c56d4e3c2df276ce7501f614cb |
File details
Details for the file decoder_ring-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: decoder_ring-0.1.0-py3-none-any.whl
- Upload date:
- Size: 11.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 732941d64c0f5e3686bf7dd576b45081c70ba54ceadcf5d89c77b9bbe24062c1 |
|
MD5 | ca84ee13fcd7291fef726a9bce2dcb63 |
|
BLAKE2b-256 | 77a0b232092c4ecfbe40665696dd47169fea090e192162d22df05c36bdaedb63 |