Skip to main content

No project description provided

Project description

embed-train

embed-train is a config-driven library for training, evaluating, checkpointing, and publishing embedding models. It is designed for teams that want reusable training building blocks instead of one-off scripts.

The package currently supports two main training paths:

  • A custom PyTorch training loop for query/document contrastive learning
  • A SentenceTransformers-based trainer for Hugging Face Dataset workflows and IR evaluation

It also includes a dedicated runner for exporting checkpoints and pushing model repositories to Hugging Face Hub.

Why This Library Exists

Most embedding projects end up reimplementing the same pieces:

  • model loading and checkpoint restore
  • dataset adapters
  • collate/tokenization logic
  • contrastive losses
  • train/validation splitting
  • checkpoint saving
  • model packaging and Hub publishing

embed-train turns those pieces into small, replaceable abstractions connected by typed settings objects and module_path-based loading.

Installation

Python requirements from pyproject.toml:

  • Python >=3.11,<3.13
  • PyTorch
  • Sentence Transformers
  • Hugging Face Datasets
  • TensorBoard
  • Accelerate
  • retrievalbase

Install from source:

pip install -e .

For local development in this repository:

make dev-install

Library Structure

src/embed_train/
├── __init__.py                    # Base Runner abstraction and YAML runner loader
├── constants.py                   # Shared constants such as default seed and trust flags
├── exceptions.py                  # Library-specific exception hierarchy
├── models/                        # Base model wrapper abstraction
├── push_to_hf/                    # Runner for packaging and pushing checkpoints to HF Hub
├── settings.py                    # Pydantic settings models for all library components
├── train/
│   ├── __init__.py                # TrainRunner
│   ├── dataset/
│   │   ├── __init__.py            # Base TorchDataset and CollateFn abstractions
│   │   ├── collate.py             # Built-in in-batch positive collate functions
│   │   ├── torch_datasets.py      # Built-in query/positive dataset views
│   │   └── sampling/
│   │       └── samplers.py        # Built-in positive sampler(s)
│   └── trainers/
│       ├── __init__.py            # Base Trainer abstraction
│       ├── hf/                    # SentenceTransformers trainer
│       └── torch/
│           ├── __init__.py        # Custom PyTorch trainer base
│           └── loss.py            # Contrastive loss implementations
|
└── utils.py                       # Dynamic import, checkpoint loading, HF file helpers

Core Concepts

1. Runners

Runners are the top-level execution unit.

  • embed_train.Runner: abstract base class
  • embed_train.train.TrainRunner: loads a trainer and executes train()
  • embed_train.push_to_hf.PushToHFRunner: restores a checkpoint, saves a local HF-style repo, and optionally pushes it

2. Settings-Driven Composition

Every major component is configured through Pydantic settings objects in src/embed_train/settings.py. Most components are resolved dynamically from a module_path, which makes the library extensible without editing core orchestration code.

3. Model Wrappers

Models inherit from embed_train.models.Model and must implement:

  • to_hf_model(): return a Hugging Face PreTrainedModel

The base class already provides:

  • save(...)
  • to(device)
  • from_checkpoint(...)

4. Datasets and Collate Functions

The PyTorch flow separates concerns clearly:

  • TorchDataset: how rows are loaded and exposed
  • CollateFn: how rows become query/document text pairs and tokenized tensors
  • Processor: text normalization or preprocessing, provided by retrievalbase

5. Trainers

There are two trainer families:

  • PyTorchTrainer: custom loop with checkpoints, TensorBoard logging, train/val split, and pluggable losses
  • SentenceTransformersTrainer: wraps SentenceTransformerTrainer and builds an IR evaluator from a Hugging Face dataset

Public Building Blocks

The current repository includes these reusable implementations:

  • embed_train.train.dataset.torch_datasets.QueryMultiPositiveDataset
  • embed_train.train.dataset.torch_datasets.QueryPositiveDataset
  • embed_train.train.dataset.collate.InBatchNegativeCollateFn
  • embed_train.train.dataset.collate.MultiPositiveInBatchCollateFn
  • embed_train.train.trainers.torch.loss.InBatchNegativeContrastiveLoss
  • embed_train.train.trainers.torch.loss.MultiPositiveContrastiveLoss
  • embed_train.train.trainers.hf.SentenceTransformersTrainer
  • embed_train.push_to_hf.PushToHFRunner

Typical Workflows

Custom PyTorch Training

Use this when you want full control over the training loop and your model wrapper already exposes a Hugging Face model.

What the built-in trainer does:

  • loads the configured model
  • loads a TorchDataset
  • builds a collate function with tokenizer and processor
  • splits the dataset into train/validation sets
  • trains with AdamW
  • logs to TensorBoard
  • saves checkpoints to data_dir/checkpoints/...

SentenceTransformers Training

Use this when your data is naturally represented as a Hugging Face Dataset and you want a standard SentenceTransformers training path with IR evaluation.

The trainer:

  • materializes a base model locally if needed
  • constructs a SentenceTransformer from transformer + pooling modules
  • creates train/validation splits
  • builds a SentenceTransformers loss from module_path
  • evaluates with InformationRetrievalEvaluator

Publish to Hugging Face Hub

Use PushToHFRunner to:

  • load a model wrapper from a checkpoint
  • save a local HF model repository
  • copy model source files such as modeling_*.py and configuration_*.py
  • optionally create the remote repository and upload the folder

Extending the Library

Add a Model Wrapper

Implement a subclass of embed_train.models.Model:

from transformers import AutoModel

from embed_train.models import Model


class MyModelWrapper(Model):
    def __init__(self, config):
        super().__init__(config)
        self.model = AutoModel.from_pretrained(config.base_model_name)

    def to_hf_model(self):
        return self.model

Add a Custom PyTorch Trainer

Subclass embed_train.train.trainers.torch.PyTorchTrainer and implement:

  • _encode_query(...)
  • _encode_documents(...)

This is the right place to define pooling, projection heads, shared or separate encoders, and any custom forward behavior.

Add a Custom Dataset

Subclass TorchDataset when you need a different row shape or data-loading strategy.

The built-in datasets show two common patterns:

  • grouped query -> many positives
  • flattened query -> single positive

Add a Custom Collate Function

Subclass CollateFn and implement _process_batch(batch). Return two string lists:

  • queries
  • candidate documents

The base class handles processor application and tokenizer calls.

Add a Custom Loss

Subclass embed_train.train.trainers.torch.loss.Loss and implement _get_loss(q_emb, c_emb).

Output Conventions

PyTorch Trainer

  • TensorBoard logs: data_dir/tensorboard/runs/<run_name>
  • Checkpoints: data_dir/checkpoints/<run_name>/epoch_XXXX.pt

SentenceTransformers Trainer

  • checkpoints under data_dir/checkpoints/...

PushToHFRunner

  • local repository saved in hf.repo
  • optional upload to Hugging Face Hub when push=True

Best Practices When Using This Library

  • Treat settings.py as the contract for supported configuration.
  • Keep your custom classes importable from stable module paths.
  • Prefer explicit Python settings construction while iterating locally; add YAML once the pipeline is stable.
  • Keep save_steps, eval_steps, and logging_steps aligned in SentenceTransformers configs. The library validates that logging_steps <= eval_steps <= save_steps and that save_steps is a multiple of eval_steps.
  • Make checkpoint directories part of your experiment artifact strategy, not an afterthought.
  • Use deterministic seeds in your own custom samplers and data preparation logic when reproducibility matters.
  • Keep tokenizer and processor behavior documented together, since both affect embedding quality.
  • Push model code alongside weights when using custom Hugging Face model implementations.

Development

Common local commands:

make dev-install
make ci

Testing

The repository already contains:

  • unit tests for settings, runners, losses, models, and utility behavior
  • integration coverage for the train runner flow and dataset conversion behavior

Those tests are a good reference when you are unsure how a component is expected to behave.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

embed_train-2.0.0.tar.gz (150.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

embed_train-2.0.0-py3-none-any.whl (22.8 kB view details)

Uploaded Python 3

File details

Details for the file embed_train-2.0.0.tar.gz.

File metadata

  • Download URL: embed_train-2.0.0.tar.gz
  • Upload date:
  • Size: 150.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.11 {"installer":{"name":"uv","version":"0.11.11","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for embed_train-2.0.0.tar.gz
Algorithm Hash digest
SHA256 b6a63a70e380d5f6a3d4a0252104aec16b35e9bbe1fe04051946abe5aaf27ac8
MD5 1f1f831b89fad7ca008e37d2253dab03
BLAKE2b-256 c5bb2905b4f83ebad90f281fc0935061840c869cc02c2df98773253d0e61e278

See more details on using hashes here.

File details

Details for the file embed_train-2.0.0-py3-none-any.whl.

File metadata

  • Download URL: embed_train-2.0.0-py3-none-any.whl
  • Upload date:
  • Size: 22.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.11 {"installer":{"name":"uv","version":"0.11.11","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for embed_train-2.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9e00c3ca038835d01b03a59811ec5933493c2ad8f89e83e7790364934c0beb54
MD5 9e575be7bde61061fc2f43ad2a2395c6
BLAKE2b-256 d6466c6344103052cd209dc02b79126f94ecc39686c6e4ca490435c1ab944dd3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page