Skip to main content

Verify LLM outputs using Gumbel-Max sampling verification

Project description

token-difr

Verify LLM outputs using Gumbel-Max sampling verification.

Installation

pip install token-difr

Requirements

  • Python >= 3.10
  • PyTorch >= 2.0.0
  • vLLM >= 0.6.0
  • CUDA-capable GPU

Quick Start

from token_difr import verify_outputs, TokenSequence

# Create token sequences to verify
# (typically from an untrusted LLM provider)
outputs = [
    TokenSequence(
        prompt_token_ids=[128000, 2323, 374, 264, 1296],
        output_token_ids=[264, 1296, 13, 578, 4320]
    )
]

# Verify the outputs against a trusted model
results = verify_outputs(
    outputs,
    model_name="meta-llama/Llama-3.1-8B-Instruct",
    temperature=1.0,
    top_k=50,
    top_p=0.95,
    seed=42,
)

# Check verification results
for seq_idx, token_metrics in enumerate(results):
    for tok_idx, metrics in enumerate(token_metrics):
        print(f"Seq {seq_idx}, Token {tok_idx}:")
        print(f"  exact_match: {metrics.exact_match}")
        print(f"  prob: {metrics.prob:.4f}")
        print(f"  margin: {metrics.margin:.4f}")
        print(f"  logit_rank: {metrics.logit_rank}")
        print(f"  gumbel_rank: {metrics.gumbel_rank}")

API Reference

verify_outputs

def verify_outputs(
    outputs: list[TokenSequence],
    model_name: str,
    *,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    seed: int = 42,
    max_model_len: int | None = None,
    dtype: torch.dtype = torch.bfloat16,
    vllm_kwargs: dict[str, Any] | None = None,
    sampling_method: SamplingMethod = SamplingMethod.VLLM_GUMBEL_MAX,
    gpu_memory_utilization: float = 0.7,
    show_progress: bool = True,
) -> list[list[TokenMetrics]]:

Parameters:

  • outputs: List of TokenSequence objects containing prompt and output token IDs
  • model_name: HuggingFace model name (e.g., "meta-llama/Llama-3.1-8B-Instruct")
  • temperature: Sampling temperature used during generation (default: 1.0)
  • top_k: Top-k sampling parameter (default: 50)
  • top_p: Top-p (nucleus) sampling parameter (default: 0.95)
  • seed: Random seed used during generation (default: 42)
  • max_model_len: Maximum model context length. If None, auto-computed from outputs
  • dtype: Model dtype, e.g., torch.bfloat16, torch.float16 (default: torch.bfloat16)
  • vllm_kwargs: Additional kwargs for vLLM's LLM constructor (e.g., for quantization)
  • sampling_method: Sampling method to verify against (default: VLLM_GUMBEL_MAX)
  • gpu_memory_utilization: Fraction of GPU memory to use (default: 0.7)
  • show_progress: Whether to show progress bars (default: True)

Returns:

List of lists of TokenMetrics, one per token in each output sequence.

TokenSequence

@dataclass
class TokenSequence:
    prompt_token_ids: list[int]
    output_token_ids: list[int]

TokenMetrics

@dataclass
class TokenMetrics:
    exact_match: bool   # Whether token matches under verification
    prob: float         # Probability of the actual token
    margin: float       # Margin between max and actual token scores
    logit_rank: float   # Rank of actual token by logit value
    gumbel_rank: float  # Rank of actual token by Gumbel score

SamplingMethod

class SamplingMethod(Enum):
    VLLM_GUMBEL_MAX = "vllm_gumbel_max"

Advanced Usage

Using Quantization

results = verify_outputs(
    outputs,
    model_name="meta-llama/Llama-3.1-8B-Instruct",
    vllm_kwargs={"quantization": "fp8"},
)

Using FP8 KV Cache

results = verify_outputs(
    outputs,
    model_name="meta-llama/Llama-3.1-8B-Instruct",
    vllm_kwargs={
        "kv_cache_dtype": "fp8",
        "calculate_kv_scales": True,
    },
)

Custom Model Length

results = verify_outputs(
    outputs,
    model_name="meta-llama/Llama-3.1-8B-Instruct",
    max_model_len=4096,
)

Environment Variables

The package automatically sets VLLM_USE_V1=1 (using setdefault, so you can override before importing). This enables vLLM v1 features required for prompt logprobs.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

token_difr-0.1.0.tar.gz (9.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

token_difr-0.1.0-py3-none-any.whl (8.8 kB view details)

Uploaded Python 3

File details

Details for the file token_difr-0.1.0.tar.gz.

File metadata

  • Download URL: token_difr-0.1.0.tar.gz
  • Upload date:
  • Size: 9.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.10

File hashes

Hashes for token_difr-0.1.0.tar.gz
Algorithm Hash digest
SHA256 6035c65baaafee85857cb7a39a49e6793f0ab9a12f95a214021d0998c4748b66
MD5 b106af5b5bccededa7fa3b3aaa3c1c65
BLAKE2b-256 4f0a08413ceeec75eb733de958b2e9476f22e3fdd3eda6563fa3cb12e96451f3

See more details on using hashes here.

File details

Details for the file token_difr-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: token_difr-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 8.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.10

File hashes

Hashes for token_difr-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e62e63b07e90f8af90e3915a2ccb26a9bc55fe0f6993d4c6bc0548f5a8e38839
MD5 a8502af2181b940b731c5af147e79ac7
BLAKE2b-256 708adb6bfeab940a59fe8e1ac0d69dde5640ddd4e9707052024358d14ff3265d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page