Skip to main content

A generator-based, stateless autoregressive inference loop for language models compatible with HuggingFace's Transformers API.

This project has been archived.

The maintainers of this project have marked this project as archived. No new releases are expected.

Project description

autoregressive-language-model-generate

A generator-based, stateless autoregressive inference loop for language models compatible with HuggingFace's Transformers API. At each step, it yields logits from the model and expects the caller to send back the predicted next tokens. Easily integrates into custom sampling strategies (greedy, beam, top-k/p, etc).

Usage

Assume you have:

  • model
  • input_ids and attention_mask, shape (batch_size, seq_len)
import torch
from autoregressive_language_model_generate import autoregressive_language_model_generate

model = ...
input_ids = ...
attention_mask = ...

gen = autoregressive_language_model_generate(
    model,
    input_ids,
    attention_mask
)

logits = next(gen)

# Implement your sampling logic here
next_token_logits = logits[:, -1, :]
top_k = 50
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_scores = next_token_logits.masked_fill(indices_to_remove, -float('Inf'))
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)

# `next_tokens` has shape `(batch_size,)`
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

# Send `next_tokens` to generator, receive `logits`
logits = gen.send(next_tokens)

Contributing

Contributions are welcome! Please submit pull requests or open issues on the GitHub repository.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

File details

Details for the file autoregressive_language_model_generate-0.1.0a0.tar.gz.

File metadata

File hashes

Hashes for autoregressive_language_model_generate-0.1.0a0.tar.gz
Algorithm Hash digest
SHA256 0a128b624c7e7a8a3dfeca5efbdb8f9b390b8a7d4df90fdf56f6afcdc0a9cd5a
MD5 861b36318887f8050ad30a3151e101a1
BLAKE2b-256 939c133b86cb9b7e26c8bfc893dc77e6a358b44a67f84ac6314135b145bbe87b

See more details on using hashes here.

File details

Details for the file autoregressive_language_model_generate-0.1.0a0-py2.py3-none-any.whl.

File metadata

File hashes

Hashes for autoregressive_language_model_generate-0.1.0a0-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 06253a091503125297487198ccb8981c56a678abbc1ade843d0f185967c9b1cb
MD5 1683696fbe2339df2b4e1fbd192b1b7b
BLAKE2b-256 f83fa664301eb43af32ffac7045efcf13581d33cbb5dc6a6f8796c8de101ccc3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page