Microlib for sampling from an LLM
Project description
LLM Sampler
Install with:
pip install llm_sampler
Contents
Quick example
Sample from an LLM with temperature:
import torch
from llm_sampler import sample
# Initializes the forward_func.
# This could be any function that returns logits when given input tokens
# For example, Hugggingface Models, LLaMa, Falcon, etc.
forward_func = load_model()
input_ids = tokenize_input("Magnus Carlsen had won the World ") # Tokenize the input
max_new_tokens = 10 # Number of new tokens to generate
generated_tokens = sample(
forward_func=forward_func,
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=0.6,
warp_top_k=10
)
for next_token in generated_tokens:
print("Next token:", next_token)
Sample from an LLM with multiple choice:
from llm_sampler import sample_multiple_choice
# Initializes the forward_func.
# This could be any function that returns logits when given input tokens
# For example, Hugggingface Models, LLaMa, Falcon, etc.
forward_func = load_model()
generator = sample_multiple_choice(
forward_func=forward_func,
input_ids=tokenize_input("The sentiment of the sentence 'I loved it' is '"),
all_continuation_ids=[
tokenize_input("positive"),
tokenize_input("negative")
]
)
raw_seqs = list(generator)
# raw_seqs is now [tensor([0.2031], dtype=torch.bfloat16), tensor([-1.5781], dtype=torch.bfloat16)]
What is it
llm_sampler
is a microlib which allows you to sample from an LLM, or give the probability scores for
sequences given by the user.
For example, if you supply the input:
Input: The sentiment of the sentence 'I loved it' is
- Option 0:
positive
- Option 1:
negative
This lib will return the probabilities for the options.
In that sense, llm_sampler
can be used as a zero-shot classifier.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
llm_sampler-0.1.1.tar.gz
(3.9 kB
view hashes)
Built Distribution
Close
Hashes for llm_sampler-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 177fa8007af7aa9ea28ed6186a92601c12c59f0a991f810a3178d02d865442f2 |
|
MD5 | 317f79bf1c503192aa6ba924c4a4a98e |
|
BLAKE2b-256 | 0f8e9d441c4ca1fa728086b621b71aec1af617bd1d969a1f4c2b22ce47a9bc75 |