A generator-based, stateless autoregressive inference loop for language models compatible with HuggingFace's Transformers API.
This project has been archived.
The maintainers of this project have marked this project as archived. No new releases are expected.
Project description
autoregressive-language-model-generate
A generator-based, stateless autoregressive inference loop for language models compatible with HuggingFace's Transformers API. At each step, it yields logits from the model and expects the caller to send back the predicted next tokens. Easily integrates into custom sampling strategies (greedy, beam, top-k/p, etc).
Usage
Assume you have:
modelinput_idsandattention_mask, shape(batch_size, seq_len)
import torch
from autoregressive_language_model_generate import autoregressive_language_model_generate
model = ...
input_ids = ...
attention_mask = ...
gen = autoregressive_language_model_generate(
model,
input_ids,
attention_mask
)
logits = next(gen)
# Implement your sampling logic here
next_token_logits = logits[:, -1, :]
top_k = 50
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_scores = next_token_logits.masked_fill(indices_to_remove, -float('Inf'))
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
# `next_tokens` has shape `(batch_size,)`
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# Send `next_tokens` to generator, receive `logits`
logits = gen.send(next_tokens)
Contributing
Contributions are welcome! Please submit pull requests or open issues on the GitHub repository.
License
This project is licensed under the MIT License.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file autoregressive_language_model_generate-0.1.0a0.tar.gz.
File metadata
- Download URL: autoregressive_language_model_generate-0.1.0a0.tar.gz
- Upload date:
- Size: 3.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0a128b624c7e7a8a3dfeca5efbdb8f9b390b8a7d4df90fdf56f6afcdc0a9cd5a
|
|
| MD5 |
861b36318887f8050ad30a3151e101a1
|
|
| BLAKE2b-256 |
939c133b86cb9b7e26c8bfc893dc77e6a358b44a67f84ac6314135b145bbe87b
|
File details
Details for the file autoregressive_language_model_generate-0.1.0a0-py2.py3-none-any.whl.
File metadata
- Download URL: autoregressive_language_model_generate-0.1.0a0-py2.py3-none-any.whl
- Upload date:
- Size: 4.2 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
06253a091503125297487198ccb8981c56a678abbc1ade843d0f185967c9b1cb
|
|
| MD5 |
1683696fbe2339df2b4e1fbd192b1b7b
|
|
| BLAKE2b-256 |
f83fa664301eb43af32ffac7045efcf13581d33cbb5dc6a6f8796c8de101ccc3
|