Skip to main content

Rust and WebAssembly library for late interaction models.

Project description

pylate-rs

blog crate
Efficient Inference for PyLate

 

⭐️ Overview

pylate-rs is a high-performance inference engine for PyLate models, meticulously crafted in Rust for optimal speed and efficiency.

While model training is handled by PyLate, which supports a variety of late interaction models, pylate-rs is engineered to execute these models at speeds.

  • Accelerated Performance: Experience significantly faster model loading and rapid cold starts, making it ideal for serverless environments and low-latency applications.

  • Lightweight Design: Built on the Candle ML framework, pylate-rs maintains a minimal footprint suitable for resource-constrained systems like serverless functions and edge computing.

  • Broad Hardware Support: Optimized for diverse hardware, with dedicated builds for standard CPUs, Intel (MKL), Apple Silicon (Accelerate & Metal), and NVIDIA GPUs (CUDA).

  • Cross-Platform Integration: Seamlessly integrate pylate-rs into your projects with bindings for Python, Rust, and JavaScript/WebAssembly.

For a complete, high-performance multi-vector search pipeline, pair pylate-rs with its companion library, FastPlaid, at inference time.

Explore our WebAssembly live demo.

 

💻 Installation

Install the version of pylate-rs that matches your hardware for optimal performance.

Python

Target Hardware Installation Command
Standard CPU pip install pylate-rs
Apple CPU (macOS) pip install pylate-rs-accelerate
Intel CPU (MKL) pip install pylate-rs-mkl
Apple GPU (M1/M2/M3) pip install pylate-rs-metal

Python GPU support

To install pylate-rs with GPU support, please built it from source using the following command:

pip install git+https://github.com/lightonai/pylate-rs.git

or by cloning the repository and installing it locally:

git clone https://github.com/lightonai/pylate-rs.git
cd pylate-rs
pip install .

Any help to pre-build and disribute the CUDA wheels would be greatly appreciated.

 

Rust

Add pylate-rs to your Cargo.toml by enabling the feature flag that corresponds to your backend.

Feature Target Hardware Installation Command
(default) Standard CPU cargo add pylate-rs
accelerate Apple CPU (macOS) cargo add pylate-rs --features accelerate
mkl Intel CPU (MKL) cargo add pylate-rs --features mkl
metal Apple GPU (M1/M2/M3) cargo add pylate-rs --features metal
cuda NVIDIA GPU (CUDA) cargo add pylate-rs --features cuda

 

⚡️ Quick Start

Python

Get started in just a few lines of Python.

from pylate_rs import models

# Initialize the model for your target device ("cpu", "cuda", or "mps")
model = models.ColBERT(
    model_name_or_path="lightonai/GTE-ModernColBERT-v1",
    device="cuda"
)

# Encode queries and documents
queries_embeddings = model.encode(
    sentences=["What is the capital of France?", "How big is the sun?"],
    is_query=True
)

documents_embeddings = model.encode(
    sentences=["Paris is the capital of France.", "The sun is a star."],
    is_query=False
)

# Calculate similarity scores
similarities = model.similarity(queries_embeddings, documents_embeddings)

print(f"Similarity scores:\n{similarities}")

# Use hierarchical pooling to reduce document embedding size and speed up downstream tasks
pooled_documents_embeddings = model.encode(
    sentences=["Paris is the capital of France.", "The sun is a star."],
    is_query=False,
    pool_factor=2, # Halves the number of token embeddings
)

similarities_pooled = model.similarity(queries_embeddings, pooled_documents_embeddings)

print(f"Similarity scores with pooling:\n{similarities_pooled}")

 

Rust

use anyhow::Result;
use candle_core::Device;
use pylate_rs::{hierarchical_pooling, ColBERT};

fn main() -> Result<()> {
    // Set the device (e.g., Cpu, Cuda, Metal)
    let device = Device::Cpu;

    // Initialize the model
    let mut model: ColBERT = ColBERT::from("lightonai/GTE-ModernColBERT-v1")
        .with_device(device)
        .try_into()?;

    // Encode queries and documents
    let queries = vec!["What is the capital of France?".to_string()];
    let documents = vec!["Paris is the capital of France.".to_string()];

    let query_embeddings = model.encode(&queries, true)?;
    let document_embeddings = model.encode(&documents, false)?;

    // Calculate similarity
    let similarities = model.similarity(&query_embeddings, &document_embeddings)?;
    println!("Similarity score: {}", similarities.data[0][0]);

    // Use hierarchical pooling
    let pooled_document_embeddings = hierarchical_pooling(&document_embeddings, 2)?;
    let pooled_similarities = model.similarity(&query_embeddings, &pooled_document_embeddings)?;
    println!("Similarity score after hierarchical pooling: {}", pooled_similarities.data[0][0]);

    Ok(())
}

 

📊 Benchmarks

Device    backend        Queries per seconds        Documents per seconds        Model loading time
cpu       PyLate         350.10                     32.16                        2.06
cpu       pylate-rs      386.21 (+10%)              42.15 (+31%)                 0.07 (-97%)

cuda      PyLate         2236.48                    882.66                       3.62
cuda      pylate-rs      4046.88 (+81%)             976.23 (+11%)                1.95 (-46%)

mps       PyLate         580.81                     103.10                       1.95
mps       pylate-rs      291.71 (-50%)              23.26 (-77%)                 0.08 (-96%)

Benchmark were run with Python. pylate-rs provide significant performance improvement, especially in scenarios requiring fast startup times. While on a Mac it takes up to 5 seconds to load a model with the Transformers backend and encode a single query, pylate-rs achieves this in just 0.11 seconds, making it ideal for low-latency applications. Don't expect pylate-rs to be much faster than PyLate to encode a lot of content at the same time as PyTorch is heavily optimized.

 

📦 Using Custom Models

pylate-rs is compatible with any model saved in the PyLate format, whether from the Hugging Face Hub or a local directory. PyLate itself is compatible with a wide range of models, including those from Sentence Transformers, Hugging Face Transformers, and custom models. So before using pylate-rs, ensure your model is saved in the PyLate format. You can easily convert and upload your own models using PyLate.

Pushing a model to the Hugging Face Hub in PyLate format is straightforward. Here’s how you can do it:

pip install pylate

Then, you can use the following Python code snippet to push your model:

from pylate import models

# Load your model
model = models.ColBERT(model_name_or_path="your-base-model-on-hf")

# Push in PyLate format
model.push_to_hub(
    repo_id="YourUsername/YourModelName",
    private=False,
    token="YOUR_HUGGINGFACE_TOKEN",
)

If you want to save a model in PyLate format locally, you can do so with the following code snippet:

from pylate import models

# Load your model
model = models.ColBERT(model_name_or_path="your-base-model-on-hf")

# Save in PyLate format
model.save_pretrained("path/to/save/GTE-ModernColBERT-v1-pylate")

An existing set of models compatible with pylate-rs is available on the Hugging Face Hub under the LightOn namespace.

 

Retrieval pipeline

pip install pylate-rs fast-plaid

Here is a sample code for running ColBERT with pylate-rs and fast-plaid.

import torch
from fast_plaid import search
from pylate_rs import models

model = models.ColBERT(
    model_name_or_path="lightonai/GTE-ModernColBERT-v1",
    device="cpu", # mps or cuda
)

documents = [
    "1st Arrondissement: Louvre, Tuileries Garden, Palais Royal, historic, tourist.",
    "2nd Arrondissement: Bourse, financial, covered passages, Sentier, business.",
    "3rd Arrondissement: Marais, Musée Picasso, galleries, trendy, historic.",
    "4th Arrondissement: Notre-Dame, Marais, Hôtel de Ville, LGBTQ+.",
    "5th Arrondissement: Latin Quarter, Sorbonne, Panthéon, student, intellectual.",
    "6th Arrondissement: Saint-Germain-des-Prés, Luxembourg Gardens, chic, artistic, cafés.",
    "7th Arrondissement: Eiffel Tower, Musée d'Orsay, Les Invalides, affluent, prestigious.",
    "8th Arrondissement: Champs-Élysées, Arc de Triomphe, luxury, shopping, Élysée.",
    "9th Arrondissement: Palais Garnier, department stores, shopping, theaters.",
    "10th Arrondissement: Gare du Nord, Gare de l'Est, Canal Saint-Martin.",
    "11th Arrondissement: Bastille, nightlife, Oberkampf, revolutionary, hip.",
    "12th Arrondissement: Bois de Vincennes, Opéra Bastille, Bercy, residential.",
    "13th Arrondissement: Chinatown, Bibliothèque Nationale, modern, diverse, street-art.",
    "14th Arrondissement: Montparnasse, Catacombs, residential, artistic, quiet.",
    "15th Arrondissement: Residential, family, populous, Parc André Citroën.",
    "16th Arrondissement: Trocadéro, Bois de Boulogne, affluent, elegant, embassies.",
    "17th Arrondissement: Diverse, Palais des Congrès, residential, Batignolles.",
    "18th Arrondissement: Montmartre, Sacré-Cœur, Moulin Rouge, artistic, historic.",
    "19th Arrondissement: Parc de la Villette, Cité des Sciences, canals, diverse.",
    "20th Arrondissement: Père Lachaise, Belleville, cosmopolitan, artistic, historic.",
]

# Encoding documents
documents_embeddings = model.encode(
    sentences=documents,
    is_query=False,
    pool_factor=2, # Let's divide the number of embeddings by 2.
)

# Creating the FastPlaid index
fast_plaid = search.FastPlaid(index="index")


fast_plaid.create(
    documents_embeddings=[torch.tensor(embedding) for embedding in documents_embeddings]
)

We can then load the existing index and search for the most relevant documents:

import torch
from fast_plaid import search
from pylate_rs import models

fast_plaid = search.FastPlaid(index="index")

queries = [
    "arrondissement with the Eiffel Tower and Musée d'Orsay",
    "Latin Quarter and Sorbonne University",
    "arrondissement with Sacré-Cœur and Moulin Rouge",
    "arrondissement with the Louvre and Tuileries Garden",
    "arrondissement with Notre-Dame Cathedral and the Marais",
]

queries_embeddings = model.encode(
    sentences=queries,
    is_query=True,
)

scores = fast_plaid.search(
    queries_embeddings=torch.tensor(queries_embeddings),
    top_k=3,
)

print(scores)

📝 Citation

If you use pylate-rs in your research or project, please cite it as follows:

@misc{PyLate,
  title={PyLate: Flexible Training and Retrieval for Late Interaction Models},
  author={Chaffin, Antoine and Sourty, Raphaël},
  url={https://github.com/lightonai/pylate},
  year={2024}
}

 

WebAssembly

For JavaScript and TypeScript projects, install the WASM package from npm.

npm install pylate-rs

Load the model by fetching the required files from a local path or the Hugging Face Hub.

import { ColBERT } from "pylate-rs";

const REQUIRED_FILES = [
  "tokenizer.json",
  "model.safetensors",
  "config.json",
  "config_sentence_transformers.json",
  "1_Dense/model.safetensors",
  "1_Dense/config.json",
  "special_tokens_map.json",
];

async function loadModel(modelRepo) {
  const fetchAllFiles = async (basePath) => {
    const responses = await Promise.all(
      REQUIRED_FILES.map((file) => fetch(`${basePath}/${file}`))
    );
    for (const response of responses) {
      if (!response.ok) throw new Error(`File not found: ${response.url}`);
    }
    return Promise.all(
      responses.map((res) => res.arrayBuffer().then((b) => new Uint8Array(b)))
    );
  };

  try {
    let modelFiles;
    try {
      // Attempt to load from a local `models` directory first
      modelFiles = await fetchAllFiles(`models/${modelRepo}`);
    } catch (e) {
      console.warn(
        `Local model not found, falling back to Hugging Face Hub.`,
        e
      );
      // Fallback to fetching directly from the Hugging Face Hub
      modelFiles = await fetchAllFiles(
        `https://huggingface.co/${modelRepo}/resolve/main`
      );
    }

    const [
      tokenizer,
      model,
      config,
      stConfig,
      dense,
      denseConfig,
      tokensConfig,
    ] = modelFiles;

    // Instantiate the model with the loaded files
    const colbertModel = new ColBERT(
      model,
      dense,
      tokenizer,
      config,
      stConfig,
      denseConfig,
      tokensConfig,
      32
    );

    // You can now use `colbertModel` for encoding
    console.log("Model loaded successfully!");
    return colbertModel;
  } catch (error) {
    console.error("Model Loading Error:", error);
  }
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

pylate_rs_mkl-1.0.3-cp313-cp313-win_amd64.whl (7.0 MB view details)

Uploaded CPython 3.13Windows x86-64

pylate_rs_mkl-1.0.3-cp313-cp313-manylinux_2_34_x86_64.whl (8.7 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.34+ x86-64

pylate_rs_mkl-1.0.3-cp312-cp312-win_amd64.whl (7.0 MB view details)

Uploaded CPython 3.12Windows x86-64

pylate_rs_mkl-1.0.3-cp312-cp312-manylinux_2_34_x86_64.whl (8.7 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.34+ x86-64

pylate_rs_mkl-1.0.3-cp311-cp311-win_amd64.whl (7.0 MB view details)

Uploaded CPython 3.11Windows x86-64

pylate_rs_mkl-1.0.3-cp311-cp311-manylinux_2_34_x86_64.whl (8.7 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.34+ x86-64

pylate_rs_mkl-1.0.3-cp310-cp310-win_amd64.whl (7.0 MB view details)

Uploaded CPython 3.10Windows x86-64

pylate_rs_mkl-1.0.3-cp310-cp310-manylinux_2_34_x86_64.whl (8.7 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.34+ x86-64

pylate_rs_mkl-1.0.3-cp39-cp39-win_amd64.whl (7.0 MB view details)

Uploaded CPython 3.9Windows x86-64

pylate_rs_mkl-1.0.3-cp39-cp39-manylinux_2_34_x86_64.whl (8.7 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.34+ x86-64

File details

Details for the file pylate_rs_mkl-1.0.3-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.3-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 ff780074fbe569942fa6301bfbf2601401626d13e5916f442a2654f7f026acee
MD5 b99a1b1300c89f58d6223c6da9088e01
BLAKE2b-256 fe62bdd7f8add0f34ed64d0f709dfb77e5a35cdc6e533a65cfc3427fff6c325e

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.3-cp313-cp313-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.3-cp313-cp313-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 b09b919006f3ed6fa4e0549c4a7acc34525690eb129e5e91ebe92e31c7fbad3e
MD5 46fdff431ede0ba68ac980f2065b04f2
BLAKE2b-256 f595c753c309ae9a078e9badce2e44d0813082782f59db9eab3de0f26567281e

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.3-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.3-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 2194925f9218264a3b87cb4ab7285a9ed5ccf0819ad131711bcb930ef9cd06d6
MD5 119777fb14972cb93521bb4e6b905356
BLAKE2b-256 3854bd920dda13cfec00aa695a613f92798045e0f33fb881dc80160da96f4435

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.3-cp312-cp312-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.3-cp312-cp312-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 3bf4d29eada54c2087614fa40352915792522ada35596b417cc2e92a466e58d0
MD5 c10a48454b611e6e207669be21d10e1e
BLAKE2b-256 2047e96df2bf9936fc155b86fea39f650d57a862ad6d7a2d080ec00c12b8038d

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.3-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.3-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 d7e97eb4e6d0e4aec2df2c92ccfb0ba6827336c5c349a361a7c40500b96d89a7
MD5 220f19b28d9f33e659c28588ba2d20d9
BLAKE2b-256 e2ccff71703200ca40eda3b460c009679e9cbc778898d6cff0d428d126d19e2d

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.3-cp311-cp311-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.3-cp311-cp311-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 72c421f0211748f4ada6e1b7ae330a3a3383d915cc7b3e76124a087005c69ad1
MD5 417d9a2eca0d24ad70a61667b80f7560
BLAKE2b-256 2766ddddc111b48da27dd5f350a83b7e773950b96dc01e4e14c311162fabb781

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.3-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.3-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 04b3c8dc71ce8433437c55853d3c0dd45ad8fb7a7c70988367a9d6f86b91776b
MD5 70231081b1ae9e9dc814e927ff5eb65d
BLAKE2b-256 93d52493d424da868432443956180c17e900793c5ffd630e381802bb3998d2bd

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.3-cp310-cp310-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.3-cp310-cp310-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 467fc48c3d700328da65d12426e6df200ad156038a1a3b626f45328bdef2b63f
MD5 c403eff4676cd492e71af1528391ee6a
BLAKE2b-256 20bda385d96fb83acb8348eea385a5446cfc714394ef032d63cbbee8f29e225f

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.3-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.3-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 b0146dd3c90c2484471d9b9312b475e50285628d78130e6665f8381ba91699b7
MD5 de21fc63d4df8fb05b4348c67d8750f9
BLAKE2b-256 a9edc9e55299a2b1ffc193fbacc83eb1f4af0e540e845e02ec7eb7fd4b81b766

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.3-cp39-cp39-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.3-cp39-cp39-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 2697241bc87e093590ae25bb497e82240e2222efa8868b3c81323655958db808
MD5 ff85f5b67e99c55b26815f95b1700d90
BLAKE2b-256 942d2b440a78c455df08e326527e3c67a1d2f414c1697bffba018f05e3c4b907

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page