Skip to main content

No project description provided

Project description

embed-train

embed-train is a config-driven library for training, evaluating, checkpointing, and publishing embedding models. It is designed for teams that want reusable training building blocks instead of one-off scripts.

The package currently supports two main training paths:

  • A custom PyTorch training loop for query/document contrastive learning
  • A SentenceTransformers-based trainer for Hugging Face Dataset workflows and IR evaluation

It also includes a dedicated runner for exporting checkpoints and pushing model repositories to Hugging Face Hub.

Why This Library Exists

Most embedding projects end up reimplementing the same pieces:

  • model loading and checkpoint restore
  • dataset adapters
  • collate/tokenization logic
  • contrastive losses
  • train/validation splitting
  • checkpoint saving
  • model packaging and Hub publishing

embed-train turns those pieces into small, replaceable abstractions connected by typed settings objects and module_path-based loading.

Installation

Python requirements from pyproject.toml:

  • Python >=3.11,<3.13
  • PyTorch
  • Sentence Transformers
  • Hugging Face Datasets
  • TensorBoard
  • Accelerate
  • retrievalbase

Install from source:

pip install -e .

For local development in this repository:

make dev-install

Library Structure

src/embed_train/
├── __init__.py                    # Base Runner abstraction and YAML runner loader
├── constants.py                   # Shared constants such as default seed and trust flags
├── exceptions.py                  # Library-specific exception hierarchy
├── models/                        # Base model wrapper abstraction
├── push_to_hf/                    # Runner for packaging and pushing checkpoints to HF Hub
├── settings.py                    # Pydantic settings models for all library components
├── train/
│   ├── __init__.py                # TrainRunner
│   ├── dataset/
│   │   ├── __init__.py            # Base TorchDataset and CollateFn abstractions
│   │   ├── collate.py             # Built-in in-batch positive collate functions
│   │   ├── torch_datasets.py      # Built-in query/positive dataset views
│   │   └── sampling/
│   │       └── samplers.py        # Built-in positive sampler(s)
│   └── trainers/
│       ├── __init__.py            # Base Trainer abstraction
│       ├── hf/                    # SentenceTransformers trainer
│       └── torch/
│           ├── __init__.py        # Custom PyTorch trainer base
│           └── loss.py            # Contrastive loss implementations
|
└── utils.py                       # Dynamic import, checkpoint loading, HF file helpers

Core Concepts

1. Runners

Runners are the top-level execution unit.

  • embed_train.Runner: abstract base class
  • embed_train.train.TrainRunner: loads a trainer and executes train()
  • embed_train.push_to_hf.PushToHFRunner: restores a checkpoint, saves a local HF-style repo, and optionally pushes it

2. Settings-Driven Composition

Every major component is configured through Pydantic settings objects in src/embed_train/settings.py. Most components are resolved dynamically from a module_path, which makes the library extensible without editing core orchestration code.

3. Model Wrappers

Models inherit from embed_train.models.Model and must implement:

  • to_hf_model(): return a Hugging Face PreTrainedModel

The base class already provides:

  • save(...)
  • to(device)
  • from_checkpoint(...)

4. Datasets and Collate Functions

The PyTorch flow separates concerns clearly:

  • TorchDataset: how rows are loaded and exposed
  • CollateFn: how rows become query/document text pairs and tokenized tensors
  • Processor: text normalization or preprocessing, provided by retrievalbase

5. Trainers

There are two trainer families:

  • PyTorchTrainer: custom loop with checkpoints, TensorBoard logging, train/val split, and pluggable losses
  • SentenceTransformersTrainer: wraps SentenceTransformerTrainer and builds an IR evaluator from a Hugging Face dataset

Public Building Blocks

The current repository includes these reusable implementations:

  • embed_train.train.dataset.torch_datasets.QueryMultiPositiveDataset
  • embed_train.train.dataset.torch_datasets.QueryPositiveDataset
  • embed_train.train.dataset.collate.InBatchPositiveCollateFn
  • embed_train.train.dataset.collate.MultiPositiveInBatchCollateFn
  • embed_train.train.trainers.torch.loss.InBatchNegativeContrastiveLoss
  • embed_train.train.trainers.torch.loss.MultiPositiveContrastiveLoss
  • embed_train.train.trainers.hf.SentenceTransformersTrainer
  • embed_train.push_to_hf.PushToHFRunner

Typical Workflows

Custom PyTorch Training

Use this when you want full control over the training loop and your model wrapper already exposes a Hugging Face model.

What the built-in trainer does:

  • loads the configured model
  • loads a TorchDataset
  • builds a collate function with tokenizer and processor
  • splits the dataset into train/validation sets
  • trains with AdamW
  • logs to TensorBoard
  • saves checkpoints to data_dir/checkpoints/...

SentenceTransformers Training

Use this when your data is naturally represented as a Hugging Face Dataset and you want a standard SentenceTransformers training path with IR evaluation.

The trainer:

  • materializes a base model locally if needed
  • constructs a SentenceTransformer from transformer + pooling modules
  • creates train/validation splits
  • builds a SentenceTransformers loss from module_path
  • evaluates with InformationRetrievalEvaluator

Publish to Hugging Face Hub

Use PushToHFRunner to:

  • load a model wrapper from a checkpoint
  • save a local HF model repository
  • copy model source files such as modeling_*.py and configuration_*.py
  • optionally create the remote repository and upload the folder

Extending the Library

Add a Model Wrapper

Implement a subclass of embed_train.models.Model:

from transformers import AutoModel

from embed_train.models import Model


class MyModelWrapper(Model):
    def __init__(self, config):
        super().__init__(config)
        self.model = AutoModel.from_pretrained(config.base_model_name)

    def to_hf_model(self):
        return self.model

Add a Custom PyTorch Trainer

Subclass embed_train.train.trainers.torch.PyTorchTrainer and implement:

  • _encode_query(...)
  • _encode_documents(...)

This is the right place to define pooling, projection heads, shared or separate encoders, and any custom forward behavior.

Add a Custom Dataset

Subclass TorchDataset when you need a different row shape or data-loading strategy.

The built-in datasets show two common patterns:

  • grouped query -> many positives
  • flattened query -> single positive

Add a Custom Collate Function

Subclass CollateFn and implement _process_batch(batch). Return two string lists:

  • queries
  • candidate documents

The base class handles processor application and tokenizer calls.

Add a Custom Loss

Subclass embed_train.train.trainers.torch.loss.Loss and implement _get_loss(q_emb, c_emb).

Output Conventions

PyTorch Trainer

  • TensorBoard logs: data_dir/tensorboard/runs/<run_name>
  • Checkpoints: data_dir/checkpoints/<run_name>/epoch_XXXX.pt

SentenceTransformers Trainer

  • checkpoints under data_dir/checkpoints/...

PushToHFRunner

  • local repository saved in hf.repo
  • optional upload to Hugging Face Hub when push=True

Best Practices When Using This Library

  • Treat settings.py as the contract for supported configuration.
  • Keep your custom classes importable from stable module paths.
  • Prefer explicit Python settings construction while iterating locally; add YAML once the pipeline is stable.
  • Keep save_steps, eval_steps, and logging_steps aligned in SentenceTransformers configs. The library validates that logging_steps <= eval_steps <= save_steps and that save_steps is a multiple of eval_steps.
  • Make checkpoint directories part of your experiment artifact strategy, not an afterthought.
  • Use deterministic seeds in your own custom samplers and data preparation logic when reproducibility matters.
  • Keep tokenizer and processor behavior documented together, since both affect embedding quality.
  • Push model code alongside weights when using custom Hugging Face model implementations.

Development

Common local commands:

make dev-install
make ci

Testing

The repository already contains:

  • unit tests for settings, runners, losses, models, and utility behavior
  • integration coverage for the train runner flow and dataset conversion behavior

Those tests are a good reference when you are unsure how a component is expected to behave.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

embed_train-1.0.0.tar.gz (149.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

embed_train-1.0.0-py3-none-any.whl (21.9 kB view details)

Uploaded Python 3

File details

Details for the file embed_train-1.0.0.tar.gz.

File metadata

  • Download URL: embed_train-1.0.0.tar.gz
  • Upload date:
  • Size: 149.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.7 {"installer":{"name":"uv","version":"0.11.7","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for embed_train-1.0.0.tar.gz
Algorithm Hash digest
SHA256 d976ce4c803ced733441bd0d9c6c50e69253e280e5064e88362fb923c5525f30
MD5 f8435288d5ec9bc97ab7e6ec88ea4f41
BLAKE2b-256 b2cb32173f66709ebd748e8922d6c620ce595c3b26061410fe65020c9932b075

See more details on using hashes here.

File details

Details for the file embed_train-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: embed_train-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 21.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.7 {"installer":{"name":"uv","version":"0.11.7","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for embed_train-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2cd5b48d707c480a8f4af03ae6c2f06ca5252480d4b0479b72e6ff5de08c8745
MD5 3f72ec49ffd1e48205224f347ba2af15
BLAKE2b-256 f5335e6af5ceed28ba8b17dd1f25bc411d90112efb4466adefb58a8774ce686d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page