Skip to main content

No project description provided

Project description

embed-train

embed-train is a config-driven library for training, evaluating, checkpointing, and publishing embedding models. It is designed for teams that want reusable training building blocks instead of one-off scripts.

The package currently supports two main training paths:

  • A custom PyTorch training loop for query/document contrastive learning
  • A SentenceTransformers-based trainer for Hugging Face Dataset workflows and IR evaluation

It also includes a dedicated runner for exporting checkpoints and pushing model repositories to Hugging Face Hub.

Why This Library Exists

Most embedding projects end up reimplementing the same pieces:

  • model loading and checkpoint restore
  • dataset adapters
  • collate/tokenization logic
  • contrastive losses
  • train/validation splitting
  • checkpoint saving
  • model packaging and Hub publishing

embed-train turns those pieces into small, replaceable abstractions connected by typed settings objects and module_path-based loading.

Installation

Python requirements from pyproject.toml:

  • Python >=3.11,<3.13
  • PyTorch
  • Sentence Transformers
  • Hugging Face Datasets
  • TensorBoard
  • Accelerate
  • retrievalbase

Install from source:

pip install -e .

For local development in this repository:

make dev-install

Library Structure

src/embed_train/
├── __init__.py                    # Base Runner abstraction and YAML runner loader
├── constants.py                   # Shared constants such as default seed and trust flags
├── exceptions.py                  # Library-specific exception hierarchy
├── models/                        # Base model wrapper abstraction
├── push_to_hf/                    # Runner for packaging and pushing checkpoints to HF Hub
├── settings.py                    # Pydantic settings models for all library components
├── train/
│   ├── __init__.py                # TrainRunner
│   ├── dataset/
│   │   ├── __init__.py            # Base TorchDataset, HardNegativeMiner, and CollateFn abstractions
│   │   ├── collate.py             # Built-in in-batch and hard-negative collate functions
│   │   ├── hard_negatives.py      # SentenceTransformers hard-negative miner
│   │   ├── torch_datasets.py      # Built-in query/positive and hard-negative dataset views
│   │   └── sampling/
│   │       └── samplers.py        # Built-in positive sampler(s)
│   └── trainers/
│       ├── __init__.py            # Base Trainer abstraction
│       ├── hf/                    # SentenceTransformers trainer
│       └── torch/
│           ├── __init__.py        # Custom PyTorch trainer base
│           └── loss.py            # Contrastive loss implementations
|
└── utils.py                       # Dynamic import, checkpoint loading, HF file helpers

Core Concepts

1. Runners

Runners are the top-level execution unit.

  • embed_train.Runner: abstract base class
  • embed_train.train.TrainRunner: loads a trainer and executes train()
  • embed_train.push_to_hf.PushToHFRunner: restores a checkpoint, saves a local HF-style repo, and optionally pushes it

2. Settings-Driven Composition

Every major component is configured through Pydantic settings objects in src/embed_train/settings.py. Most components are resolved dynamically from a module_path, which makes the library extensible without editing core orchestration code.

3. Model Wrappers

Models inherit from embed_train.models.Model and must implement:

  • to_hf_model(): return a Hugging Face PreTrainedModel

The base class already provides:

  • save(...)
  • to(device)
  • from_checkpoint(...)

4. Datasets and Collate Functions

The PyTorch flow separates concerns clearly:

  • TorchDataset: how rows are loaded, optionally mined, and exposed
  • HardNegativeMiner: how query/positive rows can be expanded with mined negatives
  • CollateFn: how rows become query/document text pairs and tokenized tensors
  • Processor: text normalization or preprocessing, provided by retrievalbase

5. Trainers

There are two trainer families:

  • PyTorchTrainer: custom loop with checkpoints, TensorBoard logging, train/val split, and pluggable losses
  • SentenceTransformersTrainer: wraps SentenceTransformerTrainer and builds an IR evaluator from a Hugging Face dataset

Public Building Blocks

The current repository includes these reusable implementations:

  • embed_train.train.dataset.torch_datasets.QueryMultiPositiveDataset
  • embed_train.train.dataset.torch_datasets.QueryPositiveDataset
  • embed_train.train.dataset.torch_datasets.HardNegativeDataset
  • embed_train.train.dataset.hard_negatives.SentenceTransformerHardNegativeMiner
  • embed_train.train.dataset.collate.HardNegativeCollateFn
  • embed_train.train.dataset.collate.InBatchNegativeCollateFn
  • embed_train.train.dataset.collate.MultiPositiveInBatchCollateFn
  • embed_train.train.trainers.torch.loss.HardNegativeContrastiveLoss
  • embed_train.train.trainers.torch.loss.InBatchNegativeContrastiveLoss
  • embed_train.train.trainers.torch.loss.MultiPositiveContrastiveLoss
  • embed_train.train.trainers.hf.SentenceTransformersTrainer
  • embed_train.push_to_hf.PushToHFRunner

Typical Workflows

Custom PyTorch Training

Use this when you want full control over the training loop and your model wrapper already exposes a Hugging Face model.

What the built-in trainer does:

  • loads the configured model
  • loads a TorchDataset
  • builds a collate function with tokenizer and processor
  • splits the dataset into train/validation sets
  • trains with AdamW
  • logs to TensorBoard
  • saves checkpoints to data_dir/checkpoints/...

For hard-negative training, configure the trainer with:

  • torch_dataset: embed_train.train.dataset.torch_datasets.HardNegativeDataset
  • torch_dataset.hard_negative_miner: embed_train.train.dataset.hard_negatives.SentenceTransformerHardNegativeMiner
  • collate_fn: embed_train.train.dataset.collate.HardNegativeCollateFn
  • loss: embed_train.train.trainers.torch.loss.HardNegativeContrastiveLoss

HardNegativeDataset converts rows with metadata.query and page_content into a Hugging Face Dataset, mines negatives with sentence_transformers.util.mine_hard_negatives, and then exposes the mined rows through the normal TorchDataset interface. HardNegativeCollateFn expects each row to contain query, positive, and either negative or numbered negative_<n> fields. It emits one positive followed by that row's negatives, which is the candidate layout required by HardNegativeContrastiveLoss.

SentenceTransformers Training

Use this when your data is naturally represented as a Hugging Face Dataset and you want a standard SentenceTransformers training path with IR evaluation.

The trainer:

  • materializes a base model locally if needed
  • constructs a SentenceTransformer from transformer + pooling modules
  • creates train/validation splits
  • builds a SentenceTransformers loss from module_path
  • evaluates with InformationRetrievalEvaluator

Publish to Hugging Face Hub

Use PushToHFRunner to:

  • load a model wrapper from a checkpoint
  • save a local HF model repository
  • copy model source files such as modeling_*.py, configuration_*.py, and vllm_*.py when with_vllm_integration is enabled
  • optionally rewrite the local repository as a SentenceTransformer model when sentence_transformer is configured
  • optionally create the remote repository and upload the folder

Set with_vllm_integration: false for wrappers that save a standard HF model and do not have a local model source directory to copy. Leave it enabled for custom models that need their modeling_*, configuration_*, or vllm_* files packaged into the repo.

Example sentence_transformer block for push configs:

sentence_transformer:
  pooling_mode: mean
  trust_remote_code: false
  tokenizer:
    name: your-tokenizer-name
    padding: true
    truncation: true
    max_length: 512

Extending the Library

Add a Model Wrapper

Implement a subclass of embed_train.models.Model:

from transformers import AutoModel

from embed_train.models import Model


class MyModelWrapper(Model):
    def __init__(self, config):
        super().__init__(config)
        self.model = AutoModel.from_pretrained(config.base_model_name)

    def to_hf_model(self):
        return self.model

Add a Custom PyTorch Trainer

Subclass embed_train.train.trainers.torch.PyTorchTrainer and implement:

  • _encode_query(...)
  • _encode_documents(...)

This is the right place to define pooling, projection heads, shared or separate encoders, and any custom forward behavior.

Add a Custom Dataset

Subclass TorchDataset when you need a different row shape or data-loading strategy.

The built-in datasets show these common patterns:

  • grouped query -> many positives
  • flattened query -> single positive
  • flattened query -> single positive plus mined hard negatives

Add a Custom Collate Function

Subclass CollateFn and implement _process_batch(batch). Return two string lists:

  • queries
  • candidate documents

The base class handles processor application and tokenizer calls.

Add a Custom Loss

Subclass embed_train.train.trainers.torch.loss.Loss and implement _get_loss(q_emb, c_emb).

Output Conventions

PyTorch Trainer

  • TensorBoard logs: data_dir/tensorboard/runs/<run_name>
  • Checkpoints: data_dir/checkpoints/<run_name>/epoch_XXXX.pt

SentenceTransformers Trainer

  • checkpoints under data_dir/checkpoints/...

PushToHFRunner

  • local repository saved in hf.repo
  • optional upload to Hugging Face Hub when push=True

Best Practices When Using This Library

  • Treat settings.py as the contract for supported configuration.
  • Keep your custom classes importable from stable module paths.
  • Prefer explicit Python settings construction while iterating locally; add YAML once the pipeline is stable.
  • Keep save_steps, eval_steps, and logging_steps aligned in SentenceTransformers configs. The library validates that logging_steps <= eval_steps <= save_steps and that save_steps is a multiple of eval_steps.
  • Make checkpoint directories part of your experiment artifact strategy, not an afterthought.
  • Use deterministic seeds in your own custom samplers and data preparation logic when reproducibility matters.
  • Keep tokenizer and processor behavior documented together, since both affect embedding quality.
  • Push model code alongside weights when using custom Hugging Face model implementations.

Development

Common local commands:

make dev-install
make ci

Testing

The repository already contains:

  • unit tests for settings, runners, losses, models, and utility behavior
  • integration coverage for the train runner flow and dataset conversion behavior

Those tests are a good reference when you are unsure how a component is expected to behave.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

embed_train-6.0.0.tar.gz (155.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

embed_train-6.0.0-py3-none-any.whl (26.4 kB view details)

Uploaded Python 3

File details

Details for the file embed_train-6.0.0.tar.gz.

File metadata

  • Download URL: embed_train-6.0.0.tar.gz
  • Upload date:
  • Size: 155.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.17 {"installer":{"name":"uv","version":"0.11.17","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for embed_train-6.0.0.tar.gz
Algorithm Hash digest
SHA256 ecb02f7d8b26d6500ffcdd121667d84a9acc5e338773d8d035eda59ac4f6498f
MD5 08b04e09592045659101777a4a62fe14
BLAKE2b-256 94f65e9bdefb3176ca3a9f8cb6bcd96a339ebb8ab4afdf04254de64e78a85044

See more details on using hashes here.

File details

Details for the file embed_train-6.0.0-py3-none-any.whl.

File metadata

  • Download URL: embed_train-6.0.0-py3-none-any.whl
  • Upload date:
  • Size: 26.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.17 {"installer":{"name":"uv","version":"0.11.17","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for embed_train-6.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fc2d950a5fa926adf25b0df7bb893b802f65f387068eac45845fe58eb73528eb
MD5 1f21045ca0ff58695a3b9852c1263f06
BLAKE2b-256 dc1260ec4c5972a6e0faed435c15b63924d0aaf15c98e4a923f7d68a2a6bf776

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page