Verify LLM outputs using Gumbel-Max sampling verification
Project description
token-difr
Verify LLM outputs using Gumbel-Max sampling verification.
Installation
pip install token-difr
Requirements
- Python >= 3.10
- PyTorch >= 2.0.0
- vLLM >= 0.11.0
- CUDA-capable GPU
Quick Start
from token_difr import verify_outputs, TokenSequence
# Create token sequences to verify
# (typically from an untrusted LLM provider)
outputs = [
TokenSequence(
prompt_token_ids=[128000, 2323, 374, 264, 1296],
output_token_ids=[264, 1296, 13, 578, 4320]
)
]
# Verify the outputs against a trusted model
# All sampling parameters must match what was used during generation
results = verify_outputs(
outputs,
model_name="meta-llama/Llama-3.1-8B-Instruct",
temperature=1.0, # Required: must match generation
top_k=50, # Required: must match generation
top_p=0.95, # Required: must match generation
seed=42, # Required: must match generation
)
# Check verification results
for seq_idx, token_metrics in enumerate(results):
for tok_idx, metrics in enumerate(token_metrics):
print(f"Seq {seq_idx}, Token {tok_idx}:")
print(f" exact_match: {metrics.exact_match}")
print(f" prob: {metrics.prob:.4f}")
print(f" margin: {metrics.margin:.4f}")
print(f" logit_rank: {metrics.logit_rank}")
print(f" gumbel_rank: {metrics.gumbel_rank}")
API Reference
verify_outputs
def verify_outputs(
outputs: list[TokenSequence],
model_name: str,
*,
temperature: float, # Required
top_k: int, # Required
top_p: float, # Required
seed: int, # Required
max_model_len: int | None = None,
dtype: torch.dtype = torch.bfloat16,
vllm_kwargs: dict[str, Any] | None = None,
sampling_method: SamplingMethod = SamplingMethod.VLLM_GUMBEL_MAX,
gpu_memory_utilization: float = 0.7,
verbose: bool = True,
) -> list[list[TokenMetrics]]:
Parameters:
outputs: List ofTokenSequenceobjects containing prompt and output token IDsmodel_name: HuggingFace model name (e.g.,"meta-llama/Llama-3.1-8B-Instruct")temperature: Sampling temperature used during generation. Required.top_k: Top-k sampling parameter. Required.top_p: Top-p (nucleus) sampling parameter. Required.seed: Random seed used during generation. Required.max_model_len: Maximum model context length. If None, auto-computed from outputsdtype: Model dtype, e.g.,torch.bfloat16,torch.float16(default:torch.bfloat16)vllm_kwargs: Additional kwargs for vLLM's LLM constructor (e.g., for quantization)sampling_method: Sampling method to verify against (default:VLLM_GUMBEL_MAX)gpu_memory_utilization: Fraction of GPU memory to use (default: 0.7)verbose: Whether to show progress and print summary (default: True)
Returns:
List of lists of TokenMetrics, one per token in each output sequence.
TokenSequence
@dataclass
class TokenSequence:
prompt_token_ids: list[int]
output_token_ids: list[int]
TokenMetrics
@dataclass
class TokenMetrics:
exact_match: bool # Whether token matches under verification
prob: float # Probability of the actual token
margin: float # Margin between max and actual token scores
logit_rank: float # Rank of actual token by logit value
gumbel_rank: float # Rank of actual token by Gumbel score
SamplingMethod
class SamplingMethod(Enum):
VLLM_GUMBEL_MAX = "vllm_gumbel_max"
compute_metrics_summary
def compute_metrics_summary(results: list[list[TokenMetrics]]) -> dict:
"""Returns aggregate stats from verification results."""
# Returns: {
# "total_tokens": int,
# "exact_match_rate": float,
# "avg_prob": float,
# "avg_margin": float,
# "avg_logit_rank": float,
# "avg_gumbel_rank": float,
# }
Advanced Usage
Using Quantization
results = verify_outputs(
outputs,
model_name="meta-llama/Llama-3.1-8B-Instruct",
temperature=1.0, top_k=50, top_p=0.95, seed=42,
vllm_kwargs={"quantization": "fp8"},
)
Using FP8 KV Cache
results = verify_outputs(
outputs,
model_name="meta-llama/Llama-3.1-8B-Instruct",
temperature=1.0, top_k=50, top_p=0.95, seed=42,
vllm_kwargs={
"kv_cache_dtype": "fp8",
"calculate_kv_scales": True,
},
)
Custom Model Length
results = verify_outputs(
outputs,
model_name="meta-llama/Llama-3.1-8B-Instruct",
temperature=1.0, top_k=50, top_p=0.95, seed=42,
max_model_len=4096,
)
Environment Variables
The package automatically sets VLLM_USE_V1=1 (using setdefault, so you can override before importing). This enables vLLM v1 features required for prompt logprobs.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file token_difr-0.1.1.tar.gz.
File metadata
- Download URL: token_difr-0.1.1.tar.gz
- Upload date:
- Size: 9.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
09164a99eb6df04b24fc72d40a4b9d91cdee9ae98480d811531436399825956f
|
|
| MD5 |
9699a1d183881c6862398891442b02cc
|
|
| BLAKE2b-256 |
76f43666775f2ca53890b7126dae00b76a0290a6bb4bc445142edf9b97a00ed6
|
File details
Details for the file token_difr-0.1.1-py3-none-any.whl.
File metadata
- Download URL: token_difr-0.1.1-py3-none-any.whl
- Upload date:
- Size: 9.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a669bccda24de2161522881f66b0a7ebdf9cb7de1fbb4761fba7eb8f4a2cea2c
|
|
| MD5 |
2db79064e73b7c357b8735eb20e530cc
|
|
| BLAKE2b-256 |
08c6ba43b987bdbea2059a0941fd4dfe7a6a3b3edf8bf2aab3c4cb67f1ae73c2
|