Verify LLM outputs using Gumbel-Max sampling verification
Project description
token-difr
Verify LLM outputs using Gumbel-Max sampling verification.
Installation
pip install token-difr
Requirements
- Python >= 3.10
- PyTorch >= 2.0.0
- vLLM >= 0.6.0
- CUDA-capable GPU
Quick Start
from token_difr import verify_outputs, TokenSequence
# Create token sequences to verify
# (typically from an untrusted LLM provider)
outputs = [
TokenSequence(
prompt_token_ids=[128000, 2323, 374, 264, 1296],
output_token_ids=[264, 1296, 13, 578, 4320]
)
]
# Verify the outputs against a trusted model
results = verify_outputs(
outputs,
model_name="meta-llama/Llama-3.1-8B-Instruct",
temperature=1.0,
top_k=50,
top_p=0.95,
seed=42,
)
# Check verification results
for seq_idx, token_metrics in enumerate(results):
for tok_idx, metrics in enumerate(token_metrics):
print(f"Seq {seq_idx}, Token {tok_idx}:")
print(f" exact_match: {metrics.exact_match}")
print(f" prob: {metrics.prob:.4f}")
print(f" margin: {metrics.margin:.4f}")
print(f" logit_rank: {metrics.logit_rank}")
print(f" gumbel_rank: {metrics.gumbel_rank}")
API Reference
verify_outputs
def verify_outputs(
outputs: list[TokenSequence],
model_name: str,
*,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.95,
seed: int = 42,
max_model_len: int | None = None,
dtype: torch.dtype = torch.bfloat16,
vllm_kwargs: dict[str, Any] | None = None,
sampling_method: SamplingMethod = SamplingMethod.VLLM_GUMBEL_MAX,
gpu_memory_utilization: float = 0.7,
show_progress: bool = True,
) -> list[list[TokenMetrics]]:
Parameters:
outputs: List ofTokenSequenceobjects containing prompt and output token IDsmodel_name: HuggingFace model name (e.g.,"meta-llama/Llama-3.1-8B-Instruct")temperature: Sampling temperature used during generation (default: 1.0)top_k: Top-k sampling parameter (default: 50)top_p: Top-p (nucleus) sampling parameter (default: 0.95)seed: Random seed used during generation (default: 42)max_model_len: Maximum model context length. If None, auto-computed from outputsdtype: Model dtype, e.g.,torch.bfloat16,torch.float16(default:torch.bfloat16)vllm_kwargs: Additional kwargs for vLLM's LLM constructor (e.g., for quantization)sampling_method: Sampling method to verify against (default:VLLM_GUMBEL_MAX)gpu_memory_utilization: Fraction of GPU memory to use (default: 0.7)show_progress: Whether to show progress bars (default: True)
Returns:
List of lists of TokenMetrics, one per token in each output sequence.
TokenSequence
@dataclass
class TokenSequence:
prompt_token_ids: list[int]
output_token_ids: list[int]
TokenMetrics
@dataclass
class TokenMetrics:
exact_match: bool # Whether token matches under verification
prob: float # Probability of the actual token
margin: float # Margin between max and actual token scores
logit_rank: float # Rank of actual token by logit value
gumbel_rank: float # Rank of actual token by Gumbel score
SamplingMethod
class SamplingMethod(Enum):
VLLM_GUMBEL_MAX = "vllm_gumbel_max"
Advanced Usage
Using Quantization
results = verify_outputs(
outputs,
model_name="meta-llama/Llama-3.1-8B-Instruct",
vllm_kwargs={"quantization": "fp8"},
)
Using FP8 KV Cache
results = verify_outputs(
outputs,
model_name="meta-llama/Llama-3.1-8B-Instruct",
vllm_kwargs={
"kv_cache_dtype": "fp8",
"calculate_kv_scales": True,
},
)
Custom Model Length
results = verify_outputs(
outputs,
model_name="meta-llama/Llama-3.1-8B-Instruct",
max_model_len=4096,
)
Environment Variables
The package automatically sets VLLM_USE_V1=1 (using setdefault, so you can override before importing). This enables vLLM v1 features required for prompt logprobs.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file token_difr-0.1.0.tar.gz.
File metadata
- Download URL: token_difr-0.1.0.tar.gz
- Upload date:
- Size: 9.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6035c65baaafee85857cb7a39a49e6793f0ab9a12f95a214021d0998c4748b66
|
|
| MD5 |
b106af5b5bccededa7fa3b3aaa3c1c65
|
|
| BLAKE2b-256 |
4f0a08413ceeec75eb733de958b2e9476f22e3fdd3eda6563fa3cb12e96451f3
|
File details
Details for the file token_difr-0.1.0-py3-none-any.whl.
File metadata
- Download URL: token_difr-0.1.0-py3-none-any.whl
- Upload date:
- Size: 8.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e62e63b07e90f8af90e3915a2ccb26a9bc55fe0f6993d4c6bc0548f5a8e38839
|
|
| MD5 |
a8502af2181b940b731c5af147e79ac7
|
|
| BLAKE2b-256 |
708adb6bfeab940a59fe8e1ac0d69dde5640ddd4e9707052024358d14ff3265d
|