Skip to main content

No project description provided

Project description

embed-train

embed-train is a config-driven library for training, evaluating, checkpointing, and publishing embedding models. It is designed for teams that want reusable training building blocks instead of one-off scripts.

The package currently supports two main training paths:

  • A custom PyTorch training loop for query/document contrastive learning
  • A SentenceTransformers-based trainer for Hugging Face Dataset workflows and IR evaluation

It also includes a dedicated runner for exporting checkpoints and pushing model repositories to Hugging Face Hub.

Why This Library Exists

Most embedding projects end up reimplementing the same pieces:

  • model loading and checkpoint restore
  • dataset adapters
  • collate/tokenization logic
  • contrastive losses
  • train/validation splitting
  • checkpoint saving
  • model packaging and Hub publishing

embed-train turns those pieces into small, replaceable abstractions connected by typed settings objects and module_path-based loading.

Installation

Python requirements from pyproject.toml:

  • Python >=3.11,<3.13
  • PyTorch
  • Sentence Transformers
  • Hugging Face Datasets
  • TensorBoard
  • Accelerate
  • retrievalbase

Install from source:

pip install -e .

For local development in this repository:

make dev-install

Library Structure

src/embed_train/
├── __init__.py                    # Base Runner abstraction and YAML runner loader
├── constants.py                   # Shared constants such as default seed and trust flags
├── exceptions.py                  # Library-specific exception hierarchy
├── models/                        # Base model wrapper abstraction
├── push_to_hf/                    # Runner for packaging and pushing checkpoints to HF Hub
├── settings.py                    # Pydantic settings models for all library components
├── train/
│   ├── __init__.py                # TrainRunner
│   ├── dataset/
│   │   ├── __init__.py            # Base TorchDataset, HardNegativeMiner, and CollateFn abstractions
│   │   ├── collate.py             # Built-in in-batch and hard-negative collate functions
│   │   ├── hard_negatives.py      # SentenceTransformers hard-negative miner
│   │   ├── torch_datasets.py      # Built-in query/positive and hard-negative dataset views
│   │   └── sampling/
│   │       └── samplers.py        # Built-in positive sampler(s)
│   └── trainers/
│       ├── __init__.py            # Base Trainer abstraction
│       ├── hf/                    # SentenceTransformers trainer
│       └── torch/
│           ├── __init__.py        # Custom PyTorch trainer base
│           └── loss.py            # Contrastive loss implementations
|
└── utils.py                       # Dynamic import, checkpoint loading, HF file helpers

Core Concepts

1. Runners

Runners are the top-level execution unit.

  • embed_train.Runner: abstract base class
  • embed_train.train.TrainRunner: loads a trainer and executes train()
  • embed_train.push_to_hf.PushToHFRunner: restores a checkpoint, saves a local HF-style repo, and optionally pushes it

2. Settings-Driven Composition

Every major component is configured through Pydantic settings objects in src/embed_train/settings.py. Most components are resolved dynamically from a module_path, which makes the library extensible without editing core orchestration code.

3. Model Wrappers

Models inherit from embed_train.models.Model and must implement:

  • to_hf_model(): return a Hugging Face PreTrainedModel

The base class already provides:

  • save(...)
  • to(device)
  • from_checkpoint(...)

4. Datasets and Collate Functions

The PyTorch flow separates concerns clearly:

  • TorchDataset: how rows are loaded, optionally mined, and exposed
  • HardNegativeMiner: how query/positive rows can be expanded with mined negatives
  • CollateFn: how rows become query/document text pairs and tokenized tensors
  • Processor: text normalization or preprocessing, provided by retrievalbase

5. Trainers

There are two trainer families:

  • PyTorchTrainer: custom loop with checkpoints, TensorBoard logging, train/val split, and pluggable losses
  • SentenceTransformersTrainer: wraps SentenceTransformerTrainer and builds an IR evaluator from a Hugging Face dataset

Public Building Blocks

The current repository includes these reusable implementations:

  • embed_train.train.dataset.torch_datasets.QueryMultiPositiveDataset
  • embed_train.train.dataset.torch_datasets.QueryPositiveDataset
  • embed_train.train.dataset.torch_datasets.HardNegativeDataset
  • embed_train.train.dataset.hard_negatives.SentenceTransformerHardNegativeMiner
  • embed_train.train.dataset.collate.HardNegativeCollateFn
  • embed_train.train.dataset.collate.InBatchNegativeCollateFn
  • embed_train.train.dataset.collate.MultiPositiveInBatchCollateFn
  • embed_train.train.trainers.torch.loss.HardNegativeContrastiveLoss
  • embed_train.train.trainers.torch.loss.InBatchNegativeContrastiveLoss
  • embed_train.train.trainers.torch.loss.MultiPositiveContrastiveLoss
  • embed_train.train.trainers.hf.SentenceTransformersTrainer
  • embed_train.push_to_hf.PushToHFRunner

Typical Workflows

Custom PyTorch Training

Use this when you want full control over the training loop and your model wrapper already exposes a Hugging Face model.

What the built-in trainer does:

  • loads the configured model
  • loads a TorchDataset
  • builds a collate function with tokenizer and processor
  • splits the dataset into train/validation sets
  • trains with AdamW
  • logs to TensorBoard
  • saves checkpoints to data_dir/checkpoints/...

For hard-negative training, configure the trainer with:

  • torch_dataset: embed_train.train.dataset.torch_datasets.HardNegativeDataset
  • torch_dataset.hard_negative_miner: embed_train.train.dataset.hard_negatives.SentenceTransformerHardNegativeMiner
  • collate_fn: embed_train.train.dataset.collate.HardNegativeCollateFn
  • loss: embed_train.train.trainers.torch.loss.HardNegativeContrastiveLoss

HardNegativeDataset converts rows with metadata.query and page_content into a Hugging Face Dataset, mines negatives with sentence_transformers.util.mine_hard_negatives, and then exposes the mined rows through the normal TorchDataset interface. HardNegativeCollateFn expects each row to contain query, positive, and either negative or numbered negative_<n> fields. It emits one positive followed by that row's negatives, which is the candidate layout required by HardNegativeContrastiveLoss.

SentenceTransformers Training

Use this when your data is naturally represented as a Hugging Face Dataset and you want a standard SentenceTransformers training path with IR evaluation.

The trainer:

  • materializes a base model locally if needed
  • constructs a SentenceTransformer from transformer + pooling modules
  • creates train/validation splits
  • builds a SentenceTransformers loss from module_path
  • evaluates with InformationRetrievalEvaluator

Publish to Hugging Face Hub

Use PushToHFRunner to:

  • load a model wrapper from a checkpoint
  • save a local HF model repository
  • copy model source files such as modeling_*.py and configuration_*.py
  • optionally create the remote repository and upload the folder

Extending the Library

Add a Model Wrapper

Implement a subclass of embed_train.models.Model:

from transformers import AutoModel

from embed_train.models import Model


class MyModelWrapper(Model):
    def __init__(self, config):
        super().__init__(config)
        self.model = AutoModel.from_pretrained(config.base_model_name)

    def to_hf_model(self):
        return self.model

Add a Custom PyTorch Trainer

Subclass embed_train.train.trainers.torch.PyTorchTrainer and implement:

  • _encode_query(...)
  • _encode_documents(...)

This is the right place to define pooling, projection heads, shared or separate encoders, and any custom forward behavior.

Add a Custom Dataset

Subclass TorchDataset when you need a different row shape or data-loading strategy.

The built-in datasets show these common patterns:

  • grouped query -> many positives
  • flattened query -> single positive
  • flattened query -> single positive plus mined hard negatives

Add a Custom Collate Function

Subclass CollateFn and implement _process_batch(batch). Return two string lists:

  • queries
  • candidate documents

The base class handles processor application and tokenizer calls.

Add a Custom Loss

Subclass embed_train.train.trainers.torch.loss.Loss and implement _get_loss(q_emb, c_emb).

Output Conventions

PyTorch Trainer

  • TensorBoard logs: data_dir/tensorboard/runs/<run_name>
  • Checkpoints: data_dir/checkpoints/<run_name>/epoch_XXXX.pt

SentenceTransformers Trainer

  • checkpoints under data_dir/checkpoints/...

PushToHFRunner

  • local repository saved in hf.repo
  • optional upload to Hugging Face Hub when push=True

Best Practices When Using This Library

  • Treat settings.py as the contract for supported configuration.
  • Keep your custom classes importable from stable module paths.
  • Prefer explicit Python settings construction while iterating locally; add YAML once the pipeline is stable.
  • Keep save_steps, eval_steps, and logging_steps aligned in SentenceTransformers configs. The library validates that logging_steps <= eval_steps <= save_steps and that save_steps is a multiple of eval_steps.
  • Make checkpoint directories part of your experiment artifact strategy, not an afterthought.
  • Use deterministic seeds in your own custom samplers and data preparation logic when reproducibility matters.
  • Keep tokenizer and processor behavior documented together, since both affect embedding quality.
  • Push model code alongside weights when using custom Hugging Face model implementations.

Development

Common local commands:

make dev-install
make ci

Testing

The repository already contains:

  • unit tests for settings, runners, losses, models, and utility behavior
  • integration coverage for the train runner flow and dataset conversion behavior

Those tests are a good reference when you are unsure how a component is expected to behave.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

embed_train-3.2.0.tar.gz (154.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

embed_train-3.2.0-py3-none-any.whl (25.7 kB view details)

Uploaded Python 3

File details

Details for the file embed_train-3.2.0.tar.gz.

File metadata

  • Download URL: embed_train-3.2.0.tar.gz
  • Upload date:
  • Size: 154.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.13 {"installer":{"name":"uv","version":"0.11.13","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for embed_train-3.2.0.tar.gz
Algorithm Hash digest
SHA256 0de75023bf2a4dd18e0a0d6cfa48d71e4900f015f971bccefcc449e0ceed7633
MD5 34fe0d30b112b539dab76c9ab0573231
BLAKE2b-256 3e10ef35444e8e3bb2719ce4291218e0f55b263505309d940bc93ae5ba23ad4f

See more details on using hashes here.

File details

Details for the file embed_train-3.2.0-py3-none-any.whl.

File metadata

  • Download URL: embed_train-3.2.0-py3-none-any.whl
  • Upload date:
  • Size: 25.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.13 {"installer":{"name":"uv","version":"0.11.13","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for embed_train-3.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f29904d352a2f06b28bec3927e934ff87128ae1604d5b802afaa76e418d73e14
MD5 f3678405c5fb2a02fe2ce14ef331691f
BLAKE2b-256 3249422cc15b68dd288ab3666ccc0e903cc538052628139af376c7d81040d978

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page