A lightweight library for fast finetuning of embeddings
Project description
litfit
litfit /lɪt fɪt/ — the shortest path from someone else's embedding to your task.
Why litfit?
Fine-tuning dense embedding models means writing a training loop, picking a loss function, tuning a learning rate, and waiting minutes to hours — whether you're working with text, images, or multimodal embeddings. litfit takes a different approach: given pairs of items that should be similar (duplicates, relevant matches, same-class images), it computes covariance statistics and solves for the optimal linear projection in closed form. No gradient descent, no hyperparameters to babysit.
What you get:
- Fast — one pass over your pairs to collect statistics, then everything is solved in closed form. No iterative training = fast.
- Any dense embeddings — text, vision, multimodal. If it outputs a vector, litfit can probably improve it.
- Simple — pass in embeddings + pair labels, get a well-tuned projection matrix back.
Benchmarks
| Task | Model | R@1 | MAP@50 | Dims | Time (CPU) | |
|---|---|---|---|---|---|---|
| Fashion retrieval | SigLIP2-SO400M | baseline | 0.833 | 0.532 | 1152 | — |
| (DeepFashion In-Shop) | + litfit | 0.923 | 0.738 | 228 | 37s | |
| Duplicate detection | e5-base-v2 | baseline | 0.522 | 0.488 | 768 | — |
| (AskUbuntu) | + litfit | 0.591 | 0.573 | 768 | ~3min |
Embeddings precomputed. DeepFashion In-Shop uses fast mode (~40 configs); AskUbuntu uses full sweep (~860 projections). Closed-form solution means you can safely merge val into training data without risk of overfitting — typically adds a fraction of a second.
Try it yourself — no setup needed
More benchmarks welcome — if you run litfit on your dataset, open an issue or PR with results!
Installation
pip install litfit
For an editable (development) install:
pip install -e ".[dev]"
For faster statistics computation on CUDA GPUs:
pip install triton
Usage
from litfit import (
load_askubuntu, encode_texts, split_data,
compute_stats, generate_all_projections, evaluate_projections,
)
all_ids, all_texts, id_to_group = load_askubuntu(max_groups=1000)
embs = encode_texts("intfloat/e5-base-v2", all_texts)
data = split_data(all_ids, all_texts, embs, id_to_group)
train_ids, _, train_embs, _ = data["train"]
val_ids, _, val_embs, _ = data["val"]
test_ids, _, test_embs, _ = data["test"]
st = compute_stats(train_embs, train_ids, id_to_group)
all_W = generate_all_projections(st, neg=None, include_neg_methods=False)
results, summary = evaluate_projections(
all_W, val_embs, val_ids, id_to_group,
test_embs=test_embs, test_ids=test_ids,
dim_fractions=(0.1, 0.2, 0.5, 1.0),
)
Streaming + fast + dim search (low memory)
Combine streaming statistics, fast projections (~40 configs), lazy evaluation, and automatic dimension search for a memory-efficient pipeline:
from litfit import (
compute_stats_streaming, generate_fast_projections,
find_dim_range, evaluate_projections,
)
def pair_batches():
for i in range(0, len(X_pairs_memmap), 1024):
yield X_pairs_memmap[i:i+1024], Y_pairs_memmap[i:i+1024]
st = compute_stats_streaming(pair_batches())
dim_fractions = find_dim_range(st, val_embs, val_ids, id_to_group)
all_W = generate_fast_projections(st, lazy=True)
results, summary = evaluate_projections(
all_W, test_embs, test_ids, id_to_group,
dim_fractions=dim_fractions,
)
Full walkthrough: data concepts, splitting, extracting the best projection
litfit operates on three data structures:
ids— a list of unique identifiers, one per embedding (strings, ints, anything hashable).id_to_group— a dict mapping each id to a group label. Items that share a group are treated as positives (duplicates / paraphrases / relevant matches). Everything else is a negative.embs— a numpy array or torch tensor of shape(n, d), one row per id.
For example, if questions 0, 1, 2 are duplicates and 3, 4 are duplicates:
ids = [0, 1, 2, 3, 4]
id_to_group = {0: 'A', 1: 'A', 2: 'A', 3: 'B', 4: 'B'}
Here is a complete pipeline — from loading data to exporting a torch.nn.Linear:
import torch
import torch.nn as nn
from litfit import (
load_askubuntu, encode_texts, split_data,
compute_stats, generate_fast_projections,
find_dim_range, evaluate_projections,
)
# --- 1. Load & encode ---
# load_askubuntu returns (ids, texts, id_to_group).
# max_groups limits how many duplicate-groups to keep (for speed).
all_ids, all_texts, id_to_group = load_askubuntu(max_groups=1500)
embs = encode_texts("intfloat/e5-base-v2", all_texts)
# --- 2. Split into train / val / test ---
# split_data does a group-aware split: all items in a group stay together,
# so no group leaks across splits. Default: 60/20/20.
data = split_data(all_ids, all_texts, embs, id_to_group)
train_ids, _, train_embs, _ = data["train"]
val_ids, _, val_embs, _ = data["val"]
test_ids, _, test_embs, _ = data["test"]
# --- 3. Compute sufficient statistics from training pairs ---
# compute_stats builds covariance matrices (Sigma_XX, Sigma_XY, etc.)
# from all positive pairs implied by id_to_group.
st = compute_stats(train_embs, train_ids, id_to_group)
# --- 4. Find useful dimension range ---
# Scans Rayleigh projections at many dims to find where performance peaks.
# Returns dim_fractions focused on the useful range.
dim_fractions = find_dim_range(st, val_embs, val_ids, id_to_group)
# --- 5. Generate & evaluate projections ---
# generate_fast_projections returns ~40 (method, hyperparams) configs.
# evaluate_projections uses explore-exploit scheduling on the val set.
all_W = generate_fast_projections(st)
results, summary = evaluate_projections(
all_W, val_embs, val_ids, id_to_group,
test_embs=test_embs, test_ids=test_ids,
dim_fractions=dim_fractions,
)
# --- 6. Extract the best projection ---
# results keys are tuples like ('m_rayleigh', 'reg=0.1').
# Each value is {n_dims: {'MAP@50': ..., 'R@1': ..., ...}}.
# n_dims=None means full-dimensional.
best_key = max(results, key=lambda k: results[k][None]['MAP@50'])
W = all_W[best_key] # shape (d, d) or (d, k)
# Optionally truncate to the best reduced dimension:
best_dim = 128
projected = test_embs @ W[:, :best_dim] # shape (n, best_dim)
# --- 7. (Optional) Recompute stats on ALL data for best performance ---
# The train split was used for fitting and val/test for model selection.
# Once you've picked the best config, recompute stats on all available
# embeddings so the final projection sees the most signal.
full_st = compute_stats(embs, all_ids, id_to_group)
all_W_full = generate_fast_projections(full_st, verbose=False)
W = all_W_full[best_key]
# --- 8. Export as torch.nn.Linear for inference ---
out_dim = best_dim # or W.shape[1] for full
layer = nn.Linear(W.shape[0], out_dim, bias=False)
layer.weight = nn.Parameter(W[:, :out_dim].T.cpu().float())
# Use it: projected = layer(input_embs)
See the docs for more examples, architecture diagrams, and streaming scripts.
Device Support
- CUDA: Full support with optional Triton acceleration
- CPU: Full support
- MPS: Not supported (missing linalg ops)
How it works
- You provide embeddings and group labels (which items are duplicates/relevant/same-class)
- litfit computes covariance matrices from all positive pairs (the "sufficient statistics")
- It generates ~40 candidate projections in fast mode, or 800+ in full sweep, using different methods (generalized Rayleigh quotients, CCA-style decompositions, asymmetric refinements, MSE regularization)
- You get a projection matrix
W— multiply your embeddings by it and you're done
The result is a linear transformation that can also reduce dimensionality: a 1152-dim SigLIP embedding projected to 228 dims can score better than the original on your task.
Development
pip install -e ".[dev]"
pytest
mypy litfit
black litfit tests
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file litfit-0.1.4.tar.gz.
File metadata
- Download URL: litfit-0.1.4.tar.gz
- Upload date:
- Size: 37.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c286fb59330fc610e51d2719623235b285287405d8702525fa5cd21c1eaa54d8
|
|
| MD5 |
bae20aa4e13288fb3257a3361ba47493
|
|
| BLAKE2b-256 |
0dbda9254be774b4fcdc6c2002191df51b64965da538f1c5bb4d430c028b324a
|
Provenance
The following attestation bundles were made for litfit-0.1.4.tar.gz:
Publisher:
publish.yml on b0nce/litfit
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
litfit-0.1.4.tar.gz -
Subject digest:
c286fb59330fc610e51d2719623235b285287405d8702525fa5cd21c1eaa54d8 - Sigstore transparency entry: 1005016450
- Sigstore integration time:
-
Permalink:
b0nce/litfit@ffaa56715a794d4426bc665889c7b32ac68636bf -
Branch / Tag:
refs/tags/v0.1.4 - Owner: https://github.com/b0nce
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@ffaa56715a794d4426bc665889c7b32ac68636bf -
Trigger Event:
push
-
Statement type:
File details
Details for the file litfit-0.1.4-py3-none-any.whl.
File metadata
- Download URL: litfit-0.1.4-py3-none-any.whl
- Upload date:
- Size: 27.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a987c0a5ce2ea5ae6a55c2b919397788330613bcf09c56fc3f9c6ddff94c0b14
|
|
| MD5 |
e386351b9981826cae436bcba7f94401
|
|
| BLAKE2b-256 |
ff7d28f2376c007d7d724ee70fb3f7d5cf955006ff38030645afb5fa036098fb
|
Provenance
The following attestation bundles were made for litfit-0.1.4-py3-none-any.whl:
Publisher:
publish.yml on b0nce/litfit
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
litfit-0.1.4-py3-none-any.whl -
Subject digest:
a987c0a5ce2ea5ae6a55c2b919397788330613bcf09c56fc3f9c6ddff94c0b14 - Sigstore transparency entry: 1005016453
- Sigstore integration time:
-
Permalink:
b0nce/litfit@ffaa56715a794d4426bc665889c7b32ac68636bf -
Branch / Tag:
refs/tags/v0.1.4 - Owner: https://github.com/b0nce
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@ffaa56715a794d4426bc665889c7b32ac68636bf -
Trigger Event:
push
-
Statement type: