Hugging Face Transformers image embedding adapter with a scikit-learn KNN classification head.
Project description
transformers-knn-adapter
Hugging Face image embeddings with a scikit-learn KNN head.
transformers_knn_adapter extends Hugging Face image models by attaching a scikit-learn KNN classifier on top of transformer embeddings.
Requirements
- Python 3.11+
uv
Setup
uv sync --dev
Run Tests
uv run pytest
Train
uv run python -m transformers_knn_adapter.knn_image_pipeline train \
--model microsoft/resnet-50 \
--knn-model-path /tmp/knn/dinov2_small_mini_imagenet_full.joblib \
--dataset timm/mini-imagenet \
--split train \
--max-samples 1000 \
--shuffle \
--grid-search \
--grid-search-splits 3 \
--grid-search-repeats 2 \
--grid-search-scoring f1_macro
Evaluate
uv run python -m transformers_knn_adapter.knn_image_pipeline eval \
--model microsoft/resnet-50 \
--knn-model-path /tmp/knn/dinov2_small_mini_imagenet_full.joblib \
--dataset timm/mini-imagenet \
--split test \
--stratified \
--max-samples 100 \
--shuffle \
--batch-size 100
Inference
uv run python -m transformers_knn_adapter.knn_image_pipeline infer \
--model microsoft/resnet-50 \
--knn-model-path /tmp/knn/dinov2_small_mini_imagenet_full.joblib \
--image https://picsum.photos/200 \
--inference-batch-size 5
CLI Arguments
train
| Argument | Required | Default | Description |
|---|---|---|---|
--model |
Yes | - | HF model id/path for feature extraction |
--knn-model-path |
Yes | - | Path to save/load KNN model (.joblib) |
--dataset |
Yes | - | HF dataset name or local imagefolder path |
--split |
No | train |
Dataset split |
--image-column |
No | image |
Dataset image column |
--label-column |
No | label |
Dataset label column |
--batch-size |
No | 16 |
Embedding batch size |
--stream |
No | false |
Enable streaming mode |
--stratified |
No | false |
Stratified sampling (--max-samples subset size, non-streaming only) |
--shuffle |
No | false |
Shuffle dataset before sampling/training |
--shuffle-seed |
No | 42 |
Shuffle seed |
--shuffle-buffer-size |
No | 1000 |
Streaming shuffle buffer size |
--max-samples |
No | None |
Optional training sample cap |
--n-neighbors |
No | None |
KNN neighbors (mutually exclusive with --grid-search) |
--grid-search |
No | false |
Run GridSearchCV over neighbors/metrics |
--grid-search-splits |
No | None |
Stratified splits per repeat for grid search |
--grid-search-repeats |
No | None |
Repeat count for stratified folds in grid search |
--grid-search-scoring |
No | None |
Grid-search scoring metric (f1_macro, precision_macro, recall_macro) |
--top-k |
No | 2 |
Top-k at inference time |
--device |
No | -1 |
Transformers device index (-1 for CPU) |
eval
| Argument | Required | Default | Description |
|---|---|---|---|
--model |
Yes | - | HF model id/path for feature extraction |
--knn-model-path |
Yes | - | Path to trained KNN model (.joblib) |
--dataset |
Yes | - | HF dataset name or local imagefolder path |
--split |
No | validation |
Dataset split |
--image-column |
No | image |
Dataset image column |
--label-column |
No | label |
Dataset label column |
--batch-size |
No | 16 |
Embedding batch size |
--stream |
No | false |
Enable streaming mode |
--stratified |
No | false |
Stratified sampling (--max-samples subset size, non-streaming only) |
--shuffle |
No | false |
Shuffle dataset before evaluation |
--shuffle-seed |
No | 42 |
Shuffle seed |
--shuffle-buffer-size |
No | 1000 |
Streaming shuffle buffer size |
--max-samples |
No | None |
Optional evaluation sample cap |
--min-class-instances |
No | None |
Drop classes with fewer than this number of eval instances |
--negative-classes |
No | other |
Comma-separated classes treated as negative |
--positive-classes-population-ratio |
No | None |
Target positive/(total) ratio after subsampling |
--top-k |
No | 1 |
Top-k predictions (evaluation uses top-1) |
--device |
No | -1 |
Transformers device index (-1 for CPU) |
infer
| Argument | Required | Default | Description |
|---|---|---|---|
--model |
Yes | - | HF model id/path for feature extraction |
--knn-model-path |
Yes | - | Path to save/load KNN model (.joblib) |
--top-k |
No | 3 |
Top-k predictions |
--device |
No | -1 |
Transformers device index (-1 for CPU) |
--image |
No | https://picsum.photos/200 |
Image input (file path or URL) |
--inference-batch-size |
No | 5 |
Number of images for batched inference |
predict
| Argument | Required | Default | Description |
|---|---|---|---|
--model |
Yes | - | HF model id/path for feature extraction |
--knn-model-path |
Yes | - | Path to trained KNN model (.joblib) |
--image |
Yes | - | Image path/URL accepted by Transformers image pipeline |
--top-k |
No | 3 |
Top-k predictions |
--device |
No | -1 |
Transformers device index (-1 for CPU) |
Python API
from transformers_knn_adapter import pipeline
clf = pipeline(
"image-classification",
model_path="microsoft/resnet-50",
knn_model_path="/tmp/knn/model.joblib",
)
Train from Python
clf.train(
dataset="timm/mini-imagenet",
split="train",
max_samples=1000,
shuffle=True,
grid_search=True,
grid_search_splits=3,
grid_search_repeats=2,
grid_search_scoring="f1_macro",
)
Evaluate from Python
metrics = clf.evaluate(
dataset="timm/mini-imagenet",
split="test",
max_samples=100,
shuffle=True,
batch_size=100,
)
print(metrics["top1_accuracy"])
Inference from Python
single = clf("https://picsum.photos/200")
batch = clf(["https://picsum.photos/200"] * 5)
print(single)
print(batch)
Notes
Real train/eval runs can download model and dataset artifacts from Hugging Face.
Trainer Integration
The package also provides:
Dinov2ForImageClassificationWithArcFaceLossKNNCallbackFreezeScheduleCallback
These are useful when you already have a working Hugging Face Trainer script for Dinov2ForImageClassification and want to:
- replace the classifier loss with ArcFace
- log KNN retrieval metrics during evaluation
- freeze and unfreeze modules by epoch
ArcFace Dinov2 Model
Dinov2ForImageClassificationWithArcFaceLoss is intended as a Trainer-compatible replacement for Dinov2ForImageClassification.
Example:
from transformers import AutoConfig, AutoImageProcessor, Trainer, TrainingArguments
from transformers_knn_adapter import Dinov2ForImageClassificationWithArcFaceLoss
config = AutoConfig.from_pretrained("facebook/dinov2-small")
config.num_labels = num_labels
config.label2id = label2id
config.id2label = id2label
config.arcface_margin = 28.6
config.arcface_scale = 64.0
image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-small")
model = Dinov2ForImageClassificationWithArcFaceLoss.from_pretrained(
"facebook/dinov2-small",
config=config,
ignore_mismatched_sizes=True,
)
training_args = TrainingArguments(
output_dir="/tmp/arcface-run",
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
dataloader_num_workers=4,
learning_rate=1e-3,
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
processing_class=image_processor,
compute_metrics=compute_metrics,
)
trainer.train()
Notes:
- training loss uses ArcFace logits
outputs.logitsreturned to Trainer evaluation/prediction are inference logits, not margin-modified training logits- the ArcFace head weight matrix is initialized automatically when loading from a plain Dinov2 checkpoint
KNN Callback
KNNCallback evaluates embeddings with a scikit-learn KNN classifier during Trainer.evaluate().
Current embedding sources:
clscls_mean
Example:
from transformers_knn_adapter import KNNCallback
trainer.add_callback(
KNNCallback(
trainer=trainer,
label_column="labels",
ks=(1, 5),
embedding_source="cls_mean",
)
)
Requirements:
- train and eval datasets must yield
pixel_values - train and eval datasets must expose the label column you pass to
label_column - for ViT-like models, hidden states must be available so the callback can extract token embeddings
Current callback behavior:
- uses
trainer.args.per_device_eval_batch_sizeby default - uses
trainer.args.dataloader_num_workersby default - fits KNN once at
max(ks)and derives smaller-kmetrics from the same neighbor query
Combined Trainer Example
from transformers import AutoConfig, AutoImageProcessor, Trainer, TrainingArguments
from transformers_knn_adapter import (
Dinov2ForImageClassificationWithArcFaceLoss,
FreezeScheduleCallback,
KNNCallback,
)
config = AutoConfig.from_pretrained("facebook/dinov2-small")
config.num_labels = num_labels
config.label2id = label2id
config.id2label = id2label
config.arcface_margin = 28.6
config.arcface_scale = 64.0
model = Dinov2ForImageClassificationWithArcFaceLoss.from_pretrained(
"facebook/dinov2-small",
config=config,
ignore_mismatched_sizes=True,
)
image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-small")
training_args = TrainingArguments(
output_dir="/tmp/arcface-run",
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
dataloader_num_workers=4,
learning_rate=1e-3,
num_train_epochs=10,
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
processing_class=image_processor,
compute_metrics=compute_metrics,
)
trainer.add_callback(
FreezeScheduleCallback(
trainer=trainer,
freeze_schedule=[
{"epoch": 0.0, "freeze_modules": ["dinov2"], "unfreeze_modules": []},
{
"epoch": 3.0,
"freeze_modules": [],
"unfreeze_modules": ["dinov2.encoder.layer.11"],
},
],
)
)
trainer.add_callback(
KNNCallback(
trainer=trainer,
label_column="labels",
ks=(1, 5),
embedding_source="cls_mean",
)
)
trainer.train()
metrics = trainer.evaluate()
Freeze Schedule Callback
FreezeScheduleCallback applies module freeze/unfreeze rules based on epoch number.
You can provide the schedule inline without modifying model.config:
from transformers_knn_adapter import FreezeScheduleCallback
freeze_schedule = [
{"epoch": 0.0, "freeze_modules": ["dinov2"], "unfreeze_modules": []},
{"epoch": 3.0, "freeze_modules": [], "unfreeze_modules": ["dinov2.encoder.layer.11"]},
]
trainer.add_callback(
FreezeScheduleCallback(
trainer=trainer,
freeze_schedule=freeze_schedule,
)
)
Or, if you prefer, store the schedule on model.config.freeze_schedule and instantiate the callback without the freeze_schedule= argument.
Schedule format:
[
{
"epoch": 0.0,
"freeze_modules": ["dinov2"],
"unfreeze_modules": [],
},
{
"epoch": 3.0,
"freeze_modules": [],
"unfreeze_modules": ["dinov2.encoder.layer.11"],
},
]
The callback:
- applies the schedule at train begin and each epoch begin
- logs trainable parameter count as:
train/trainable_parameters
Module names must match model.named_modules(), for example:
dinov2dinov2.embeddingsdinov2.encoder.layer.11.for the whole model
Smoke Script
The repository includes a practical end-to-end example in:
It shows how to combine:
Dinov2ForImageClassificationWithArcFaceLossKNNCallbackFreezeScheduleCallback- Hugging Face
Trainer - optional W&B logging
Example:
uv run --with accelerate --with pytorch-metric-learning python scripts/dogfaces_smoke.py \
--dataset dimidagd/DogFaceNet_224resize \
--model facebook/dinov2-small \
--freeze-schedule-config configs/freeze_schedule.backbone_then_unfreeze.json \
--num-train-epochs 10 \
--learning-rate 0.001 \
--warmup-steps 10 \
--logging-strategy steps \
--logging-steps 20
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file transformers_knn_adapter-0.8.0.tar.gz.
File metadata
- Download URL: transformers_knn_adapter-0.8.0.tar.gz
- Upload date:
- Size: 279.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f1de30ccbf7d168b905267fe3e0d0cfe7e4c2eca72bce2dc185e639c3ea20717
|
|
| MD5 |
f6e18990f6a7b7021dfbc8149a02f95f
|
|
| BLAKE2b-256 |
35864b5fd19f3e65ab2ad4961c39b177178a45a39da332e93b5dd03f6b829361
|
Provenance
The following attestation bundles were made for transformers_knn_adapter-0.8.0.tar.gz:
Publisher:
publish.yml on dimidagd/transformers-knn-adapter
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
transformers_knn_adapter-0.8.0.tar.gz -
Subject digest:
f1de30ccbf7d168b905267fe3e0d0cfe7e4c2eca72bce2dc185e639c3ea20717 - Sigstore transparency entry: 1134089038
- Sigstore integration time:
-
Permalink:
dimidagd/transformers-knn-adapter@07e5138a955a008d8467adecf29b91a4238dae0a -
Branch / Tag:
refs/tags/v0.8.0 - Owner: https://github.com/dimidagd
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@07e5138a955a008d8467adecf29b91a4238dae0a -
Trigger Event:
push
-
Statement type:
File details
Details for the file transformers_knn_adapter-0.8.0-py3-none-any.whl.
File metadata
- Download URL: transformers_knn_adapter-0.8.0-py3-none-any.whl
- Upload date:
- Size: 27.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
043375845afecf98554312271c6ab2f8df97d68add3f3cf21cd3e93363c6d0f4
|
|
| MD5 |
7155fceae40ac8302367c99e12ea81b9
|
|
| BLAKE2b-256 |
a672c310a3527e6379329248e806e1cac42920cae2c3a051a0f2d157c0209eff
|
Provenance
The following attestation bundles were made for transformers_knn_adapter-0.8.0-py3-none-any.whl:
Publisher:
publish.yml on dimidagd/transformers-knn-adapter
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
transformers_knn_adapter-0.8.0-py3-none-any.whl -
Subject digest:
043375845afecf98554312271c6ab2f8df97d68add3f3cf21cd3e93363c6d0f4 - Sigstore transparency entry: 1134089217
- Sigstore integration time:
-
Permalink:
dimidagd/transformers-knn-adapter@07e5138a955a008d8467adecf29b91a4238dae0a -
Branch / Tag:
refs/tags/v0.8.0 - Owner: https://github.com/dimidagd
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@07e5138a955a008d8467adecf29b91a4238dae0a -
Trigger Event:
push
-
Statement type: