No project description provided
Project description
embed-train
embed-train is a config-driven library for training, evaluating, checkpointing, and publishing embedding models.
It is designed for teams that want reusable training building blocks instead of one-off scripts.
The package currently supports two main training paths:
- A custom PyTorch training loop for query/document contrastive learning
- A SentenceTransformers-based trainer for Hugging Face
Datasetworkflows and IR evaluation
It also includes a dedicated runner for exporting checkpoints and pushing model repositories to Hugging Face Hub.
Why This Library Exists
Most embedding projects end up reimplementing the same pieces:
- model loading and checkpoint restore
- dataset adapters
- collate/tokenization logic
- contrastive losses
- train/validation splitting
- checkpoint saving
- model packaging and Hub publishing
embed-train turns those pieces into small, replaceable abstractions connected by typed settings objects and module_path-based loading.
Installation
Python requirements from pyproject.toml:
- Python
>=3.11,<3.13 - PyTorch
- Sentence Transformers
- Hugging Face Datasets
- TensorBoard
- Accelerate
retrievalbase
Install from source:
pip install -e .
For local development in this repository:
make dev-install
Library Structure
src/embed_train/
├── __init__.py # Base Runner abstraction and YAML runner loader
├── constants.py # Shared constants such as default seed and trust flags
├── exceptions.py # Library-specific exception hierarchy
├── models/ # Base model wrapper abstraction
├── push_to_hf/ # Runner for packaging and pushing checkpoints to HF Hub
├── settings.py # Pydantic settings models for all library components
├── train/
│ ├── __init__.py # TrainRunner
│ ├── dataset/
│ │ ├── __init__.py # Base TorchDataset and CollateFn abstractions
│ │ ├── collate.py # Built-in in-batch positive collate functions
│ │ ├── torch_datasets.py # Built-in query/positive dataset views
│ │ └── sampling/
│ │ └── samplers.py # Built-in positive sampler(s)
│ └── trainers/
│ ├── __init__.py # Base Trainer abstraction
│ ├── hf/ # SentenceTransformers trainer
│ └── torch/
│ ├── __init__.py # Custom PyTorch trainer base
│ └── loss.py # Contrastive loss implementations
|
└── utils.py # Dynamic import, checkpoint loading, HF file helpers
Core Concepts
1. Runners
Runners are the top-level execution unit.
embed_train.Runner: abstract base classembed_train.train.TrainRunner: loads a trainer and executestrain()embed_train.push_to_hf.PushToHFRunner: restores a checkpoint, saves a local HF-style repo, and optionally pushes it
2. Settings-Driven Composition
Every major component is configured through Pydantic settings objects in src/embed_train/settings.py.
Most components are resolved dynamically from a module_path, which makes the library extensible without editing core orchestration code.
3. Model Wrappers
Models inherit from embed_train.models.Model and must implement:
to_hf_model(): return a Hugging FacePreTrainedModel
The base class already provides:
save(...)to(device)from_checkpoint(...)
4. Datasets and Collate Functions
The PyTorch flow separates concerns clearly:
TorchDataset: how rows are loaded and exposedCollateFn: how rows become query/document text pairs and tokenized tensorsProcessor: text normalization or preprocessing, provided byretrievalbase
5. Trainers
There are two trainer families:
PyTorchTrainer: custom loop with checkpoints, TensorBoard logging, train/val split, and pluggable lossesSentenceTransformersTrainer: wrapsSentenceTransformerTrainerand builds an IR evaluator from a Hugging Face dataset
Public Building Blocks
The current repository includes these reusable implementations:
embed_train.train.dataset.torch_datasets.QueryMultiPositiveDatasetembed_train.train.dataset.torch_datasets.QueryPositiveDatasetembed_train.train.dataset.collate.InBatchPositiveCollateFnembed_train.train.dataset.collate.MultiPositiveInBatchCollateFnembed_train.train.trainers.torch.loss.InBatchNegativeContrastiveLossembed_train.train.trainers.torch.loss.MultiPositiveContrastiveLossembed_train.train.trainers.hf.SentenceTransformersTrainerembed_train.push_to_hf.PushToHFRunner
Typical Workflows
Custom PyTorch Training
Use this when you want full control over the training loop and your model wrapper already exposes a Hugging Face model.
What the built-in trainer does:
- loads the configured model
- loads a
TorchDataset - builds a collate function with tokenizer and processor
- splits the dataset into train/validation sets
- trains with AdamW
- logs to TensorBoard
- saves checkpoints to
data_dir/checkpoints/...
SentenceTransformers Training
Use this when your data is naturally represented as a Hugging Face Dataset and you want a standard SentenceTransformers training path with IR evaluation.
The trainer:
- materializes a base model locally if needed
- constructs a
SentenceTransformerfrom transformer + pooling modules - creates train/validation splits
- builds a SentenceTransformers loss from
module_path - evaluates with
InformationRetrievalEvaluator
Publish to Hugging Face Hub
Use PushToHFRunner to:
- load a model wrapper from a checkpoint
- save a local HF model repository
- copy model source files such as
modeling_*.pyandconfiguration_*.py - optionally create the remote repository and upload the folder
Extending the Library
Add a Model Wrapper
Implement a subclass of embed_train.models.Model:
from transformers import AutoModel
from embed_train.models import Model
class MyModelWrapper(Model):
def __init__(self, config):
super().__init__(config)
self.model = AutoModel.from_pretrained(config.base_model_name)
def to_hf_model(self):
return self.model
Add a Custom PyTorch Trainer
Subclass embed_train.train.trainers.torch.PyTorchTrainer and implement:
_encode_query(...)_encode_documents(...)
This is the right place to define pooling, projection heads, shared or separate encoders, and any custom forward behavior.
Add a Custom Dataset
Subclass TorchDataset when you need a different row shape or data-loading strategy.
The built-in datasets show two common patterns:
- grouped query -> many positives
- flattened query -> single positive
Add a Custom Collate Function
Subclass CollateFn and implement _process_batch(batch).
Return two string lists:
- queries
- candidate documents
The base class handles processor application and tokenizer calls.
Add a Custom Loss
Subclass embed_train.train.trainers.torch.loss.Loss and implement _get_loss(q_emb, c_emb).
Output Conventions
PyTorch Trainer
- TensorBoard logs:
data_dir/tensorboard/runs/<run_name> - Checkpoints:
data_dir/checkpoints/<run_name>/epoch_XXXX.pt
SentenceTransformers Trainer
- checkpoints under
data_dir/checkpoints/...
PushToHFRunner
- local repository saved in
hf.repo - optional upload to Hugging Face Hub when
push=True
Best Practices When Using This Library
- Treat
settings.pyas the contract for supported configuration. - Keep your custom classes importable from stable module paths.
- Prefer explicit Python settings construction while iterating locally; add YAML once the pipeline is stable.
- Keep
save_steps,eval_steps, andlogging_stepsaligned in SentenceTransformers configs. The library validates thatlogging_steps <= eval_steps <= save_stepsand thatsave_stepsis a multiple ofeval_steps. - Make checkpoint directories part of your experiment artifact strategy, not an afterthought.
- Use deterministic seeds in your own custom samplers and data preparation logic when reproducibility matters.
- Keep tokenizer and processor behavior documented together, since both affect embedding quality.
- Push model code alongside weights when using custom Hugging Face model implementations.
Development
Common local commands:
make dev-install
make ci
Testing
The repository already contains:
- unit tests for settings, runners, losses, models, and utility behavior
- integration coverage for the train runner flow and dataset conversion behavior
Those tests are a good reference when you are unsure how a component is expected to behave.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file embed_train-1.0.0.tar.gz.
File metadata
- Download URL: embed_train-1.0.0.tar.gz
- Upload date:
- Size: 149.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.7 {"installer":{"name":"uv","version":"0.11.7","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d976ce4c803ced733441bd0d9c6c50e69253e280e5064e88362fb923c5525f30
|
|
| MD5 |
f8435288d5ec9bc97ab7e6ec88ea4f41
|
|
| BLAKE2b-256 |
b2cb32173f66709ebd748e8922d6c620ce595c3b26061410fe65020c9932b075
|
File details
Details for the file embed_train-1.0.0-py3-none-any.whl.
File metadata
- Download URL: embed_train-1.0.0-py3-none-any.whl
- Upload date:
- Size: 21.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.7 {"installer":{"name":"uv","version":"0.11.7","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2cd5b48d707c480a8f4af03ae6c2f06ca5252480d4b0479b72e6ff5de08c8745
|
|
| MD5 |
3f72ec49ffd1e48205224f347ba2af15
|
|
| BLAKE2b-256 |
f5335e6af5ceed28ba8b17dd1f25bc411d90112efb4466adefb58a8774ce686d
|