Hugging Face transformers with Hydra integration
Project description
Hydra Trainer
A package that wraps Hugging Face's Transformer Trainer with Hydra integration for better configuration management and hyperparameter optimization support.
Checkout my powerformer repo for a concrete example.
Features
- Hydra configuration management
- Optuna hyperparameter optimization integration
- Easy-to-extend base classes for custom datasets and trainers
- Specify TrainingArguments or hyperparameter search parameters within a hydra configuration file
- An example config,
base.yaml, is provided in this package.
- An example config,
Installation
pip install hydra-trainer
Quick Start
- Create your dataset class by extending
BaseDatasetor use any dataset that extendsdatasets.Dataset:
from typing import Literal
from omegaconf import DictConfig
from hydra_trainer import BaseDataset
class ExampleDataset(BaseDataset):
def __init__(self, cfg: DictConfig, dataset_key: Literal["train", "eval"]):
super().__init__(cfg)
self.dataset_key = dataset_key
# TODO: implement dataset loading and preprocessing
raise NotImplementedError
def __len__(self):
# TODO: implement this method
raise NotImplementedError
def __getitem__(self, idx):
# TODO: implement this method
raise NotImplementedError
- Create your trainer class by extending
BaseTrainer:
from typing import Literal
import optuna
from omegaconf import DictConfig
from hydra_trainer import BaseTrainer
class ExampleTrainer(BaseTrainer[ExampleDataset, DictConfig]):
def model_init_factory(self):
def model_init(trial: optuna.Trial | None = None):
model_cfg = self.get_trial_model_cfg(trial, self.cfg)
# TODO: implement model initialization
raise NotImplementedError
return model_init
def dataset_factory(
self, dataset_cfg: DictConfig, dataset_key: Literal["train", "eval"]
) -> ExampleDataset:
# TODO: implement this method
raise NotImplementedError
- Set up your training script with Hydra:
import hydra
from omegaconf import DictConfig
@hydra.main(config_path="hydra_trainer", config_name="base", version_base=None)
def main(cfg: DictConfig):
trainer = ExampleTrainer(cfg)
trainer.train()
if __name__ == "__main__":
main()
BaseTrainer Key Features
- Model Initialization Factory: Implement
model_init_factory()to define how your model is created. - Dataset Factory: Implement
dataset_factory()to create your training and evaluation datasets
Configuration
The package uses Hydra for configuration management. Here's the base configuration structure:
seed: 42
checkpoint_path: null
resume_from_checkpoint: null
do_hyperoptim: false
early_stopping_patience: 3
model: # model parameters - access them within `model_init_factory` implementation
d_model: 128
n_layers: 12
n_heads: 16
d_ff: 512
trainer: # transformers.TrainingArguments
num_train_epochs: 3
eval_strategy: steps
eval_steps: 50
logging_steps: 5
output_dir: training_output
per_device_train_batch_size: 2
per_device_eval_batch_size: 4096
learning_rate: 5e-3
weight_decay: 0.0
fp16: true
hyperopt:
n_trials: 128
patience: 2
persistence: true # set to false to use in memory storage instead of db storage
load_if_exists: true
storage_url: postgresql://postgres:password@127.0.0.1:5432/postgres
storage_heartbeat_interval: 15
storage_engine_kwargs:
pool_size: 5
connect_args:
keepalives: 1
hp_space:
training:
- name: learning_rate # TrainingArguments attribute name
type: float
low: 5e-5
high: 5e-3
step: 1e-5
log: true
model:
- name: d_model # model parameters
type: int
low: 128
high: 512
step: 128
log: true
Hyperparameter Optimization
Enable hyperparameter optimization by setting do_hyperoptim: true in your config. The package uses Optuna for hyperparameter optimization with support for:
- Integer parameters
- Float parameters
- Categorical parameters
- Persistent storage with a relational database
- Early stopping with patient pruning
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file hydra_trainer-0.1.2.tar.gz.
File metadata
- Download URL: hydra_trainer-0.1.2.tar.gz
- Upload date:
- Size: 5.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
17359d1bb46be5bafaf4cbd87f89e5102336705ccc5f5658864a0e1e1b2b8a46
|
|
| MD5 |
3823403d075c83aa19c066587d4aa38e
|
|
| BLAKE2b-256 |
80cb099a7ea1a8f87a447c0c41142b3f4c3795304ed789e9d23328ffac75a940
|
Provenance
The following attestation bundles were made for hydra_trainer-0.1.2.tar.gz:
Publisher:
release.yml on emapco/HydraTrainer
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
hydra_trainer-0.1.2.tar.gz -
Subject digest:
17359d1bb46be5bafaf4cbd87f89e5102336705ccc5f5658864a0e1e1b2b8a46 - Sigstore transparency entry: 176422089
- Sigstore integration time:
-
Permalink:
emapco/HydraTrainer@e28e68655423e2f8aa518c54dc4b3df5f660d208 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/emapco
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@e28e68655423e2f8aa518c54dc4b3df5f660d208 -
Trigger Event:
push
-
Statement type:
File details
Details for the file hydra_trainer-0.1.2-py3-none-any.whl.
File metadata
- Download URL: hydra_trainer-0.1.2-py3-none-any.whl
- Upload date:
- Size: 5.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a600e7642ac80e0918d282216f49775302d7139f41e869aed2847da90a32c593
|
|
| MD5 |
d62001c7013bb6768013ffd218274e4e
|
|
| BLAKE2b-256 |
bf0c871c08f29b23f5d1cfc9d2d59b14f06060286452ade87ec089e2057e6b99
|
Provenance
The following attestation bundles were made for hydra_trainer-0.1.2-py3-none-any.whl:
Publisher:
release.yml on emapco/HydraTrainer
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
hydra_trainer-0.1.2-py3-none-any.whl -
Subject digest:
a600e7642ac80e0918d282216f49775302d7139f41e869aed2847da90a32c593 - Sigstore transparency entry: 176422090
- Sigstore integration time:
-
Permalink:
emapco/HydraTrainer@e28e68655423e2f8aa518c54dc4b3df5f660d208 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/emapco
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@e28e68655423e2f8aa518c54dc4b3df5f660d208 -
Trigger Event:
push
-
Statement type: