Skip to main content

Unified API for loading 28+ vision datasets with async embeddings from any PyTorch model

Project description

embedata

Unified API for loading 29 vision datasets and extracting embeddings from any PyTorch model. Handles torchvision, HuggingFace, and custom ImageFolder datasets through a single interface. Ships with 12 built-in embedding models (CLIP, DINOv2, DINOv3) and supports async extraction on secondary GPUs.

Install

pip install embedata              # core (torchvision datasets)
pip install embedata[hf]          # + HuggingFace datasets (ImageNet, CUB, RESISC45)

From source:

git clone https://github.com/ayghri/embedata.git && cd embedata
pip install -e ".[hf]"

Quick start

from embedata import list_datasets, get_dataset, get_datasets, get_dataloaders

print(list_datasets())  # all 29 datasets

train_ds = get_dataset("cifar10", split="train", root_dir="./data")
train_ds, val_ds = get_datasets("cifar10", root_dir="./data")
train_loader, val_loader = get_dataloaders("cifar10", batch_size=128, root_dir="./data")

Most datasets auto-download on first use. Datasets that require manual setup (Kaggle downloads, frame extraction, etc.) document their steps via get_spec("dataset_name").notes.

Extracting embeddings

Pass any PyTorch model that maps images to feature vectors. Embeddings are saved as .npy files:

import torch
from embedata import get_dataloaders, extract

model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14")
model.eval().to("cuda:0")

train_loader, val_loader = get_dataloaders("cifar10", batch_size=256, root_dir="./data")

extract(train_loader, model, device=torch.device("cuda:0"),
        save_dir="./representations/cifar10/dinov2", suffix="train")
extract(val_loader, model, device=torch.device("cuda:0"),
        save_dir="./representations/cifar10/dinov2", suffix="val")

Load them back as a PyTorch Dataset:

from embedata import load_embeddings

ds = load_embeddings("cifar10", "dinov2", repr_dir="./representations", split="train")
feat, label = ds[0]

For on-the-fly extraction without disk I/O, EmbeddingDataLoader runs a model on a secondary device in a background thread:

from embedata import EmbeddingDataLoader, get_dataset

dataset = get_dataset("cifar10", split="train", root_dir="./data")

loader = EmbeddingDataLoader(dataset, model, device="cuda:1", batch_size=256, prefetch=2)

for embeddings, labels in loader:
    ...

Built-in models

The package includes 12 pre-registered models accessible via load_model. Eight CLIP variants (clipRN50, clipRN101, clipRN50x4, clipRN50x16, clipRN50x64, clipvitB32, clipvitB16, clipvitL14), DINOv2 ViT-g/14 (dinov2), and three DINOv3 variants (dinov3s, dinov3b, dinov3l).

from embedata import list_models, load_model

model, preprocess = load_model("clipvitL14", device="cuda:0", models_dir="./models")

The returned preprocess is the model's image transform -- pass it to get_dataloaders or get_dataset. For DINOv2, preprocess is None; use get_default_transforms() instead.

CLIP models require pip install git+https://github.com/openai/CLIP.git. DINOv3 models require pip install transformers.

Custom datasets

Register your own dataset loader with the @register decorator. The function receives a torchvision transform and a data path, and returns (train_dataset, val_dataset):

from embedata import register

@register("my_dataset", notes="Setup instructions shown by get_spec()")
def _my_dataset(transform, data_path):
    train_ds = ...
    val_ds = ...
    return train_ds, val_ds

HuggingFace streaming

Three datasets (ImageNet, CUB, RESISC45) load via HuggingFace and support streaming mode (requires pip install embedata[hf]). ImageNet-1k is gated and requires huggingface-cli login.

train_ds, val_ds = get_datasets("imagenet", streaming=True)

Dataset preparation

Some datasets require manual download and preparation before use. Python prepare scripts are bundled with the package:

python -m embedata.prepare.birdsnap     --root_dir ROOT_DIR
python -m embedata.prepare.fer2013      --root_dir ROOT_DIR
python -m embedata.prepare.ucf101       --root_dir ROOT_DIR [--download]
python -m embedata.prepare.hatefulmemes --root_dir ROOT_DIR

Shell scripts for cars, eurosat, and sun397 are included under embedata/prepare/. All scripts expect raw data under ROOT_DIR/datasets/{dataset_name}/ and write prepared splits to the same location.

Available datasets

Dataset Train Val/Test Classes Size Source Notes
aircraft 6,667 3,333 100 variable FGVCAircraft trainval / test
birdsnap ~25,000 ~24,829 500 variable ImageFolder manual download + prepare
caltech101 ~3,060 ~5,587 101 variable Caltech101 30/class for train
cars 8,144 8,041 196 variable StanfordCars Kaggle download
cifar10 50,000 10,000 10 32x32 CIFAR10 auto-download
cifar100 50,000 10,000 100 32x32 CIFAR100 auto-download
clevr 70,000 15,000 11 320x240 CLEVRClassification count 0-10
country211 42,200 21,100 211 variable Country211 train+valid / test
cub 5,994 5,794 200 variable HuggingFace CUB-200-2011
dtd 3,760 1,880 47 variable DTD train+val / test
eurosat 10,000 5,000 10 64x64 EuroSAT 1k+500 per class
fashionmnist 60,000 10,000 10 28x28 FashionMNIST auto-download
fer2013 28,709 3,589 7 48x48 FER2013 Kaggle + prepare
flowers 2,040 6,149 102 variable Flowers102 train+val / test
food101 75,750 25,250 101 variable Food101 auto-download
gtsrb 26,640 12,630 43 variable GTSRB auto-download
hatefulmemes ~8,500 ~500 2 variable ImageFolder Kaggle + prepare
imagenet 1,281,167 50,000 1,000 variable HuggingFace gated, HF login
imagenette 9,469 3,925 10 variable Imagenette ImageNet subset
kinetics700 varies varies 700 variable ImageFolder frame extraction
kitti varies varies 4 variable ImageFolder manual prep
mnist 60,000 10,000 10 28x28 MNIST auto-download
pcam 294,912 32,768 2 96x96 PCAM train+val / test
pets 3,680 3,669 37 variable OxfordIIITPet auto-download
resisc45 25,200 6,300 45 256x256 HuggingFace remote sensing
sst 7,792 1,821 2 variable RenderedSST2 train+val / test
stl10 5,000 8,000 10 96x96 STL10 auto-download
sun397 ~19,850 ~19,850 397 variable SUN397 manual + prepare
ucf101 varies varies 101 variable ImageFolder frame extraction + prepare

Counts reflect splits as loaded by embedata (some merge train+val for training). "variable" means images have different native resolutions -- all are resized by the transform (default 224x224). Datasets marked "auto-download" are fetched on first use.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

embedata-0.1.0.tar.gz (16.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

embedata-0.1.0-py3-none-any.whl (19.8 kB view details)

Uploaded Python 3

File details

Details for the file embedata-0.1.0.tar.gz.

File metadata

  • Download URL: embedata-0.1.0.tar.gz
  • Upload date:
  • Size: 16.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.3.2 CPython/3.13.11 Linux/6.12.74-gentoo-x86_64

File hashes

Hashes for embedata-0.1.0.tar.gz
Algorithm Hash digest
SHA256 31122599c55b50e2ec61f5415ef735daed836dff17fe53155eee3119173f9ecb
MD5 846e7152248c3e776bbf14573ebe1666
BLAKE2b-256 d940b1fe0fb6003cd6769bf8cc444440aae21a05cc2d9a834b146c5f28a082cd

See more details on using hashes here.

File details

Details for the file embedata-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: embedata-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 19.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.3.2 CPython/3.13.11 Linux/6.12.74-gentoo-x86_64

File hashes

Hashes for embedata-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d6e290a6ff783b1b10f92f23e630278b0b41a804ea47558d0350aa263d9e4a01
MD5 b85fd9b3847bf7033a8270230b5b88da
BLAKE2b-256 b1ffbb42bf2a9375fcb880d60bf285ffa33bd37cdd7edfddff612936d1b2f819

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page