Skip to main content

A library for generating semantic IDs from embeddings using RQ-KMeans and other algorithms.

Project description

Semantic ID 🌟

Turn your vectors into meaningful strings.

Semantic ID is a friendly Python library that helps you transform continuous vector embeddings (like those from OpenAI, BERT, or ResNet) into discrete, human-readable semantic strings. It uses algorithms like RQ-KMeans (Residual Quantization K-Means) and RQ-VAE (Residual Quantization Variational Autoencoder) to hierarchically cluster your data, giving you IDs that actually mean something!

Imagine turning [0.12, -0.88, 0.04, ...] into "cars-suv-landrover". Okay, maybe more like "12-4-9-1", but you get the idea—it preserves semantic similarity!

💡 Inspiration

This project is heavily inspired by the incredible work found in:

We aim to make these powerful techniques accessible and easy to use for everyone.

🗺️ Explore Your Embeddings

Before you start clustering, it's super helpful to "see" your data. We love Apple's Embedding Atlas and suggest everyone try it out! It's a great way to visualize your high-dimensional vectors and understand the landscape of your data. It's also a great way to evaluate your results after training your RQ-model.


✨ Features

  • RQ-KMeans: Hierarchical residual quantization with K-Means on CPU & GPU.
  • RQ-VAE: Neural network-based quantization with learnable codebooks.
  • Balanced Clustering: Constrained K-Means for evenly distributed codes.
  • Uniqueness: Automatic collision resolution (suffix-based and Sinkhorn re-encoding).
  • Custom Formats: User-defined formatter callbacks for any ID format, plus custom item IDs for collision resolution.
  • Evaluation: Built-in metrics — collision rate, recall@K, NDCG@K, distance correlation, code utilization, entropy, quantization MSE.
  • LLM-Friendly Tokens: Output IDs in <a_3><b_9><c_1> format for language models.
  • Persistence: Save/Load models and full engine pipelines.

📦 Installation

pip install -e .

To enable GPU acceleration (recommended!):

pip install torch

To use balanced clustering:

pip install k-means-constrained

🚀 Quick Start

1. The Basics (RQ-KMeans)

Let's generate some simple IDs. We'll use a small number of clusters (10 per level) so the IDs are short and sweet.

import numpy as np
from semantic_id import RQKMeans

# 1. Generate some dummy data (100 vectors, 16 dimensions)
X = np.random.randn(100, 16)

# 2. Initialize the model
# We'll use 3 levels with 10 clusters each.
# This means our IDs will look like "X-Y-Z" where numbers are 0-9.
model = RQKMeans(n_levels=3, n_clusters=10, random_state=42)

# 3. Train the model
model.fit(X)

# 4. Generate Semantic IDs
# This converts vectors -> codes -> strings
codes = model.encode(X)     # shape (100, 3)
sids = model.semantic_id(codes)

print(f"Vector: {X[0][:3]}...")
print(f"Semantic ID: {sids[0]}")  # Output: e.g., "3-9-1"

2. Go Fast with GPU 🏎️

Got a GPU? Let's use it! The PyTorch backend is compatible with cuda and mps.

device = "cuda" # or "mps" for Mac, or "cpu"

model = RQKMeans(n_levels=3, n_clusters=10)
model.fit(X, device=device)
codes = model.encode(X, device=device)

3. Ensure Uniqueness (The Engine)

In the real world, two different items might end up in the same cluster. The SemanticIdEngine handles this gracefully by appending a counter to duplicates.

from semantic_id import SemanticIdEngine, RQKMeans, UniqueIdResolver, SQLiteCollisionStore

# Setup the algorithm
encoder = RQKMeans(n_levels=3, n_clusters=10)

# Setup the persistence (saves collision counts to a file)
store = SQLiteCollisionStore("collisions.db")
resolver = UniqueIdResolver(store=store)

# Create the engine
engine = SemanticIdEngine(encoder=encoder, unique_resolver=resolver)

# Train and Get Unique IDs
engine.fit(X)
unique_ids = engine.unique_ids(X)

print(unique_ids[0]) # e.g., "3-9-1"
# If another item has code (3, 9, 1), it becomes "3-9-1-1" automatically!

Tip: For quick experiments, skip the store setup entirely — SemanticIdEngine uses an InMemoryCollisionStore by default:

engine = SemanticIdEngine(encoder=encoder)  # zero-config uniqueness

4. Neural Networks (RQ-VAE) 🧠

For complex data, a simple K-Means might not be enough. RQ-VAE uses a neural network to learn the optimal codebooks.

from semantic_id import RQVAE

model = RQVAE(
    in_dim=16,                # Input dimension of your vectors
    num_emb_list=[32, 32, 32], # 32 clusters per level
    e_dim=16,                 # Codebook dimension
    layers=[32, 16],          # Hidden layers
    device="cpu"
)

model.fit(X)
ids = model.semantic_id(model.encode(X))

5. Evaluate Your IDs 📊

Use the built-in evaluate() function to measure how well your IDs preserve the structure of the original embeddings.

from semantic_id import evaluate

metrics = evaluate(X, codes, encoder=model)
print(metrics)
# {
#   'n_samples': 100,
#   'n_unique_codes': 87,
#   'collision_rate': 0.13,
#   'collision_rate_per_level': [0.9, 0.45, 0.13],
#   'recall_at_10': 0.42,
#   'ndcg_at_10': 0.38,
#   'distance_correlation': 0.65,
#   'code_utilization_per_level': [1.0, 0.95, 0.87],
#   'code_entropy_per_level': [2.30, 2.25, 2.10],
#   'quantization_mse': 0.003
# }
Metric What it measures
collision_rate Fraction of items sharing an ID with another item (lower is better)
collision_rate_per_level Collision rate at each prefix depth — shows where uniqueness breaks down
recall_at_10 How well code-space neighbors match embedding-space neighbors (higher is better)
ndcg_at_10 Ranking quality of code-space neighbors vs embedding-space (higher is better)
distance_correlation Spearman correlation between embedding distances and code distances (higher is better)
code_utilization_per_level Fraction of codebook entries used at each level (higher is better)
code_entropy_per_level Shannon entropy of code distribution per level (higher = more uniform)
quantization_mse Reconstruction error from decode() (lower is better; requires an encoder with decode())

6. LLM-Friendly Token Format 🤖

When feeding semantic IDs into a language model, the token format wraps each level in angle brackets with a level letter:

codes = model.encode(X)

# Standard format (default)
plain_ids = model.semantic_id(codes)               # ["3-9-1", "0-5-7", ...]

# Token format for LLMs
token_ids = model.semantic_id(codes, fmt="token")   # ["<a_3><b_9><c_1>", ...]

7. Custom ID Formats 🎨

Define your own format function for full control over how codes become strings:

# Custom format for your LLM
def my_llm_format(codes):
    return "".join(f"[item_L{i}_{c}]" for i, c in enumerate(codes))

ids = model.semantic_id(codes, formatter=my_llm_format)
# ["[item_L0_3][item_L1_9][item_L2_1]", ...]

# Works through the engine too
engine = SemanticIdEngine(encoder=model)
engine.fit(X)
uids = engine.unique_ids(X, formatter=my_llm_format)

8. Use Your Own Item IDs 🏷️

Instead of auto-incremented suffixes (-1, -2), attach your own identifiers:

db_keys = ["SKU001", "SKU002", "SKU003", ...]
uids = engine.unique_ids(X, item_ids=db_keys)
# Collisions become "3-9-1-SKU042" instead of "3-9-1-1"

# Custom separator for the suffix too
uids = engine.unique_ids(X, item_ids=db_keys, sep="/")
# "3/9/1/SKU042"

9. Balanced Clustering ⚖️

Use implementation="constrained" to enforce roughly equal cluster sizes. This reduces collision rates but requires the k-means-constrained package.

model = RQKMeans(
    n_levels=3,
    n_clusters=10,
    implementation="constrained",  # balanced clusters
    random_state=42
)
model.fit(X)

🔄 Reproducibility & Persistence

We know how annoying it is when IDs change between machines. To ensure identical Semantic IDs across different environments (e.g., Training on GPU -> Inference on CPU):

  1. Train (fit) once on your training machine.
  2. Save the model.
  3. Load on your production machine.

Do not re-train on the second machine, as random initialization will differ!

# Save a single encoder
model.save("my_model")
loaded = RQKMeans.load("my_model")

# Save the full engine (encoder + collision store)
engine.save("my_engine")
loaded_engine = SemanticIdEngine.load("my_engine")

Both RQKMeans and RQVAE support save()/load(). The engine also persists the collision store so suffix counters are preserved.

🗺️ Project Status

We are actively building! Here is what's ready for you today:

  • RQ-KMeans: Core algorithm working on CPU & GPU.
  • RQ-VAE: Neural network based quantization with training history tracking.
  • Balanced Clustering: Constrained K-Means for even code distribution.
  • Uniqueness: Suffix-based and Sinkhorn-based collision resolution.
  • Custom Formats: User-defined formatter callbacks and item IDs for collision resolution.
  • Evaluation: Comprehensive metrics including NDCG, code utilization, entropy, and hierarchical distance.
  • Token Format: LLM-friendly ID output.
  • Persistence: Save/Load models and engines.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

semantic_id-0.2.4.tar.gz (50.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

semantic_id-0.2.4-py3-none-any.whl (42.1 kB view details)

Uploaded Python 3

File details

Details for the file semantic_id-0.2.4.tar.gz.

File metadata

  • Download URL: semantic_id-0.2.4.tar.gz
  • Upload date:
  • Size: 50.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.5

File hashes

Hashes for semantic_id-0.2.4.tar.gz
Algorithm Hash digest
SHA256 036fc8a614c28482dea3db2750c244f360b41ab9a56cbe14215be7bf0f356462
MD5 e2164b15dd7884be7f8e8e4736dd8667
BLAKE2b-256 107a8a829ac8c3447c824491bf99794c90a49287857cc5bba1748461e30d5641

See more details on using hashes here.

File details

Details for the file semantic_id-0.2.4-py3-none-any.whl.

File metadata

  • Download URL: semantic_id-0.2.4-py3-none-any.whl
  • Upload date:
  • Size: 42.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.5

File hashes

Hashes for semantic_id-0.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 b51c70b26064482615535c479eef302ac939d07c5a189a29177ccf04dd241a42
MD5 b5f3011bc8ea39848176ecacf6f05244
BLAKE2b-256 c3ba8b03acdee594529489742f5d9a077d54700f2f3a1ee5702e1b00aea0d3de

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page