Skip to main content

Verify LLM outputs using Gumbel-Max sampling verification

Project description

token-difr

Verify LLM outputs using Gumbel-Max sampling verification.

Installation

pip install token-difr

Requirements

  • Python >= 3.10
  • PyTorch >= 2.0.0
  • vLLM >= 0.11.0
  • CUDA-capable GPU

Quick Start

from token_difr import verify_outputs, TokenSequence

# Create token sequences to verify
# (typically from an untrusted LLM provider)
outputs = [
    TokenSequence(
        prompt_token_ids=[128000, 2323, 374, 264, 1296],
        output_token_ids=[264, 1296, 13, 578, 4320]
    )
]

# Verify the outputs against a trusted model
# All sampling parameters must match what was used during generation
results = verify_outputs(
    outputs,
    model_name="meta-llama/Llama-3.1-8B-Instruct",
    temperature=1.0,  # Required: must match generation
    top_k=50,         # Required: must match generation
    top_p=0.95,       # Required: must match generation
    seed=42,          # Required: must match generation
)

# Check verification results
for seq_idx, token_metrics in enumerate(results):
    for tok_idx, metrics in enumerate(token_metrics):
        print(f"Seq {seq_idx}, Token {tok_idx}:")
        print(f"  exact_match: {metrics.exact_match}")
        print(f"  prob: {metrics.prob:.4f}")
        print(f"  margin: {metrics.margin:.4f}")
        print(f"  logit_rank: {metrics.logit_rank}")
        print(f"  gumbel_rank: {metrics.gumbel_rank}")

API Reference

verify_outputs

def verify_outputs(
    outputs: list[TokenSequence],
    model_name: str,
    *,
    temperature: float,      # Required
    top_k: int,              # Required
    top_p: float,            # Required
    seed: int,               # Required
    max_model_len: int | None = None,
    dtype: torch.dtype = torch.bfloat16,
    vllm_kwargs: dict[str, Any] | None = None,
    sampling_method: SamplingMethod = SamplingMethod.VLLM_GUMBEL_MAX,
    gpu_memory_utilization: float = 0.7,
    verbose: bool = True,
) -> list[list[TokenMetrics]]:

Parameters:

  • outputs: List of TokenSequence objects containing prompt and output token IDs
  • model_name: HuggingFace model name (e.g., "meta-llama/Llama-3.1-8B-Instruct")
  • temperature: Sampling temperature used during generation. Required.
  • top_k: Top-k sampling parameter. Required.
  • top_p: Top-p (nucleus) sampling parameter. Required.
  • seed: Random seed used during generation. Required.
  • max_model_len: Maximum model context length. If None, auto-computed from outputs
  • dtype: Model dtype, e.g., torch.bfloat16, torch.float16 (default: torch.bfloat16)
  • vllm_kwargs: Additional kwargs for vLLM's LLM constructor (e.g., for quantization)
  • sampling_method: Sampling method to verify against (default: VLLM_GUMBEL_MAX)
  • gpu_memory_utilization: Fraction of GPU memory to use (default: 0.7)
  • verbose: Whether to show progress and print summary (default: True)

Returns:

List of lists of TokenMetrics, one per token in each output sequence.

TokenSequence

@dataclass
class TokenSequence:
    prompt_token_ids: list[int]
    output_token_ids: list[int]

TokenMetrics

@dataclass
class TokenMetrics:
    exact_match: bool   # Whether token matches under verification
    prob: float         # Probability of the actual token
    margin: float       # Margin between max and actual token scores
    logit_rank: float   # Rank of actual token by logit value
    gumbel_rank: float  # Rank of actual token by Gumbel score

SamplingMethod

class SamplingMethod(Enum):
    VLLM_GUMBEL_MAX = "vllm_gumbel_max"

compute_metrics_summary

def compute_metrics_summary(results: list[list[TokenMetrics]]) -> dict:
    """Returns aggregate stats from verification results."""
    # Returns: {
    #     "total_tokens": int,
    #     "exact_match_rate": float,
    #     "avg_prob": float,
    #     "avg_margin": float,
    #     "avg_logit_rank": float,
    #     "avg_gumbel_rank": float,
    # }

Advanced Usage

Using Quantization

results = verify_outputs(
    outputs,
    model_name="meta-llama/Llama-3.1-8B-Instruct",
    temperature=1.0, top_k=50, top_p=0.95, seed=42,
    vllm_kwargs={"quantization": "fp8"},
)

Using FP8 KV Cache

results = verify_outputs(
    outputs,
    model_name="meta-llama/Llama-3.1-8B-Instruct",
    temperature=1.0, top_k=50, top_p=0.95, seed=42,
    vllm_kwargs={
        "kv_cache_dtype": "fp8",
        "calculate_kv_scales": True,
    },
)

Custom Model Length

results = verify_outputs(
    outputs,
    model_name="meta-llama/Llama-3.1-8B-Instruct",
    temperature=1.0, top_k=50, top_p=0.95, seed=42,
    max_model_len=4096,
)

Environment Variables

The package automatically sets VLLM_USE_V1=1 (using setdefault, so you can override before importing). This enables vLLM v1 features required for prompt logprobs.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

token_difr-0.1.1.tar.gz (9.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

token_difr-0.1.1-py3-none-any.whl (9.5 kB view details)

Uploaded Python 3

File details

Details for the file token_difr-0.1.1.tar.gz.

File metadata

  • Download URL: token_difr-0.1.1.tar.gz
  • Upload date:
  • Size: 9.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.10

File hashes

Hashes for token_difr-0.1.1.tar.gz
Algorithm Hash digest
SHA256 09164a99eb6df04b24fc72d40a4b9d91cdee9ae98480d811531436399825956f
MD5 9699a1d183881c6862398891442b02cc
BLAKE2b-256 76f43666775f2ca53890b7126dae00b76a0290a6bb4bc445142edf9b97a00ed6

See more details on using hashes here.

File details

Details for the file token_difr-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: token_difr-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 9.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.10

File hashes

Hashes for token_difr-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a669bccda24de2161522881f66b0a7ebdf9cb7de1fbb4761fba7eb8f4a2cea2c
MD5 2db79064e73b7c357b8735eb20e530cc
BLAKE2b-256 08c6ba43b987bdbea2059a0941fd4dfe7a6a3b3edf8bf2aab3c4cb67f1ae73c2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page