Skip to main content

A lightweight library for fast finetuning of embeddings

Project description

litfit

CI PyPI Python License

litfit /lɪt fɪt/ — the shortest path from someone else's embedding to your task.

Why litfit?

Fine-tuning dense embedding models means writing a training loop, picking a loss function, tuning a learning rate, and waiting minutes to hours — whether you're working with text, images, or multimodal embeddings. litfit takes a different approach: given pairs of items that should be similar (duplicates, relevant matches, same-class images), it computes covariance statistics and solves for the optimal linear projection in closed form. No gradient descent, no hyperparameters to babysit.

What you get:

  • Fast — one pass over your pairs to collect statistics, then everything is solved in closed form. No iterative training = fast.
  • Any dense embeddings — text, vision, multimodal. If it outputs a vector, litfit can probably improve it.
  • Simple — pass in embeddings + pair labels, get a well-tuned projection matrix back.

Benchmarks

Task Model R@1 MAP@50 Dims Time (CPU)
Fashion retrieval SigLIP2-SO400M baseline 0.833 0.532 1152
(DeepFashion In-Shop) + litfit 0.923 0.738 228 37s
Duplicate detection e5-base-v2 baseline 0.556 0.522 768
(AskUbuntu) + litfit 0.598 0.590 294 ~3min

Embeddings precomputed. DeepFashion In-Shop uses fast mode (~40 configs); AskUbuntu uses full sweep (~860 projections). Closed-form solution means you can safely merge val into training data without risk of overfitting — typically adds a fraction of a second.

Try it yourself — no setup needed Open In Colab

More benchmarks welcome — if you run litfit on your dataset, open an issue or PR with results!

Installation

pip install litfit

For an editable (development) install:

pip install -e ".[dev]"

For faster statistics computation on CUDA GPUs:

pip install triton

Usage

from litfit import (
    load_askubuntu, encode_texts, split_data,
    compute_stats, generate_all_projections, evaluate_projections,
)

all_ids, all_texts, id_to_group = load_askubuntu(max_groups=1000)
embs = encode_texts("intfloat/e5-base-v2", all_texts)
data = split_data(all_ids, all_texts, embs, id_to_group)
train_ids, _, train_embs, _ = data["train"]
val_ids, _, val_embs, _ = data["val"]
test_ids, _, test_embs, _ = data["test"]

st = compute_stats(train_embs, train_ids, id_to_group)
all_W = generate_all_projections(st, neg=None, include_neg_methods=False)
results, summary = evaluate_projections(
    all_W, val_embs, val_ids, id_to_group,
    test_embs=test_embs, test_ids=test_ids,
    dim_fractions=(0.1, 0.2, 0.5, 1.0),
)

Streaming + fast + dim search (low memory)

Combine streaming statistics, fast projections (~40 configs), lazy evaluation, and automatic dimension search for a memory-efficient pipeline:

from litfit import (
    compute_stats_streaming, generate_fast_projections,
    find_dim_range, evaluate_projections,
)

def pair_batches():
    for i in range(0, len(X_pairs_memmap), 1024):
        yield X_pairs_memmap[i:i+1024], Y_pairs_memmap[i:i+1024]

st = compute_stats_streaming(pair_batches())
dim_fractions = find_dim_range(st, val_embs, val_ids, id_to_group)
all_W = generate_fast_projections(st, lazy=True)
results, summary = evaluate_projections(
    all_W, test_embs, test_ids, id_to_group,
    dim_fractions=dim_fractions,
)
Full walkthrough: data concepts, splitting, extracting the best projection

litfit operates on three data structures:

  • ids — a list of unique identifiers, one per embedding (strings, ints, anything hashable).
  • id_to_group — a dict mapping each id to a group label. Items that share a group are treated as positives (duplicates / paraphrases / relevant matches). Everything else is a negative.
  • embs — a numpy array or torch tensor of shape (n, d), one row per id.

For example, if questions 0, 1, 2 are duplicates and 3, 4 are duplicates:

ids = [0, 1, 2, 3, 4]
id_to_group = {0: 'A', 1: 'A', 2: 'A', 3: 'B', 4: 'B'}

Here is a complete pipeline — from loading data to exporting a torch.nn.Linear:

import torch
import torch.nn as nn
from litfit import (
    load_askubuntu, encode_texts, split_data,
    compute_stats, generate_fast_projections,
    find_dim_range, evaluate_projections,
)

# --- 1. Load & encode ---
# load_askubuntu returns (ids, texts, id_to_group).
# max_groups limits how many duplicate-groups to keep (for speed).
all_ids, all_texts, id_to_group = load_askubuntu(max_groups=1500)
embs = encode_texts("intfloat/e5-base-v2", all_texts)

# --- 2. Split into train / val / test ---
# split_data does a group-aware split: all items in a group stay together,
# so no group leaks across splits. Default: 60/20/20.
data = split_data(all_ids, all_texts, embs, id_to_group)
train_ids, _, train_embs, _ = data["train"]
val_ids,   _, val_embs,   _ = data["val"]
test_ids,  _, test_embs,  _ = data["test"]

# --- 3. Compute sufficient statistics from training pairs ---
# compute_stats builds covariance matrices (Sigma_XX, Sigma_XY, etc.)
# from all positive pairs implied by id_to_group.
st = compute_stats(train_embs, train_ids, id_to_group)

# --- 4. Find useful dimension range ---
# Scans Rayleigh projections at many dims to find where performance peaks.
# Returns dim_fractions focused on the useful range.
dim_fractions = find_dim_range(st, val_embs, val_ids, id_to_group)

# --- 5. Generate & evaluate projections ---
# generate_fast_projections returns ~40 (method, hyperparams) configs.
# evaluate_projections uses explore-exploit scheduling on the val set.
all_W = generate_fast_projections(st)
results, summary = evaluate_projections(
    all_W, val_embs, val_ids, id_to_group,
    test_embs=test_embs, test_ids=test_ids,
    dim_fractions=dim_fractions,
)

# --- 6. Extract the best projection ---
# results keys are tuples like ('m_rayleigh', 'reg=0.1').
# Each value is {n_dims: {'MAP@50': ..., 'R@1': ..., ...}}.
# n_dims=None means full-dimensional.
best_key = max(results, key=lambda k: results[k][None]['MAP@50'])
W = all_W[best_key]                # shape (d, d) or (d, k)

# Optionally truncate to the best reduced dimension:
best_dim = 128
projected = test_embs @ W[:, :best_dim]  # shape (n, best_dim)

# --- 7. (Optional) Recompute stats on ALL data for best performance ---
# The train split was used for fitting and val/test for model selection.
# Once you've picked the best config, recompute stats on all available
# embeddings so the final projection sees the most signal.
full_st = compute_stats(embs, all_ids, id_to_group)
all_W_full = generate_fast_projections(full_st, verbose=False)
W = all_W_full[best_key]

# --- 8. Export as torch.nn.Linear for inference ---
out_dim = best_dim             # or W.shape[1] for full
layer = nn.Linear(W.shape[0], out_dim, bias=False)
layer.weight = nn.Parameter(W[:, :out_dim].T.cpu().float())
# Use it: projected = layer(input_embs)

See the docs for more examples, architecture diagrams, and streaming scripts.

Device Support

  • CUDA: Full support with optional Triton acceleration
  • CPU: Full support
  • MPS: Not supported (missing linalg ops)

How it works

  1. You provide embeddings and group labels (which items are duplicates/relevant/same-class)
  2. litfit computes covariance matrices from all positive pairs (the "sufficient statistics")
  3. It generates ~40 candidate projections in fast mode, or 800+ in full sweep, using different methods (generalized Rayleigh quotients, CCA-style decompositions, asymmetric refinements, MSE regularization)
  4. You get a projection matrix W — multiply your embeddings by it and you're done

The result is a linear transformation that can also reduce dimensionality: a 1152-dim SigLIP embedding projected to 228 dims can score better than the original on your task.

Development

pip install -e ".[dev]"
pytest
mypy litfit
black litfit tests

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

litfit-0.1.6.tar.gz (38.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

litfit-0.1.6-py3-none-any.whl (28.7 kB view details)

Uploaded Python 3

File details

Details for the file litfit-0.1.6.tar.gz.

File metadata

  • Download URL: litfit-0.1.6.tar.gz
  • Upload date:
  • Size: 38.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for litfit-0.1.6.tar.gz
Algorithm Hash digest
SHA256 d1a0fb1498b30bf543429be0cbd1563e4b30b30f7ad03566e9c53a92e3ce4186
MD5 18cf6910c8e4f4ba33ef7a84438c7952
BLAKE2b-256 9be768c36bcb71a40b0a9e7c646256319c351293289f6204a1d75b5e155da6f4

See more details on using hashes here.

Provenance

The following attestation bundles were made for litfit-0.1.6.tar.gz:

Publisher: publish.yml on b0nce/litfit

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file litfit-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: litfit-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 28.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for litfit-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 3564e09480aee01fe4b96c587f1044dee9b1401837d8cdf17ad0228d13af2d4a
MD5 effd7094f8bcf986aed5d98b14be6d1d
BLAKE2b-256 a2a069532e48d75aba1938e95d1d0aef51f56905f4444f1a870f551fc4dfc10f

See more details on using hashes here.

Provenance

The following attestation bundles were made for litfit-0.1.6-py3-none-any.whl:

Publisher: publish.yml on b0nce/litfit

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page