Unified API for loading 28+ vision datasets with async embeddings from any PyTorch model
Project description
embedata
Unified API for loading 29 vision datasets and extracting embeddings from any PyTorch model. Handles torchvision, HuggingFace, and custom ImageFolder datasets through a single interface. Ships with 12 built-in embedding models (CLIP, DINOv2, DINOv3) and supports async extraction on secondary GPUs.
Install
pip install embedata # core (torchvision datasets)
pip install embedata[hf] # + HuggingFace datasets (ImageNet, CUB, RESISC45)
From source:
git clone https://github.com/ayghri/embedata.git && cd embedata
pip install -e ".[hf]"
Quick start
from embedata import list_datasets, get_dataset, get_datasets, get_dataloaders
print(list_datasets()) # all 29 datasets
train_ds = get_dataset("cifar10", split="train", root_dir="./data")
train_ds, val_ds = get_datasets("cifar10", root_dir="./data")
train_loader, val_loader = get_dataloaders("cifar10", batch_size=128, root_dir="./data")
Most datasets auto-download on first use. Datasets that require manual setup (Kaggle downloads, frame extraction, etc.) document their steps via get_spec("dataset_name").notes.
Extracting embeddings
Pass any PyTorch model that maps images to feature vectors. Embeddings are saved as .npy files:
import torch
from embedata import get_dataloaders, extract
model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14")
model.eval().to("cuda:0")
train_loader, val_loader = get_dataloaders("cifar10", batch_size=256, root_dir="./data")
extract(train_loader, model, device=torch.device("cuda:0"),
save_dir="./representations/cifar10/dinov2", suffix="train")
extract(val_loader, model, device=torch.device("cuda:0"),
save_dir="./representations/cifar10/dinov2", suffix="val")
Load them back as a PyTorch Dataset:
from embedata import load_embeddings
ds = load_embeddings("cifar10", "dinov2", repr_dir="./representations", split="train")
feat, label = ds[0]
For on-the-fly extraction without disk I/O, EmbeddingDataLoader runs a model on a secondary device in a background thread:
from embedata import EmbeddingDataLoader, get_dataset
dataset = get_dataset("cifar10", split="train", root_dir="./data")
loader = EmbeddingDataLoader(dataset, model, device="cuda:1", batch_size=256, prefetch=2)
for embeddings, labels in loader:
...
Built-in models
The package includes 12 pre-registered models accessible via load_model.
Eight CLIP variants (clipRN50, clipRN101, clipRN50x4, clipRN50x16, clipRN50x64, clipvitB32, clipvitB16, clipvitL14),
DINOv2 ViT-g/14 (dinov2), and three DINOv3 variants (dinov3s, dinov3b, dinov3l).
from embedata import list_models, load_model
model, preprocess = load_model("clipvitL14", device="cuda:0", models_dir="./models")
The returned preprocess is the model's image transform -- pass it to get_dataloaders or get_dataset. For DINOv2, preprocess is None; use get_default_transforms() instead.
CLIP models require pip install git+https://github.com/openai/CLIP.git. DINOv3 models require pip install transformers.
Custom datasets
Register your own dataset loader with the @register decorator. The function receives a torchvision transform and a data path, and returns (train_dataset, val_dataset):
from embedata import register
@register("my_dataset", notes="Setup instructions shown by get_spec()")
def _my_dataset(transform, data_path):
train_ds = ...
val_ds = ...
return train_ds, val_ds
HuggingFace streaming
Three datasets (ImageNet, CUB, RESISC45) load via HuggingFace and support streaming mode (requires pip install embedata[hf]). ImageNet-1k is gated and requires huggingface-cli login.
train_ds, val_ds = get_datasets("imagenet", streaming=True)
Dataset preparation
Some datasets require manual download and preparation before use. Python prepare scripts are bundled with the package:
python -m embedata.prepare.birdsnap --root_dir ROOT_DIR
python -m embedata.prepare.fer2013 --root_dir ROOT_DIR
python -m embedata.prepare.ucf101 --root_dir ROOT_DIR [--download]
python -m embedata.prepare.hatefulmemes --root_dir ROOT_DIR
Shell scripts for cars, eurosat, and sun397 are included under embedata/prepare/. All scripts expect raw data under ROOT_DIR/datasets/{dataset_name}/ and write prepared splits to the same location.
Available datasets
| Dataset | Train | Val/Test | Classes | Size | Source | Notes |
|---|---|---|---|---|---|---|
| aircraft | 6,667 | 3,333 | 100 | variable | FGVCAircraft | trainval / test |
| birdsnap | ~25,000 | ~24,829 | 500 | variable | ImageFolder | manual download + prepare |
| caltech101 | ~3,060 | ~5,587 | 101 | variable | Caltech101 | 30/class for train |
| cars | 8,144 | 8,041 | 196 | variable | StanfordCars | Kaggle download |
| cifar10 | 50,000 | 10,000 | 10 | 32x32 | CIFAR10 | auto-download |
| cifar100 | 50,000 | 10,000 | 100 | 32x32 | CIFAR100 | auto-download |
| clevr | 70,000 | 15,000 | 11 | 320x240 | CLEVRClassification | count 0-10 |
| country211 | 42,200 | 21,100 | 211 | variable | Country211 | train+valid / test |
| cub | 5,994 | 5,794 | 200 | variable | HuggingFace | CUB-200-2011 |
| dtd | 3,760 | 1,880 | 47 | variable | DTD | train+val / test |
| eurosat | 10,000 | 5,000 | 10 | 64x64 | EuroSAT | 1k+500 per class |
| fashionmnist | 60,000 | 10,000 | 10 | 28x28 | FashionMNIST | auto-download |
| fer2013 | 28,709 | 3,589 | 7 | 48x48 | FER2013 | Kaggle + prepare |
| flowers | 2,040 | 6,149 | 102 | variable | Flowers102 | train+val / test |
| food101 | 75,750 | 25,250 | 101 | variable | Food101 | auto-download |
| gtsrb | 26,640 | 12,630 | 43 | variable | GTSRB | auto-download |
| hatefulmemes | ~8,500 | ~500 | 2 | variable | ImageFolder | Kaggle + prepare |
| imagenet | 1,281,167 | 50,000 | 1,000 | variable | HuggingFace | gated, HF login |
| imagenette | 9,469 | 3,925 | 10 | variable | Imagenette | ImageNet subset |
| kinetics700 | varies | varies | 700 | variable | ImageFolder | frame extraction |
| kitti | varies | varies | 4 | variable | ImageFolder | manual prep |
| mnist | 60,000 | 10,000 | 10 | 28x28 | MNIST | auto-download |
| pcam | 294,912 | 32,768 | 2 | 96x96 | PCAM | train+val / test |
| pets | 3,680 | 3,669 | 37 | variable | OxfordIIITPet | auto-download |
| resisc45 | 25,200 | 6,300 | 45 | 256x256 | HuggingFace | remote sensing |
| sst | 7,792 | 1,821 | 2 | variable | RenderedSST2 | train+val / test |
| stl10 | 5,000 | 8,000 | 10 | 96x96 | STL10 | auto-download |
| sun397 | ~19,850 | ~19,850 | 397 | variable | SUN397 | manual + prepare |
| ucf101 | varies | varies | 101 | variable | ImageFolder | frame extraction + prepare |
Counts reflect splits as loaded by embedata (some merge train+val for training). "variable" means images have different native resolutions -- all are resized by the transform (default 224x224). Datasets marked "auto-download" are fetched on first use.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file embedata-0.1.0.tar.gz.
File metadata
- Download URL: embedata-0.1.0.tar.gz
- Upload date:
- Size: 16.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.3.2 CPython/3.13.11 Linux/6.12.74-gentoo-x86_64
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
31122599c55b50e2ec61f5415ef735daed836dff17fe53155eee3119173f9ecb
|
|
| MD5 |
846e7152248c3e776bbf14573ebe1666
|
|
| BLAKE2b-256 |
d940b1fe0fb6003cd6769bf8cc444440aae21a05cc2d9a834b146c5f28a082cd
|
File details
Details for the file embedata-0.1.0-py3-none-any.whl.
File metadata
- Download URL: embedata-0.1.0-py3-none-any.whl
- Upload date:
- Size: 19.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.3.2 CPython/3.13.11 Linux/6.12.74-gentoo-x86_64
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d6e290a6ff783b1b10f92f23e630278b0b41a804ea47558d0350aa263d9e4a01
|
|
| MD5 |
b85fd9b3847bf7033a8270230b5b88da
|
|
| BLAKE2b-256 |
b1ffbb42bf2a9375fcb880d60bf285ffa33bd37cdd7edfddff612936d1b2f819
|