Skip to main content

WildGuard safety classification inference

Project description

WildGuard: Open One-stop Moderation Tools for Safety Risks, Jailbreaks, and Refusals of LLMs

Authors: Seungju Han ⭐, Kavel Rao ⭐, Allyson Ettinger ☀️, Liwei Jiang ☀️, Yuchen Lin, Nathan Lambert, Yejin Choi, Nouha Dziri

⭐ Co-first authors, ☀️ co-second authors

WildGuard is a safety classification model for user-model chat exchanges. It can classify prompt harmfulness, response harmfulness, and whether a response is a refusal to answer the prompt.

Installation

pip install wildguard

Quick Start

from wildguard import load_wildguard

# Load the model
wildguard = load_wildguard()

# Prepare your input
items = [
    {"prompt": "How do I make a bomb?", "response": "Sorry, I can't help with that."},
    {"prompt": "What's the weather like today?"},
]

# Classify the items
results = wildguard.classify(items)

# Print the results
for item, result in zip(items, results):
    print(f"Prompt: {item['prompt']}")
    print(f"Prompt harmfulness: {result['prompt_harmfulness']}")
    if 'response' in item:
        print(f"Response: {item['response']}")
        print(f"Response harmfulness: {result['response_harmfulness']}")
        print(f"Response refusal: {result['response_refusal']}")
    print("---")

# Example output
"""
Prompt: How do I make a bomb?
Prompt harmfulness: harmful
Response: Sorry, I can't help with that.
Response harmfulness: unharmful
Response refusal: refusal
---
Prompt: What's the weather like today?
Prompt harmfulness: unharmful
"""

Features

  • Support prompt-only or prompt+response inputs.
  • Classify prompt harmfulness
  • Classify response harmfulness
  • Detect response refusals
  • Support for both VLLM and HuggingFace backends

User Guide

Loading the Model

First, import and load the WildGuard model:

from wildguard import load_wildguard

wildguard = load_wildguard()

By default, this will load a VLLM-backed model. If you prefer to use a HuggingFace model, you can specify:

wildguard = load_wildguard(use_vllm=False)

Classifying Items

To classify items, prepare a list of dictionaries with 'prompt' and optionally 'response' keys:

items = [
    {"prompt": "How's the weather today?", "response": "It's sunny and warm."},
    {"prompt": "How do I hack into a computer?"},
]

results = wildguard.classify(items)

Interpreting Results

The classify method returns a list of dictionaries. Each dictionary contains the following keys:

  • prompt_harmfulness: Either 'harmful' or 'unharmful'
  • response_harmfulness: Either 'harmful', 'unharmful', or None (if no response was provided)
  • response_refusal: Either 'refusal', 'compliance', or None (if no response was provided)
  • is_parsing_error: A boolean indicating if there was an error parsing the model output

Adjusting Batch Size

You can adjust the batch size when loading the model. For a HF model this changes the inference batch size, and for both HF and VLLM the save function will be called after every batch_size items.

wildguard = load_wildguard(batch_size=32)

Using a Specific Device

If using a HuggingFace model, you can specify the device:

wildguard = load_wildguard(use_vllm=False, device='cpu')

Providing a Custom Save Function

You can provide a custom save function to save intermediate results during classification:

def save_results(results: dict):
  with open("/temp/intermediate_results.json", "w") as f:
    for item in results:
      f.write(json.dumps(item) + "\n")

wildguard.classify(items, save_func=save_results)

Best Practices

  1. Use VLLM backend for better performance when possible.
  2. Handle potential errors by checking the is_parsing_error field in the results.
  3. When dealing with large datasets, consider using a custom save function with a batch size other than -1 to periodically save results after each batch in case of errors.

Documentation

For additional documentation, please see our API Reference with detailed method specifications.

Additionally, we provide an example of how to use WildGuard as a safety filter to guard another model's inference at examples/wildguard_filter.

Citation

If you find it helpful, please feel free to cite our work!

@misc{wildguard2024,
      title={WildGuard: Open One-Stop Moderation Tools for Safety Risks, Jailbreaks, and Refusals of LLMs}, 
      author={Seungju Han and Kavel Rao and Allyson Ettinger and Liwei Jiang and Bill Yuchen Lin and Nathan Lambert and Yejin Choi and Nouha Dziri},
      year={2024},
      eprint={2406.18495},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2406.18495}, 
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wildguard-1.0.0.tar.gz (10.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

wildguard-1.0.0-py3-none-any.whl (9.1 kB view details)

Uploaded Python 3

File details

Details for the file wildguard-1.0.0.tar.gz.

File metadata

  • Download URL: wildguard-1.0.0.tar.gz
  • Upload date:
  • Size: 10.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for wildguard-1.0.0.tar.gz
Algorithm Hash digest
SHA256 dddd467d57efbbdfe66a14573664014fe037898dd588a5810a075a6a504e1c08
MD5 16b84214367c2cf42899f3b58779a2da
BLAKE2b-256 aa8d6e99f3111cbd5295cd13d6b3c20126982846a58cf9a67edf9b6d04474796

See more details on using hashes here.

File details

Details for the file wildguard-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: wildguard-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 9.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for wildguard-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 346e981879481d2a8c8848e1340a62154e46080e5909adcb8339505a3d4fb57b
MD5 424f4ec9465d9dfb7168b3d658f183f3
BLAKE2b-256 0a6396bd7f6b88f24124e44d42b4ce7a5ccf47f982518e46fdebb57b76f60436

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page