Efficient few-shot audio classification with contrastive fine-tuning (a SetFit for audio).
Project description
AudioSetFit
Efficient few-shot audio classification with contrastive fine-tuning - a SetFit for audio.
audiosetfit ports SetFit's prompt-free, few-shot recipe from text to audio. Instead of a
SentenceTransformer body, it uses an audio encoder (CLAP by default) and trains in two
phases:
- Embedding fine-tuning (contrastive). From a handful of labeled clips it builds positive (same-class) and negative (different-class) pairs and fine-tunes the audio body so same-class clips embed closer together. A few examples explode into hundreds of informative pairs.
- Classifier head. A lightweight head (sklearn
LogisticRegressionby default, or a differentiable torch head) is fit on the resulting embeddings.
The contrastive trainer is self-contained, it does not depend on
sentence-transformers. The pair-sampling and loss math are reimplemented to operate
directly on audio embeddings, so any HF audio model can be plugged in as the body.
The public API intentionally mirrors SetFit:
from audiosetfit import AudioSetFitModel, Trainer, TrainingArguments, sample_dataset
Installation
# from the repo root
python3 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip
pip install -e .
This installs torch, transformers, datasets, librosa, soundfile, torchcodec,
scikit-learn, etc. On Apple Silicon, PyTorch will use the MPS backend automatically;
on NVIDIA GPUs it uses CUDA; otherwise CPU.
FFmpeg required.
datasets >= 4decodes audio viatorchcodec, which needs FFmpeg (4–7) installed on your system. On macOS:brew install ffmpeg; on Debian/Ubuntu:sudo apt-get install ffmpeg.
Examples
The examples/ scripts cover four small benchmarks across different audio domains. They share
the same CLI flags (--backbone, --classes, --num-samples, --epochs, --max-pairs,
--no-embedding-finetuning, --differentiable-head, --num-workers, ...).
ESC-50 (environmental sounds)
ESC-50 (2,000 clips, 50 classes) is the small default starting point. The example restricts to a few classes for a fast first run:
python examples/train_esc50.py # 5 classes, 8 shots, CLAP
python examples/train_esc50.py --classes 10 --num-samples 16
python examples/train_esc50.py --no-embedding-finetuning # frozen-backbone baseline
UrbanSound8K (urban sounds)
UrbanSound8K (8,732 clips, 10 classes, 10 folds) follows the dataset's fold protocol (folds 1-9 train, fold 10 test):
python examples/train_urbansound8k.py # 5 classes, 8 shots, CLAP
python examples/train_urbansound8k.py --classes 10 --num-samples 16
CREMA-D (speech emotion)
A speech task where self-supervised speech encoders shine:
CREMA-D (7,442 clips, 91 actors,
6 emotions). It defaults to facebook/wav2vec2-base and uses a speaker-disjoint train/test split:
python examples/train_cremad.py # wav2vec2-base
python examples/train_cremad.py --backbone microsoft/wavlm-base
python examples/train_cremad.py --backbone facebook/hubert-base-ls960
python examples/train_cremad.py --backbone laion/clap-htsat-unfused # compare vs CLAP
MSWC (keyword spotting)
A keyword-spotting example (SUPERB KS-style) on MSWC (Multilingual Spoken Words Corpus). Each clip is a single spoken word; this is a lexical/phonetic task, so it also defaults to a speech encoder and uses the dataset's predefined train/test splits:
python examples/train_mswc_keywords.py # 10 keywords, wav2vec2-base
python examples/train_mswc_keywords.py --classes 5 --num-samples 16
python examples/train_mswc_keywords.py --language spanish
python examples/train_mswc_keywords.py --backbone laion/clap-htsat-unfused # compare vs CLAP
Benchmarking (multi-backbone / multi-seed)
examples/benchmark.py drives the training scripts above across a grid of backbones x seeds and
prints a mean +/- std table, so backbone comparisons are reproducible instead of single noisy runs.
It reuses each dataset's own split logic. Trainer.evaluate reports both accuracy and macro-F1,
and Trainer.classification_report(...) adds per-class accuracy and a confusion matrix.
# CLAP vs wav2vec2 on CREMA-D over 3 seeds
python examples/benchmark.py --dataset cremad \
--backbones laion/clap-htsat-unfused facebook/wav2vec2-base --seeds 41 42 43
# Keyword spotting with speech encoders, write a CSV of every run
python examples/benchmark.py --dataset mswc \
--backbones facebook/wav2vec2-base microsoft/wavlm-base --seeds 42 43 --csv results.csv
# Forward extra flags to the training script after a literal `--`
python examples/benchmark.py --dataset esc50 --seeds 41 42 43 -- --no-embedding-finetuning
Minimal end-to-end usage
from datasets import Audio, load_dataset
from audiosetfit import AudioSetFitModel, Trainer, TrainingArguments, sample_dataset
ds = load_dataset("ashraq/esc50", split="train").cast_column("audio", Audio(sampling_rate=48000))
labels = sorted(set(ds["category"]))
train_ds = sample_dataset(ds, label_column="category", num_samples=8)
model = AudioSetFitModel.from_pretrained("laion/clap-htsat-unfused", labels=labels)
trainer = Trainer(
model=model,
args=TrainingArguments(embedding_num_epochs=1, max_pairs=256),
train_dataset=train_ds,
column_mapping={"category": "label"}, # the 'audio' column already matches
)
trainer.train()
preds = model.predict(["dog_bark.wav", "rain.wav"]) # file paths, arrays, or Audio dicts
model.save_pretrained("my-esc50-model")
reloaded = AudioSetFitModel.from_pretrained("my-esc50-model")
Inputs accepted everywhere
predict / encode / datasets accept any mix of:
- file paths (
"clip.wav"), - raw waveforms (
np.ndarray, assumed at the backbone's sample rate), - Hugging Face
datasetsAudio dicts ({"array", "sampling_rate", "path"}).
Everything is resampled to the backbone's expected rate (CLAP = 48 kHz).
Project layout
src/audiosetfit/
├── encoders.py # AudioEncoder base + CLAP/AST/wav2vec2-family/Whisper + build_encoder()
├── modeling.py # AudioSetFitModel, AudioSetFitHead, save/from_pretrained
├── sampler.py # ContrastiveDataset (same/different-label pair generation)
├── losses.py # CosineSimilarityLoss, ContrastiveLoss (on embedding tensors)
├── data.py # load_audio (resampling), sample_dataset
├── training_args.py # TrainingArguments (both phases)
└── trainer.py # self-contained two-phase Trainer
examples/train_esc50.py
examples/train_urbansound8k.py
examples/train_cremad.py
examples/train_mswc_keywords.py
examples/benchmark.py # multi-backbone / multi-seed harness
Key training arguments
| Argument | Default | Purpose |
|---|---|---|
train_embeddings |
True |
Run phase 1. Set False for a frozen-backbone baseline. |
embedding_num_epochs |
1 |
Epochs over contrastive pairs. |
embedding_batch_size |
16 |
Pair batch size (lower it if you hit memory limits). |
body_learning_rate |
2e-5 |
LR for the audio body. |
loss |
"cosine" |
"cosine" or "contrastive" (or pass an nn.Module). |
sampling_strategy |
"oversampling" |
"unique" / "oversampling" / "undersampling". |
max_steps / max_pairs |
-1 |
Cap phase-1 work (handy on CPU/laptops). |
classifier_num_epochs |
25 |
Torch-head epochs (ignored for sklearn head). |
Backbones
Pick a backbone by passing its Hugging Face id to from_pretrained (or --backbone in the
example). The right AudioEncoder is selected automatically from the model's model_type.
AudioSetFitModel.from_pretrained("laion/clap-htsat-unfused") # CLAP (default)
AudioSetFitModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593") # AST
AudioSetFitModel.from_pretrained("facebook/wav2vec2-base") # wav2vec2
AudioSetFitModel.from_pretrained("facebook/hubert-base-ls960") # HuBERT
AudioSetFitModel.from_pretrained("microsoft/wavlm-base-plus") # WavLM
AudioSetFitModel.from_pretrained("openai/whisper-base") # Whisper encoder
python examples/train_esc50.py --backbone facebook/wav2vec2-base
python examples/train_esc50.py --backbone MIT/ast-finetuned-audioset-10-10-0.4593
Audio is resampled to each backbone's expected rate automatically (CLAP 48 kHz, others 16 kHz), so the same dataset works across all of them.
| Backbone | model_type (built-in) |
Embedding | Pooling | Best for |
|---|---|---|---|---|
| CLAP (default) | clap |
512-d, normalized | projection head | General sound events, environmental audio |
| AST | audio-spectrogram-transformer |
768-d | CLS+dist pooled | AudioSet-style tagging |
| wav2vec2 / HuBERT / WavLM | wav2vec2 / hubert / wavlm (+ unispeech, unispeech-sat, data2vec-audio, wav2vec2-conformer) |
hidden_size | masked mean over time | Speech (commands, speaker, emotion) |
| Whisper encoder | whisper |
d_model | mean over frames | Robust speech in noise |
Adding another backbone
The encoder is the only modality-specific piece. Subclass AudioEncoder, implement
prepare (waveforms → model inputs) and forward_features (inputs → [B, D]), set
target_sr / embedding_dim, then register it:
from audiosetfit import encoders
class MyEncoder(encoders.AudioEncoder):
def __init__(self, model_id, device=None):
super().__init__()
self.model_id = model_id
... # load backbone + feature extractor
self.target_sr = 16000
self.embedding_dim = ... # output dim
self.to(encoders._resolve_device(device))
def prepare(self, waveforms): ...
def forward_features(self, inputs): ...
def save(self, save_directory): ...
encoders._ENCODER_REGISTRY["my_model_type"] = MyEncoder
Roadmap / next steps
Benchmarking & evaluation
- Reproducible multi-backbone / multi-seed benchmark harness (
examples/benchmark.py) with mean ± std tables. - Richer metrics in
Trainer.evaluate(accuracy + macro-F1); per-class accuracy and confusion matrix viaTrainer.classification_report. - Published results table (CLAP vs wav2vec2 vs WavLM across all example datasets).
Training method
-
SupConLoss/ InfoNCE with in-batch negatives + group-by-label batch sampler (so larger batches add real negatives, as in SetFit). - Audio augmentation for the few-shot regime (SpecAugment, additive noise, gain, time-shift, random crop).
- Embedding cache for the frozen-backbone path (skip re-encoding clips across runs/sweeps).
- Knowledge distillation from a large unlabeled audio pool (teacher → student).
Models & inputs
- Long-clip handling: windowing/chunking → encode → pool/vote.
- Multilabel audio tagging end-to-end example (sampler already supports multilabel pairs).
- (Optional) BEATs / OpenBEATs backbone (strongest general-purpose SSL embeddings).
Productionization
- ONNX /
torch.compileexport for fast CPU inference. - Hub
push_to_hubwith an auto-generated model card (incl. the eval table). - Smoke-test suite + CI using small real models (e.g.
openai/whisper-tiny).
Acknowledgements
Architecture and training recipe adapted from Hugging Face SetFit (Tunstall et al., Efficient Few-Shot Learning Without Prompts, 2022).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file audiosetfit-0.1.0.tar.gz.
File metadata
- Download URL: audiosetfit-0.1.0.tar.gz
- Upload date:
- Size: 29.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d57f7f9e2a7f922a10c4ba53d9b6da22fbeffe6f23a5bb35648fa59fe0c733e1
|
|
| MD5 |
3f702c8529c0e8e086723f9b7eda0ed1
|
|
| BLAKE2b-256 |
01ec96af7ac1f1549b0c84a951976601b39fb84f87110d8d9ad61465990a0f42
|
Provenance
The following attestation bundles were made for audiosetfit-0.1.0.tar.gz:
Publisher:
release.yml on iljab/audiosetfit
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
audiosetfit-0.1.0.tar.gz -
Subject digest:
d57f7f9e2a7f922a10c4ba53d9b6da22fbeffe6f23a5bb35648fa59fe0c733e1 - Sigstore transparency entry: 1966729025
- Sigstore integration time:
-
Permalink:
iljab/audiosetfit@666efd9763404e32d353d7c7ab0cf7adc271a55b -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/iljab
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@666efd9763404e32d353d7c7ab0cf7adc271a55b -
Trigger Event:
release
-
Statement type:
File details
Details for the file audiosetfit-0.1.0-py3-none-any.whl.
File metadata
- Download URL: audiosetfit-0.1.0-py3-none-any.whl
- Upload date:
- Size: 27.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3a493479f006e89f3bf0d82374ef361eab63b20e92a527d178774ada6ed8e0d0
|
|
| MD5 |
a8eb09f5bfe4097ed5172977202ffc8e
|
|
| BLAKE2b-256 |
a7d8c4e9735fcefa8304a83cfdaf264aa2e2b6e5ae67504fadbc8f14e5ffb43f
|
Provenance
The following attestation bundles were made for audiosetfit-0.1.0-py3-none-any.whl:
Publisher:
release.yml on iljab/audiosetfit
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
audiosetfit-0.1.0-py3-none-any.whl -
Subject digest:
3a493479f006e89f3bf0d82374ef361eab63b20e92a527d178774ada6ed8e0d0 - Sigstore transparency entry: 1966729071
- Sigstore integration time:
-
Permalink:
iljab/audiosetfit@666efd9763404e32d353d7c7ab0cf7adc271a55b -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/iljab
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@666efd9763404e32d353d7c7ab0cf7adc271a55b -
Trigger Event:
release
-
Statement type: