Skip to main content

A lightweight library for fast finetuning of embeddings

Project description

litfit

CI PyPI Python License

litfit /lɪt fɪt/ — the shortest path from someone else's embedding to your task.

Why litfit?

Fine-tuning dense embedding models means writing a training loop, picking a loss function, tuning a learning rate, and waiting minutes to hours — whether you're working with text, images, or multimodal embeddings. litfit takes a different approach: given pairs of items that should be similar (duplicates, relevant matches, same-class images), it computes covariance statistics and solves for the optimal linear projection in closed form. No gradient descent, no hyperparameters to babysit.

What you get:

  • Fast — one pass over your pairs to collect statistics, then everything is solved in closed form. No iterative training = fast.
  • Any dense embeddings — text, vision, multimodal. If it outputs a vector, litfit can probably improve it.
  • Simple — pass in embeddings + pair labels, get a well-tuned projection matrix back.

Benchmarks

Task Model R@1 MAP@50 Dims Time (CPU)
Fashion retrieval SigLIP2-SO400M baseline 0.833 0.532 1152
(DeepFashion In-Shop) + litfit 0.923 0.738 228 37s
Duplicate detection e5-base-v2 baseline 0.522 0.488 768
(AskUbuntu) + litfit 0.591 0.573 768 ~3min

Embeddings precomputed. DeepFashion In-Shop uses fast mode (~40 configs); AskUbuntu uses full sweep (~860 projections). Closed-form solution means you can safely merge val into training data without risk of overfitting — typically adds a fraction of a second.

Try it yourself — no setup needed Open In Colab

More benchmarks welcome — if you run litfit on your dataset, open an issue or PR with results!

Installation

pip install litfit

For an editable (development) install:

pip install -e ".[dev]"

For faster statistics computation on CUDA GPUs:

pip install triton

Usage

from litfit import (
    load_askubuntu, encode_texts, split_data,
    compute_stats, generate_all_projections, evaluate_projections,
)

all_ids, all_texts, id_to_group = load_askubuntu(max_groups=1000)
embs = encode_texts("intfloat/e5-base-v2", all_texts)
data = split_data(all_ids, all_texts, embs, id_to_group)
train_ids, _, train_embs, _ = data["train"]
val_ids, _, val_embs, _ = data["val"]
test_ids, _, test_embs, _ = data["test"]

st = compute_stats(train_embs, train_ids, id_to_group)
all_W = generate_all_projections(st, neg=None, include_neg_methods=False)
results, summary = evaluate_projections(
    all_W, val_embs, val_ids, id_to_group,
    test_embs=test_embs, test_ids=test_ids,
    dim_fractions=(0.1, 0.2, 0.5, 1.0),
)

Streaming + fast + dim search (low memory)

Combine streaming statistics, fast projections (~40 configs), lazy evaluation, and automatic dimension search for a memory-efficient pipeline:

from litfit import (
    compute_stats_streaming, generate_fast_projections,
    find_dim_range, evaluate_projections,
)

def pair_batches():
    for i in range(0, len(X_pairs_memmap), 1024):
        yield X_pairs_memmap[i:i+1024], Y_pairs_memmap[i:i+1024]

st = compute_stats_streaming(pair_batches())
dim_fractions = find_dim_range(st, val_embs, val_ids, id_to_group)
all_W = generate_fast_projections(st, lazy=True)
results, summary = evaluate_projections(
    all_W, test_embs, test_ids, id_to_group,
    dim_fractions=dim_fractions,
)
Full walkthrough: data concepts, splitting, extracting the best projection

litfit operates on three data structures:

  • ids — a list of unique identifiers, one per embedding (strings, ints, anything hashable).
  • id_to_group — a dict mapping each id to a group label. Items that share a group are treated as positives (duplicates / paraphrases / relevant matches). Everything else is a negative.
  • embs — a numpy array or torch tensor of shape (n, d), one row per id.

For example, if questions 0, 1, 2 are duplicates and 3, 4 are duplicates:

ids = [0, 1, 2, 3, 4]
id_to_group = {0: 'A', 1: 'A', 2: 'A', 3: 'B', 4: 'B'}

Here is a complete pipeline — from loading data to exporting a torch.nn.Linear:

import torch
import torch.nn as nn
from litfit import (
    load_askubuntu, encode_texts, split_data,
    compute_stats, generate_fast_projections,
    find_dim_range, evaluate_projections,
)

# --- 1. Load & encode ---
# load_askubuntu returns (ids, texts, id_to_group).
# max_groups limits how many duplicate-groups to keep (for speed).
all_ids, all_texts, id_to_group = load_askubuntu(max_groups=1500)
embs = encode_texts("intfloat/e5-base-v2", all_texts)

# --- 2. Split into train / val / test ---
# split_data does a group-aware split: all items in a group stay together,
# so no group leaks across splits. Default: 60/20/20.
data = split_data(all_ids, all_texts, embs, id_to_group)
train_ids, _, train_embs, _ = data["train"]
val_ids,   _, val_embs,   _ = data["val"]
test_ids,  _, test_embs,  _ = data["test"]

# --- 3. Compute sufficient statistics from training pairs ---
# compute_stats builds covariance matrices (Sigma_XX, Sigma_XY, etc.)
# from all positive pairs implied by id_to_group.
st = compute_stats(train_embs, train_ids, id_to_group)

# --- 4. Find useful dimension range ---
# Scans Rayleigh projections at many dims to find where performance peaks.
# Returns dim_fractions focused on the useful range.
dim_fractions = find_dim_range(st, val_embs, val_ids, id_to_group)

# --- 5. Generate & evaluate projections ---
# generate_fast_projections returns ~40 (method, hyperparams) configs.
# evaluate_projections uses explore-exploit scheduling on the val set.
all_W = generate_fast_projections(st)
results, summary = evaluate_projections(
    all_W, val_embs, val_ids, id_to_group,
    test_embs=test_embs, test_ids=test_ids,
    dim_fractions=dim_fractions,
)

# --- 6. Extract the best projection ---
# results keys are tuples like ('m_rayleigh', 'reg=0.1').
# Each value is {n_dims: {'MAP@50': ..., 'R@1': ..., ...}}.
# n_dims=None means full-dimensional.
best_key = max(results, key=lambda k: results[k][None]['MAP@50'])
W = all_W[best_key]                # shape (d, d) or (d, k)

# Optionally truncate to the best reduced dimension:
best_dim = 128
projected = test_embs @ W[:, :best_dim]  # shape (n, best_dim)

# --- 7. (Optional) Recompute stats on ALL data for best performance ---
# The train split was used for fitting and val/test for model selection.
# Once you've picked the best config, recompute stats on all available
# embeddings so the final projection sees the most signal.
full_st = compute_stats(embs, all_ids, id_to_group)
all_W_full = generate_fast_projections(full_st, verbose=False)
W = all_W_full[best_key]

# --- 8. Export as torch.nn.Linear for inference ---
out_dim = best_dim             # or W.shape[1] for full
layer = nn.Linear(W.shape[0], out_dim, bias=False)
layer.weight = nn.Parameter(W[:, :out_dim].T.cpu().float())
# Use it: projected = layer(input_embs)

See the docs for more examples, architecture diagrams, and streaming scripts.

Device Support

  • CUDA: Full support with optional Triton acceleration
  • CPU: Full support
  • MPS: Not supported (missing linalg ops)

How it works

  1. You provide embeddings and group labels (which items are duplicates/relevant/same-class)
  2. litfit computes covariance matrices from all positive pairs (the "sufficient statistics")
  3. It generates ~40 candidate projections in fast mode, or 800+ in full sweep, using different methods (generalized Rayleigh quotients, CCA-style decompositions, asymmetric refinements, MSE regularization)
  4. You get a projection matrix W — multiply your embeddings by it and you're done

The result is a linear transformation that can also reduce dimensionality: a 1152-dim SigLIP embedding projected to 228 dims can score better than the original on your task.

Development

pip install -e ".[dev]"
pytest
mypy litfit
black litfit tests

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

litfit-0.1.4.tar.gz (37.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

litfit-0.1.4-py3-none-any.whl (27.8 kB view details)

Uploaded Python 3

File details

Details for the file litfit-0.1.4.tar.gz.

File metadata

  • Download URL: litfit-0.1.4.tar.gz
  • Upload date:
  • Size: 37.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for litfit-0.1.4.tar.gz
Algorithm Hash digest
SHA256 c286fb59330fc610e51d2719623235b285287405d8702525fa5cd21c1eaa54d8
MD5 bae20aa4e13288fb3257a3361ba47493
BLAKE2b-256 0dbda9254be774b4fcdc6c2002191df51b64965da538f1c5bb4d430c028b324a

See more details on using hashes here.

Provenance

The following attestation bundles were made for litfit-0.1.4.tar.gz:

Publisher: publish.yml on b0nce/litfit

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file litfit-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: litfit-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 27.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for litfit-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 a987c0a5ce2ea5ae6a55c2b919397788330613bcf09c56fc3f9c6ddff94c0b14
MD5 e386351b9981826cae436bcba7f94401
BLAKE2b-256 ff7d28f2376c007d7d724ee70fb3f7d5cf955006ff38030645afb5fa036098fb

See more details on using hashes here.

Provenance

The following attestation bundles were made for litfit-0.1.4-py3-none-any.whl:

Publisher: publish.yml on b0nce/litfit

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page