Skip to main content

Hugging Face transformers with Hydra integration

Project description

Hydra Trainer

A package that wraps Hugging Face's Transformer Trainer with Hydra integration for better configuration management and hyperparameter optimization support.

Features

  • Hydra configuration management
  • Optuna hyperparameter optimization integration
  • Easy-to-extend base classes for custom datasets and trainers
  • Specify TrainingArguments or hyperparameter search parameters within a hydra configuration file
    • An example config, base.yaml, is provided in this package.

Installation

pip install hydra-trainer

Quick Start

  1. Create your dataset class by extending BaseDataset or use any dataset that extends datasets.Dataset:
from typing import Literal
from omegaconf import DictConfig
from hydra_trainer import BaseDataset

class ExampleDataset(BaseDataset):
    def __init__(self, cfg: DictConfig, dataset_key: Literal["train", "eval"]):
        super().__init__(cfg)
        self.dataset_key = dataset_key
        # TODO: implement dataset loading and preprocessing
        raise NotImplementedError

    def __len__(self):
        # TODO: implement this method
        raise NotImplementedError

    def __getitem__(self, idx):
        # TODO: implement this method
        raise NotImplementedError
  1. Create your trainer class by extending BaseTrainer:
from typing import Literal

import optuna
from omegaconf import DictConfig

from hydra_trainer import BaseTrainer


class ExampleTrainer(BaseTrainer[ExampleDataset]):
    def _model_init_factory(self):
        model_cfg = self.cfg.model

        def model_init_closure(trial: optuna.Trial | None = None):
            if trial is not None and hasattr(self.cfg.hyperopt.hp_space, "model"):
                # Override config with trial parameters for model
                for param in self.cfg.hyperopt.hp_space.model:
                    param_name = param.name
                    trial_name = f"model__{param_name}"
                    model_cfg[param_name] = trial.params[trial_name]

            # TODO: implement model initialization
            raise NotImplementedError

        return model_init_closure

    def _dataset_factory(
        self, dataset_cfg: DictConfig, dataset_key: Literal["train", "eval"]
    ) -> ExampleDataset:
        # TODO: implement this method
        raise NotImplementedError
  1. Set up your training script with Hydra:
import hydra
from omegaconf import DictConfig

@hydra.main(config_path="../conf", config_name="base", version_base=None)
def main(cfg: DictConfig):
    trainer = ExampleTrainer(cfg)
    trainer.train()

if __name__ == "__main__":
    main()

BaseTrainer Key Features

  1. Model Initialization Factory: Implement _model_init_factory() to define how your model is created.
  2. Dataset Factory: Implement _dataset_factory() to create your training and evaluation datasets
  3. Early Stopping: Built-in early stopping support with configurable patience

Configuration

The package uses Hydra for configuration management. Here's the base configuration structure:

seed: 42
checkpoint_path: null
resume_from_checkpoint: null
do_hyperoptim: false
early_stopping_patience: 3

model: # model parameters - access within hydra_trainer._model_init_factory
  d_model: 128
  n_layers: 12
  n_heads: 16
  d_ff: 512

trainer: #  transformers.TrainingArguments
  num_train_epochs: 3
  eval_strategy: steps
  eval_steps: 50
  logging_steps: 5
  output_dir: training_output
  per_device_train_batch_size: 2
  per_device_eval_batch_size: 4096
  learning_rate: 5e-3
  weight_decay: 0.0
  fp16: true

hyperopt:
  n_trials: 128
  patience: 2
  persistence: false
  hp_space:
    training:
      - name: learning_rate # TrainingArguments attribute name
        type: float
        low: 5e-5
        high: 5e-3
        step: 1e-5
        log: true
    model:
      - name: d_model # model parameters
        type: int
        low: 128
        high: 512
        step: 128
        log: true

Hyperparameter Optimization

Enable hyperparameter optimization by setting do_hyperoptim: true in your config. The package uses Optuna for hyperparameter optimization with support for:

  • Integer parameters
  • Float parameters
  • Categorical parameters
  • Persistent storage with a relational database
  • Early stopping with patient pruning

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hydra_trainer-0.1.0.tar.gz (5.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hydra_trainer-0.1.0-py3-none-any.whl (5.7 kB view details)

Uploaded Python 3

File details

Details for the file hydra_trainer-0.1.0.tar.gz.

File metadata

  • Download URL: hydra_trainer-0.1.0.tar.gz
  • Upload date:
  • Size: 5.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for hydra_trainer-0.1.0.tar.gz
Algorithm Hash digest
SHA256 10bdcd71e2f207a10810fc6353f3026f614bb57a8957910b3704d541b3d19a5b
MD5 50a06a55d0d649a7d7bb919701259c5a
BLAKE2b-256 501bb1eb7817b07a96488c338f11b56bb64137f7c0509b3b200527e1fcb80b98

See more details on using hashes here.

Provenance

The following attestation bundles were made for hydra_trainer-0.1.0.tar.gz:

Publisher: release.yml on emapco/HydraTrainer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file hydra_trainer-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: hydra_trainer-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 5.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for hydra_trainer-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 91c1fd11fa60e4c75a2116072acc27410ed8c078fc187a128a764c7e0bcee38e
MD5 bfc9c8a1b9f47f8218c53a31c0e5f2d1
BLAKE2b-256 f1c733e491f55c4e6020148dc90bb2e8d6ae7178d9c085aa5acac829436014af

See more details on using hashes here.

Provenance

The following attestation bundles were made for hydra_trainer-0.1.0-py3-none-any.whl:

Publisher: release.yml on emapco/HydraTrainer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page