Skip to main content

Hugging Face transformers with Hydra integration

Project description

Hydra Trainer

A package that wraps Hugging Face's Transformer Trainer with Hydra integration for better configuration management and hyperparameter optimization support.

Checkout my powerformer repo for a concrete example.

Features

  • Hydra configuration management
  • Optuna hyperparameter optimization integration
  • Easy-to-extend base classes for custom datasets and trainers
  • Specify TrainingArguments or hyperparameter search parameters within a hydra configuration file
    • An example config, base.yaml, is provided in this package.

Installation

pip install hydra-trainer

Quick Start

  1. Create your dataset class by extending BaseDataset or use any dataset that extends datasets.Dataset:
from typing import Literal
from omegaconf import DictConfig
from hydra_trainer import BaseDataset

class ExampleDataset(BaseDataset):
    def __init__(self, cfg: DictConfig, dataset_key: Literal["train", "eval"]):
        super().__init__(cfg)
        self.dataset_key = dataset_key
        # TODO: implement dataset loading and preprocessing
        raise NotImplementedError

    def __len__(self):
        # TODO: implement this method
        raise NotImplementedError

    def __getitem__(self, idx):
        # TODO: implement this method
        raise NotImplementedError
  1. Create your trainer class by extending BaseTrainer:
from typing import Literal

import optuna
from omegaconf import DictConfig

from hydra_trainer import BaseTrainer


class ExampleTrainer(BaseTrainer[ExampleDataset, DictConfig]):
    def model_init_factory(self):
        def model_init(trial: optuna.Trial | None = None):
            model_cfg = self.get_trial_model_cfg(trial, self.cfg)
            # TODO: implement model initialization
            raise NotImplementedError

        return model_init

    def dataset_factory(
        self, dataset_cfg: DictConfig, dataset_key: Literal["train", "eval"]
    ) -> ExampleDataset:
        # TODO: implement this method
        raise NotImplementedError
  1. Set up your training script with Hydra:
import hydra
from omegaconf import DictConfig

@hydra.main(config_path="hydra_trainer", config_name="base", version_base=None)
def main(cfg: DictConfig):
    trainer = ExampleTrainer(cfg)
    trainer.train()

if __name__ == "__main__":
    main()

BaseTrainer Key Features

  1. Model Initialization Factory: Implement model_init_factory() to define how your model is created.
  2. Dataset Factory: Implement dataset_factory() to create your training and evaluation datasets

Configuration

The package uses Hydra for configuration management. Here's the base configuration structure:

seed: 42
checkpoint_path: null
resume_from_checkpoint: null
do_hyperoptim: false
early_stopping_patience: 3

model: # model parameters - access them within `model_init_factory` implementation
  d_model: 128
  n_layers: 12
  n_heads: 16
  d_ff: 512

trainer: #  transformers.TrainingArguments
  num_train_epochs: 3
  eval_strategy: steps
  eval_steps: 50
  logging_steps: 5
  output_dir: training_output
  per_device_train_batch_size: 2
  per_device_eval_batch_size: 4096
  learning_rate: 5e-3
  weight_decay: 0.0
  fp16: true

hyperopt:
  n_trials: 128
  patience: 2
  persistence: true # set to false to use in memory storage instead of db storage
  load_if_exists: true
  storage_url: postgresql://postgres:password@127.0.0.1:5432/postgres
  storage_heartbeat_interval: 15
  storage_engine_kwargs:
    pool_size: 5
    connect_args:
      keepalives: 1
  hp_space:
    training:
      - name: learning_rate # TrainingArguments attribute name
        type: float
        low: 5e-5
        high: 5e-3
        step: 1e-5
        log: true
    model:
      - name: d_model # model parameters
        type: int
        low: 128
        high: 512
        step: 128
        log: true

Hyperparameter Optimization

Enable hyperparameter optimization by setting do_hyperoptim: true in your config. The package uses Optuna for hyperparameter optimization with support for:

  • Integer parameters
  • Float parameters
  • Categorical parameters
  • Persistent storage with a relational database
  • Early stopping with patient pruning

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hydra_trainer-0.1.2.tar.gz (5.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hydra_trainer-0.1.2-py3-none-any.whl (5.9 kB view details)

Uploaded Python 3

File details

Details for the file hydra_trainer-0.1.2.tar.gz.

File metadata

  • Download URL: hydra_trainer-0.1.2.tar.gz
  • Upload date:
  • Size: 5.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for hydra_trainer-0.1.2.tar.gz
Algorithm Hash digest
SHA256 17359d1bb46be5bafaf4cbd87f89e5102336705ccc5f5658864a0e1e1b2b8a46
MD5 3823403d075c83aa19c066587d4aa38e
BLAKE2b-256 80cb099a7ea1a8f87a447c0c41142b3f4c3795304ed789e9d23328ffac75a940

See more details on using hashes here.

Provenance

The following attestation bundles were made for hydra_trainer-0.1.2.tar.gz:

Publisher: release.yml on emapco/HydraTrainer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file hydra_trainer-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: hydra_trainer-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 5.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for hydra_trainer-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a600e7642ac80e0918d282216f49775302d7139f41e869aed2847da90a32c593
MD5 d62001c7013bb6768013ffd218274e4e
BLAKE2b-256 bf0c871c08f29b23f5d1cfc9d2d59b14f06060286452ade87ec089e2057e6b99

See more details on using hashes here.

Provenance

The following attestation bundles were made for hydra_trainer-0.1.2-py3-none-any.whl:

Publisher: release.yml on emapco/HydraTrainer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page