Skip to main content

Rust and WebAssembly library for late interaction models.

Project description

pylate-rs

blog crate
Efficient Inference for PyLate

 

⭐️ Overview

pylate-rs is a high-performance inference engine for PyLate models, meticulously crafted in Rust for optimal speed and efficiency.

While model training is handled by PyLate, which supports a variety of late interaction models, pylate-rs is engineered to execute these models at speeds.

  • Accelerated Performance: Experience significantly faster model loading and rapid cold starts, making it ideal for serverless environments and low-latency applications.

  • Lightweight Design: Built on the Candle ML framework, pylate-rs maintains a minimal footprint suitable for resource-constrained systems like serverless functions and edge computing.

  • Broad Hardware Support: Optimized for diverse hardware, with dedicated builds for standard CPUs, Intel (MKL), Apple Silicon (Accelerate & Metal), and NVIDIA GPUs (CUDA).

  • Cross-Platform Integration: Seamlessly integrate pylate-rs into your projects with bindings for Python, Rust, and JavaScript/WebAssembly.

For a complete, high-performance multi-vector search pipeline, pair pylate-rs with its companion library, FastPlaid, at inference time.

Explore our WebAssembly live demo.

 

💻 Installation

Install the version of pylate-rs that matches your hardware for optimal performance.

Python

Target Hardware Installation Command
Standard CPU pip install pylate-rs
Apple CPU (macOS) pip install pylate-rs-accelerate
Intel CPU (MKL) pip install pylate-rs-mkl
Apple GPU (M1/M2/M3) pip install pylate-rs-metal
NVIDIA GPU (CUDA) pip install git+https://github.com/lightonai/pylate-rs

Cuda wheels are not yet available on PyPI because of their size. But you can install them with the link to the repository. WIP.

 

Rust

Add pylate-rs to your Cargo.toml by enabling the feature flag that corresponds to your backend.

Feature Target Hardware Installation Command
(default) Standard CPU cargo add pylate-rs
accelerate Apple CPU (macOS) cargo add pylate-rs --features accelerate
mkl Intel CPU (MKL) cargo add pylate-rs --features mkl
metal Apple GPU (M1/M2/M3) cargo add pylate-rs --features metal
cuda NVIDIA GPU (CUDA) cargo add pylate-rs --features cuda

 

⚡️ Quick Start

Python

Get started in just a few lines of Python.

from pylate_rs import models

# Initialize the model for your target device ("cpu", "cuda", or "mps")
model = models.ColBERT(
    model_name_or_path="lightonai/GTE-ModernColBERT-v1",
    device="cuda"
)

# Encode queries and documents
queries_embeddings = model.encode(
    sentences=["What is the capital of France?", "How big is the sun?"],
    is_query=True
)

documents_embeddings = model.encode(
    sentences=["Paris is the capital of France.", "The sun is a star."],
    is_query=False
)

# Calculate similarity scores
similarities = model.similarity(queries_embeddings, documents_embeddings)

print(f"Similarity scores:\n{similarities}")

# Use hierarchical pooling to reduce document embedding size and speed up downstream tasks
pooled_documents_embeddings = model.encode(
    sentences=["Paris is the capital of France.", "The sun is a star."],
    is_query=False,
    pool_factor=2, # Halves the number of token embeddings
)

similarities_pooled = model.similarity(queries_embeddings, pooled_documents_embeddings)

print(f"Similarity scores with pooling:\n{similarities_pooled}")

 

Rust

use anyhow::Result;
use candle_core::Device;
use pylate_rs::{hierarchical_pooling, ColBERT};

fn main() -> Result<()> {
    // Set the device (e.g., Cpu, Cuda, Metal)
    let device = Device::Cpu;

    // Initialize the model
    let mut model: ColBERT = ColBERT::from("lightonai/GTE-ModernColBERT-v1")
        .with_device(device)
        .try_into()?;

    // Encode queries and documents
    let queries = vec!["What is the capital of France?".to_string()];
    let documents = vec!["Paris is the capital of France.".to_string()];

    let query_embeddings = model.encode(&queries, true)?;
    let document_embeddings = model.encode(&documents, false)?;

    // Calculate similarity
    let similarities = model.similarity(&query_embeddings, &document_embeddings)?;
    println!("Similarity score: {}", similarities.data[0][0]);

    // Use hierarchical pooling
    let pooled_document_embeddings = hierarchical_pooling(&document_embeddings, 2)?;
    let pooled_similarities = model.similarity(&query_embeddings, &pooled_document_embeddings)?;
    println!("Similarity score after hierarchical pooling: {}", pooled_similarities.data[0][0]);

    Ok(())
}

 

📊 Benchmarks

Device    backend        Queries per seconds        Documents per seconds        Model loading time
cpu       PyLate         350.10                     32.16                        2.06
cpu       pylate-rs      386.21 (+10%)              42.15 (+31%)                 0.07 (-97%)

cuda      PyLate         2236.48                    882.66                       3.62
cuda      pylate-rs      4046.88 (+81%)             976.23 (+11%)                1.95 (-46%)

mps       PyLate         580.81                     103.10                       1.95
mps       pylate-rs      291.71 (-50%)              23.26 (-77%)                 0.08 (-96%)

Benchmark were run with Python. pylate-rs provide significant performance improvement, especially in scenarios requiring fast startup times. While on a Mac it takes up to 5 seconds to load a model with the Transformers backend and encode a single query, pylate-rs achieves this in just 0.11 seconds, making it ideal for low-latency applications. Don't expect pylate-rs to be much faster than PyLate to encode a lot of content at the same time as PyTorch is heavily optimized.

 

📦 Using Custom Models

pylate-rs is compatible with any model saved in the PyLate format, whether from the Hugging Face Hub or a local directory. PyLate itself is compatible with a wide range of models, including those from Sentence Transformers, Hugging Face Transformers, and custom models. So before using pylate-rs, ensure your model is saved in the PyLate format. You can easily convert and upload your own models using PyLate.

Pushing a model to the Hugging Face Hub in PyLate format is straightforward. Here’s how you can do it:

pip install pylate

Then, you can use the following Python code snippet to push your model:

from pylate import models

# Load your model
model = models.ColBERT(model_name_or_path="your-base-model-on-hf")

# Push in PyLate format
model.push_to_hub(
    repo_id="YourUsername/YourModelName",
    private=False,
    token="YOUR_HUGGINGFACE_TOKEN",
)

If you want to save a model in PyLate format locally, you can do so with the following code snippet:

from pylate import models

# Load your model
model = models.ColBERT(model_name_or_path="your-base-model-on-hf")

# Save in PyLate format
model.save_pretrained("path/to/save/GTE-ModernColBERT-v1-pylate")

An existing set of models compatible with pylate-rs is available on the Hugging Face Hub under the LightOn namespace.

 

Retrieval pipeline

pip install pylate-rs fast-plaid

Here is a sample code for running ColBERT with pylate-rs and fast-plaid.

import torch
from fast_plaid import search
from pylate_rs import models

model = models.ColBERT(
    model_name_or_path="lightonai/GTE-ModernColBERT-v1",
    device="cpu", # mps or cuda
)

documents = [
    "1st Arrondissement: Louvre, Tuileries Garden, Palais Royal, historic, tourist.",
    "2nd Arrondissement: Bourse, financial, covered passages, Sentier, business.",
    "3rd Arrondissement: Marais, Musée Picasso, galleries, trendy, historic.",
    "4th Arrondissement: Notre-Dame, Marais, Hôtel de Ville, LGBTQ+.",
    "5th Arrondissement: Latin Quarter, Sorbonne, Panthéon, student, intellectual.",
    "6th Arrondissement: Saint-Germain-des-Prés, Luxembourg Gardens, chic, artistic, cafés.",
    "7th Arrondissement: Eiffel Tower, Musée d'Orsay, Les Invalides, affluent, prestigious.",
    "8th Arrondissement: Champs-Élysées, Arc de Triomphe, luxury, shopping, Élysée.",
    "9th Arrondissement: Palais Garnier, department stores, shopping, theaters.",
    "10th Arrondissement: Gare du Nord, Gare de l'Est, Canal Saint-Martin.",
    "11th Arrondissement: Bastille, nightlife, Oberkampf, revolutionary, hip.",
    "12th Arrondissement: Bois de Vincennes, Opéra Bastille, Bercy, residential.",
    "13th Arrondissement: Chinatown, Bibliothèque Nationale, modern, diverse, street-art.",
    "14th Arrondissement: Montparnasse, Catacombs, residential, artistic, quiet.",
    "15th Arrondissement: Residential, family, populous, Parc André Citroën.",
    "16th Arrondissement: Trocadéro, Bois de Boulogne, affluent, elegant, embassies.",
    "17th Arrondissement: Diverse, Palais des Congrès, residential, Batignolles.",
    "18th Arrondissement: Montmartre, Sacré-Cœur, Moulin Rouge, artistic, historic.",
    "19th Arrondissement: Parc de la Villette, Cité des Sciences, canals, diverse.",
    "20th Arrondissement: Père Lachaise, Belleville, cosmopolitan, artistic, historic.",
]

# Encoding documents
documents_embeddings = model.encode(
    sentences=documents,
    is_query=False,
    pool_factor=2, # Let's divide the number of embeddings by 2.
)

# Creating the FastPlaid index
fast_plaid = search.FastPlaid(index="index")


fast_plaid.create(
    documents_embeddings=[torch.tensor(embedding) for embedding in documents_embeddings]
)

We can then load the existing index and search for the most relevant documents:

import torch
from fast_plaid import search
from pylate_rs import models

fast_plaid = search.FastPlaid(index="index")

queries = [
    "arrondissement with the Eiffel Tower and Musée d'Orsay",
    "Latin Quarter and Sorbonne University",
    "arrondissement with Sacré-Cœur and Moulin Rouge",
    "arrondissement with the Louvre and Tuileries Garden",
    "arrondissement with Notre-Dame Cathedral and the Marais",
]

queries_embeddings = model.encode(
    sentences=queries,
    is_query=True,
)

scores = fast_plaid.search(
    queries_embeddings=torch.tensor(queries_embeddings),
    top_k=3,
)

print(scores)

📝 Citation

If you use pylate-rs in your research or project, please cite it as follows:

@misc{PyLate,
  title={PyLate: Flexible Training and Retrieval for Late Interaction Models},
  author={Chaffin, Antoine and Sourty, Raphaël},
  url={https://github.com/lightonai/pylate},
  year={2024}
}

 

WebAssembly

For JavaScript and TypeScript projects, install the WASM package from npm.

npm install pylate-rs

Load the model by fetching the required files from a local path or the Hugging Face Hub.

import { ColBERT } from "pylate-rs";

const REQUIRED_FILES = [
  "tokenizer.json",
  "model.safetensors",
  "config.json",
  "config_sentence_transformers.json",
  "1_Dense/model.safetensors",
  "1_Dense/config.json",
  "special_tokens_map.json",
];

async function loadModel(modelRepo) {
  const fetchAllFiles = async (basePath) => {
    const responses = await Promise.all(
      REQUIRED_FILES.map((file) => fetch(`${basePath}/${file}`))
    );
    for (const response of responses) {
      if (!response.ok) throw new Error(`File not found: ${response.url}`);
    }
    return Promise.all(
      responses.map((res) => res.arrayBuffer().then((b) => new Uint8Array(b)))
    );
  };

  try {
    let modelFiles;
    try {
      // Attempt to load from a local `models` directory first
      modelFiles = await fetchAllFiles(`models/${modelRepo}`);
    } catch (e) {
      console.warn(
        `Local model not found, falling back to Hugging Face Hub.`,
        e
      );
      // Fallback to fetching directly from the Hugging Face Hub
      modelFiles = await fetchAllFiles(
        `https://huggingface.co/${modelRepo}/resolve/main`
      );
    }

    const [
      tokenizer,
      model,
      config,
      stConfig,
      dense,
      denseConfig,
      tokensConfig,
    ] = modelFiles;

    // Instantiate the model with the loaded files
    const colbertModel = new ColBERT(
      model,
      dense,
      tokenizer,
      config,
      stConfig,
      denseConfig,
      tokensConfig,
      32
    );

    // You can now use `colbertModel` for encoding
    console.log("Model loaded successfully!");
    return colbertModel;
  } catch (error) {
    console.error("Model Loading Error:", error);
  }
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

pylate_rs_mkl-1.0.2-cp313-cp313-win_amd64.whl (7.0 MB view details)

Uploaded CPython 3.13Windows x86-64

pylate_rs_mkl-1.0.2-cp313-cp313-manylinux_2_34_x86_64.whl (10.9 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.34+ x86-64

pylate_rs_mkl-1.0.2-cp312-cp312-win_amd64.whl (7.0 MB view details)

Uploaded CPython 3.12Windows x86-64

pylate_rs_mkl-1.0.2-cp312-cp312-manylinux_2_34_x86_64.whl (10.9 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.34+ x86-64

pylate_rs_mkl-1.0.2-cp311-cp311-win_amd64.whl (7.0 MB view details)

Uploaded CPython 3.11Windows x86-64

pylate_rs_mkl-1.0.2-cp311-cp311-manylinux_2_34_x86_64.whl (10.9 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.34+ x86-64

pylate_rs_mkl-1.0.2-cp310-cp310-win_amd64.whl (7.0 MB view details)

Uploaded CPython 3.10Windows x86-64

pylate_rs_mkl-1.0.2-cp310-cp310-manylinux_2_34_x86_64.whl (10.9 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.34+ x86-64

pylate_rs_mkl-1.0.2-cp39-cp39-win_amd64.whl (7.0 MB view details)

Uploaded CPython 3.9Windows x86-64

pylate_rs_mkl-1.0.2-cp39-cp39-manylinux_2_34_x86_64.whl (10.9 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.34+ x86-64

File details

Details for the file pylate_rs_mkl-1.0.2-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.2-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 cf6862ca2fe90f7948a22abba7c080f92ca883c681b2a4d7944083468374c029
MD5 49e8149d0ed41b4ff9fdf7089d14d37a
BLAKE2b-256 125a2da2ac78861e45d903440b200344e890f63ad0091c99bf0f1db88f3f41bf

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.2-cp313-cp313-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.2-cp313-cp313-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 05266ef7506b5ba546d5f10b7f6af535b7df9789c8ac7b9d056510a1725dc266
MD5 ed3cd71ad6d34fa8652fbc652cf139b7
BLAKE2b-256 0328f0aa634ef23fbc4b50be52d49ab69c15176556c34dfa4d8848d295178283

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.2-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.2-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 d6a60a2b1aad38adf25b9a6fb288a91da2644bd54e04065d12a3ab8749e08fc9
MD5 2bbd28c52f3c8ff89c369eecbcdfb8fb
BLAKE2b-256 6c7c5a3f5c3d0472d7d356f43973d6965fa791e324ed19e1a13fb86c821ffae5

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.2-cp312-cp312-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.2-cp312-cp312-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 a827878a2faa0d6b26db1ec8a3e994d808f1bff477ffc6833134236909cc3c54
MD5 152a9a896b753117cf075f4fffa126a7
BLAKE2b-256 e842ecbb04da63f371304ba0d9fbc07904566533d586db165114249743ff8ccf

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.2-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.2-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 5b149ecf87295e6baf43ec872510ace2e315699971cb7464744075ffdb9070c1
MD5 4737dae4a3fee7b8d2fc99982d4648ec
BLAKE2b-256 7d3e402415233630b7e948e7a712de05f94257e04e75e4b6e3735e82f1d00d65

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.2-cp311-cp311-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.2-cp311-cp311-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 9db740cad73d831d77d88d9401f144143f6d77c3484a9ca716c13197507a0cc5
MD5 11a21dff208d32d9a95990941b3d080b
BLAKE2b-256 79510a129a7afc8e1d607420796a8bbbfc7619aa50924f90b8185bca3a9dd645

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.2-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.2-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 90d77fc4680329bdc43f17fc37018212456b92428027de7aef5c779940ba35f8
MD5 60530e5203cff753748f16657d7f6357
BLAKE2b-256 2fc0488031027f3889c9419f148f1deea1cdd30ff2555b3f6def352232efc3fa

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.2-cp310-cp310-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.2-cp310-cp310-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 e40fbbe0ec04a99fa560e2db6f770c62f87f6a72028ab89cccd04565de7a455d
MD5 27a3d277de4d73868290b6de1ebe8f41
BLAKE2b-256 8a72da3fe5837560da1d4f89aa82df4a88b14ec5f16f2deb6bb9a05c69a8c4a2

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.2-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.2-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 42b72dffe31c7df3dbaa48e57f2c989c6772cea355e6805c8e5830f5bb96f29b
MD5 8bd5dfaaf7d2e672e661a7ca4508bc1c
BLAKE2b-256 fd21629e6fad281ec7313a8f7bd991cb3843d31bc09f2f50e8dcda587aa8ce9f

See more details on using hashes here.

File details

Details for the file pylate_rs_mkl-1.0.2-cp39-cp39-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for pylate_rs_mkl-1.0.2-cp39-cp39-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 3a336e5fb61829243ee9a98d7b45699bc38044ab809c12a1e44411be9dc2422a
MD5 046931d7be56cb181f9b066cf2976061
BLAKE2b-256 e766e690b09c9d51d233115e015339892b242c2b91a0828622a415ae713c5a0f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page