Skip to main content

A lightweight library for fast finetuning of embeddings

Project description

litfit

CI PyPI Python License

litfit /lɪt fɪt/ — the shortest path from someone else's embedding to your task.

Why litfit?

Fine-tuning dense embedding models means writing a training loop, picking a loss function, tuning a learning rate, and waiting minutes to hours — whether you're working with text, images, or multimodal embeddings. litfit takes a different approach: given pairs of items that should be similar (duplicates, relevant matches, same-class images), it computes covariance statistics and solves for the optimal linear projection in closed form. No gradient descent, no hyperparameters to babysit.

What you get:

  • Fast — one pass over your pairs to collect statistics, then everything is solved in closed form. No iterative training = fast.
  • Any dense embeddings — text, vision, multimodal. If it outputs a vector, litfit can probably improve it.
  • Simple — pass in embeddings + pair labels, get a well-tuned projection matrix back.

Benchmarks

Task Model R@1 MAP@50 Dims Time (CPU)
Fashion retrieval SigLIP2-SO400M baseline 0.833 0.532 1152
(DeepFashion In-Shop) + litfit 0.923 0.738 228 37s
Duplicate detection e5-base-v2 baseline 0.556 0.522 768
(AskUbuntu) + litfit 0.598 0.590 294 ~3min

Embeddings precomputed. DeepFashion In-Shop uses fast mode (~40 configs); AskUbuntu uses full sweep (~860 projections). Closed-form solution means you can safely merge val into training data without risk of overfitting — typically adds a fraction of a second.

Try it yourself — no setup needed Open In Colab

More benchmarks welcome — if you run litfit on your dataset, open an issue or PR with results!

Installation

pip install litfit

For an editable (development) install:

pip install -e ".[dev]"

For faster statistics computation on CUDA GPUs:

pip install triton

Usage

from litfit import (
    load_askubuntu, encode_texts, split_data,
    compute_stats, generate_all_projections, evaluate_projections,
)

all_ids, all_texts, id_to_group = load_askubuntu(max_groups=1000)
embs = encode_texts("intfloat/e5-base-v2", all_texts)
data = split_data(all_ids, all_texts, embs, id_to_group)
train_ids, _, train_embs, _ = data["train"]
val_ids, _, val_embs, _ = data["val"]
test_ids, _, test_embs, _ = data["test"]

st = compute_stats(train_embs, train_ids, id_to_group)
all_W = generate_all_projections(st, neg=None, include_neg_methods=False)
results, summary = evaluate_projections(
    all_W, val_embs, val_ids, id_to_group,
    test_embs=test_embs, test_ids=test_ids,
    dim_fractions=(0.1, 0.2, 0.5, 1.0),
)

Streaming + fast + dim search (low memory)

Combine streaming statistics, fast projections (~40 configs), lazy evaluation, and automatic dimension search for a memory-efficient pipeline:

from litfit import (
    compute_stats_streaming, generate_fast_projections,
    find_dim_range, evaluate_projections,
)

def pair_batches():
    for i in range(0, len(X_pairs_memmap), 1024):
        yield X_pairs_memmap[i:i+1024], Y_pairs_memmap[i:i+1024]

st = compute_stats_streaming(pair_batches())
dim_fractions = find_dim_range(st, val_embs, val_ids, id_to_group)
all_W = generate_fast_projections(st, lazy=True)
results, summary = evaluate_projections(
    all_W, test_embs, test_ids, id_to_group,
    dim_fractions=dim_fractions,
)
Full walkthrough: data concepts, splitting, extracting the best projection

litfit operates on three data structures:

  • ids — a list of unique identifiers, one per embedding (strings, ints, anything hashable).
  • id_to_group — a dict mapping each id to a group label. Items that share a group are treated as positives (duplicates / paraphrases / relevant matches). Everything else is a negative.
  • embs — a numpy array or torch tensor of shape (n, d), one row per id.

For example, if questions 0, 1, 2 are duplicates and 3, 4 are duplicates:

ids = [0, 1, 2, 3, 4]
id_to_group = {0: 'A', 1: 'A', 2: 'A', 3: 'B', 4: 'B'}

Here is a complete pipeline — from loading data to exporting a torch.nn.Linear:

import torch
import torch.nn as nn
from litfit import (
    load_askubuntu, encode_texts, split_data,
    compute_stats, generate_fast_projections,
    find_dim_range, evaluate_projections,
)

# --- 1. Load & encode ---
# load_askubuntu returns (ids, texts, id_to_group).
# max_groups limits how many duplicate-groups to keep (for speed).
all_ids, all_texts, id_to_group = load_askubuntu(max_groups=1500)
embs = encode_texts("intfloat/e5-base-v2", all_texts)

# --- 2. Split into train / val / test ---
# split_data does a group-aware split: all items in a group stay together,
# so no group leaks across splits. Default: 60/20/20.
data = split_data(all_ids, all_texts, embs, id_to_group)
train_ids, _, train_embs, _ = data["train"]
val_ids,   _, val_embs,   _ = data["val"]
test_ids,  _, test_embs,  _ = data["test"]

# --- 3. Compute sufficient statistics from training pairs ---
# compute_stats builds covariance matrices (Sigma_XX, Sigma_XY, etc.)
# from all positive pairs implied by id_to_group.
st = compute_stats(train_embs, train_ids, id_to_group)

# --- 4. Find useful dimension range ---
# Scans Rayleigh projections at many dims to find where performance peaks.
# Returns dim_fractions focused on the useful range.
dim_fractions = find_dim_range(st, val_embs, val_ids, id_to_group)

# --- 5. Generate & evaluate projections ---
# generate_fast_projections returns ~40 (method, hyperparams) configs.
# evaluate_projections uses explore-exploit scheduling on the val set.
all_W = generate_fast_projections(st)
results, summary = evaluate_projections(
    all_W, val_embs, val_ids, id_to_group,
    test_embs=test_embs, test_ids=test_ids,
    dim_fractions=dim_fractions,
)

# --- 6. Extract the best projection ---
# results keys are tuples like ('m_rayleigh', 'reg=0.1').
# Each value is {n_dims: {'MAP@50': ..., 'R@1': ..., ...}}.
# n_dims=None means full-dimensional.
best_key = max(results, key=lambda k: results[k][None]['MAP@50'])
W = all_W[best_key]                # shape (d, d) or (d, k)

# Optionally truncate to the best reduced dimension:
best_dim = 128
projected = test_embs @ W[:, :best_dim]  # shape (n, best_dim)

# --- 7. (Optional) Recompute stats on ALL data for best performance ---
# The train split was used for fitting and val/test for model selection.
# Once you've picked the best config, recompute stats on all available
# embeddings so the final projection sees the most signal.
full_st = compute_stats(embs, all_ids, id_to_group)
all_W_full = generate_fast_projections(full_st, verbose=False)
W = all_W_full[best_key]

# --- 8. Export as torch.nn.Linear for inference ---
out_dim = best_dim             # or W.shape[1] for full
layer = nn.Linear(W.shape[0], out_dim, bias=False)
layer.weight = nn.Parameter(W[:, :out_dim].T.cpu().float())
# Use it: projected = layer(input_embs)

See the docs for more examples, architecture diagrams, and streaming scripts.

Device Support

  • CUDA: Full support with optional Triton acceleration
  • CPU: Full support
  • MPS: Not supported (missing linalg ops)

How it works

  1. You provide embeddings and group labels (which items are duplicates/relevant/same-class)
  2. litfit computes covariance matrices from all positive pairs (the "sufficient statistics")
  3. It generates ~40 candidate projections in fast mode, or 800+ in full sweep, using different methods (generalized Rayleigh quotients, CCA-style decompositions, asymmetric refinements, MSE regularization)
  4. You get a projection matrix W — multiply your embeddings by it and you're done

The result is a linear transformation that can also reduce dimensionality: a 1152-dim SigLIP embedding projected to 228 dims can score better than the original on your task.

Development

pip install -e ".[dev]"
pytest
mypy litfit
black litfit tests

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

litfit-0.1.5.tar.gz (37.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

litfit-0.1.5-py3-none-any.whl (27.8 kB view details)

Uploaded Python 3

File details

Details for the file litfit-0.1.5.tar.gz.

File metadata

  • Download URL: litfit-0.1.5.tar.gz
  • Upload date:
  • Size: 37.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for litfit-0.1.5.tar.gz
Algorithm Hash digest
SHA256 a9ec420c024153c6966c520c41ca1b32fd842edcef383fc200d2a4c139caeea7
MD5 98c1bd7b528b240a443133c1bf547515
BLAKE2b-256 4cea1947a2125bb842e22e3b46fb8da9bb3d49334c1ca0c17864f190ac227a57

See more details on using hashes here.

Provenance

The following attestation bundles were made for litfit-0.1.5.tar.gz:

Publisher: publish.yml on b0nce/litfit

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file litfit-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: litfit-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 27.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for litfit-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 8d9b9596e4821ab96bfc3ca8594a8f60995dda26f8ff0c163ae52efb456d6e21
MD5 c2885388b1ade9b6a3bd73dfd51671ec
BLAKE2b-256 1893dfe1615ceb2b472800c8bb7e1c3c1af9626989984bf60859bfc34d783b80

See more details on using hashes here.

Provenance

The following attestation bundles were made for litfit-0.1.5-py3-none-any.whl:

Publisher: publish.yml on b0nce/litfit

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page