Skip to main content

Hugging Face transformers with Hydra integration

Project description

Hydra Trainer

A package that wraps Hugging Face's Transformer Trainer with Hydra integration for better configuration management and hyperparameter optimization support.

Features

  • Hydra configuration management
  • Optuna hyperparameter optimization integration
  • Easy-to-extend base classes for custom datasets and trainers
  • Specify TrainingArguments or hyperparameter search parameters within a hydra configuration file
    • An example config, base.yaml, is provided in this package.

Installation

pip install hydra-trainer

Quick Start

  1. Create your dataset class by extending BaseDataset or use any dataset that extends datasets.Dataset:
from typing import Literal
from omegaconf import DictConfig
from hydra_trainer import BaseDataset

class ExampleDataset(BaseDataset):
    def __init__(self, cfg: DictConfig, dataset_key: Literal["train", "eval"]):
        super().__init__(cfg)
        self.dataset_key = dataset_key
        # TODO: implement dataset loading and preprocessing
        raise NotImplementedError

    def __len__(self):
        # TODO: implement this method
        raise NotImplementedError

    def __getitem__(self, idx):
        # TODO: implement this method
        raise NotImplementedError
  1. Create your trainer class by extending BaseTrainer:
from typing import Literal

import optuna
from omegaconf import DictConfig

from hydra_trainer import BaseTrainer


class ExampleTrainer(BaseTrainer[ExampleDataset]):
    def model_init_factory(self):
        def model_init(trial: optuna.Trial | None = None):
            model_cfg = self.get_trial_model_cfg(trial, self.cfg)
            # TODO: implement model initialization
            raise NotImplementedError

        return model_init

    def dataset_factory(
        self, dataset_cfg: DictConfig, dataset_key: Literal["train", "eval"]
    ) -> ExampleDataset:
        # TODO: implement this method
        raise NotImplementedError
  1. Set up your training script with Hydra:
import hydra
from omegaconf import DictConfig

@hydra.main(config_path="hydra_trainer", config_name="base", version_base=None)
def main(cfg: DictConfig):
    trainer = ExampleTrainer(cfg)
    trainer.train()

if __name__ == "__main__":
    main()

BaseTrainer Key Features

  1. Model Initialization Factory: Implement model_init_factory() to define how your model is created.
  2. Dataset Factory: Implement dataset_factory() to create your training and evaluation datasets
  3. Early Stopping: Built-in early stopping support with configurable patience

Configuration

The package uses Hydra for configuration management. Here's the base configuration structure:

seed: 42
checkpoint_path: null
resume_from_checkpoint: null
do_hyperoptim: false
early_stopping_patience: 3

model: # model parameters - access them within `model_init_factory` implementation
  d_model: 128
  n_layers: 12
  n_heads: 16
  d_ff: 512

trainer: #  transformers.TrainingArguments
  num_train_epochs: 3
  eval_strategy: steps
  eval_steps: 50
  logging_steps: 5
  output_dir: training_output
  per_device_train_batch_size: 2
  per_device_eval_batch_size: 4096
  learning_rate: 5e-3
  weight_decay: 0.0
  fp16: true

hyperopt:
  n_trials: 128
  patience: 2
  persistence: false
  hp_space:
    training:
      - name: learning_rate # TrainingArguments attribute name
        type: float
        low: 5e-5
        high: 5e-3
        step: 1e-5
        log: true
    model:
      - name: d_model # model parameters
        type: int
        low: 128
        high: 512
        step: 128
        log: true

Hyperparameter Optimization

Enable hyperparameter optimization by setting do_hyperoptim: true in your config. The package uses Optuna for hyperparameter optimization with support for:

  • Integer parameters
  • Float parameters
  • Categorical parameters
  • Persistent storage with a relational database
  • Early stopping with patient pruning

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hydra_trainer-0.1.1.tar.gz (5.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hydra_trainer-0.1.1-py3-none-any.whl (5.7 kB view details)

Uploaded Python 3

File details

Details for the file hydra_trainer-0.1.1.tar.gz.

File metadata

  • Download URL: hydra_trainer-0.1.1.tar.gz
  • Upload date:
  • Size: 5.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for hydra_trainer-0.1.1.tar.gz
Algorithm Hash digest
SHA256 c0c2289420b7fd84498c18e04c90e5f3d79ddfb2df2b642507cd01d09cccfa3d
MD5 d72a979941b48b5436a1883022669bfd
BLAKE2b-256 676f3b09c3d79050ba304506d6bae4da34d5b3d3d95f2822667a6f90fd0bc368

See more details on using hashes here.

Provenance

The following attestation bundles were made for hydra_trainer-0.1.1.tar.gz:

Publisher: release.yml on emapco/HydraTrainer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file hydra_trainer-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: hydra_trainer-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 5.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for hydra_trainer-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a55fd43da63a91f59c4521093747ecbd6d261d355c76b8a2e0fe14612f5ff250
MD5 fb0a9fee8b4d1ec88d2e89573074c2c3
BLAKE2b-256 6637e29f06871bd4fc35cdae041ba6d305dfc20591c1f3f0b2c021ae0bf8ef0e

See more details on using hashes here.

Provenance

The following attestation bundles were made for hydra_trainer-0.1.1-py3-none-any.whl:

Publisher: release.yml on emapco/HydraTrainer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page