Skip to main content

A HF Trainer Implementation for Multitask Training Logs

Project description

Enhanced Multitask Trainer for Separately Reporting Task's Metrics or Losses in HuggingFace Transformers

The HuggingFace transformers library is widely used for model training. For example, to adapt a pretrained BERT model to a specific task domain, we often continue pretraining the model with two tasks in BERT: 1) Next Sentence Prediction (NSP) using the [CLS] token, and 2) Masked Language Modeling (MLM) using masked tokens.

A key issue is that the default Trainer in transformers assumes the first element of the output is the final loss to minimize. The loss returned by the forward method must be a scalar, so when training a multitask model like BERT, the loss needs to be combined.

The Trainer class offers command-line arguments to control the training process. However, it only provides a combined loss value for all tasks, which obscures the individual losses of each task. This makes it challenging to monitor training and debug different task settings. Additionally, the Tensorboard report only shows the combined loss in its metrics.

To facilitate multitask model training and review the loss of each task, as well as other training metrics, this trainer implementation is simple and useful.

The trainer works like the original Trainer in the transformers library. You just need to call the report_metrics(...) method to report the metrics that are important to you.

By the way, another utility you might need is parser-binding, which builds argument parsers from dataclasses and reads the arguments from command line scripts.

Usage

Follow these steps to use the HfMultiTaskTrainer:

  1. Install the trainer:

    pip install hf-mtask-trainer
    
  2. Replace the default trainer with HfMultiTaskTrainer:

    from hf_mtask_trainer import HfMultiTaskTrainer
    
    class Trainer(HfMultiTaskTrainer):
        def __init__(...):
            super().__init__(...)
            # Additional initialization code
    

    Alternatively, you can directly instantiate the HfMultiTaskTrainer:

    trainer = HfMultiTaskTrainer(...)
    
  3. Report metrics in the model:

    import torch.nn as nn
    
    class Model(nn.Module):
    
        supports_report_metrics: bool = True
    
        def __init__(...):
            super().__init__(...)
            # Additional initialization code
        
        def forward(self, inputs, ...):
            # Calculate metrics like loss, accuracy, etc.
            task1_loss = ...
            task2_loss = ...
            acc = ...
            f1 = ...
            # Report the metrics
            self.report_metrics(loss1=task1_loss, loss2=task2_loss, acc=acc, f1=f1)
    

    Add a flag supports_report_metrics where you need to report metrics, otherwise, the report_metrics would be not accessible.

  4. Start training the model:

    As usual, call trainer.train() to start training.

Now you can enjoy multitask training. If you set --report tensorboard, the metrics reported in the model will be displayed in Tensorboard diagrams.

Demo

We give a simple demo to mock a multi-task training in test_trainer.py.

The source code is:

import random

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
from transformers.hf_argparser import HfArgumentParser
from transformers.training_args import TrainingArguments

from hf_mtask_trainer import HfMultiTaskTrainer

# The model class
class TestModel(nn.Module):
    supports_report_metrics: bool = True # IMPORTANT

    def __init__(self, ) -> None:
        super().__init__()
        self.scaler = nn.Parameter(torch.ones(1))

    def forward(self, x):
        test_tensor = x + self.scaler
        test_np = np.array(np.random.randn()).astype(np.float32)
        test_int = random.randint(1, 100)
        test_float = random.random()
        if hasattr(self, report_metrics): # checking if the report method is accessible or not is the robust practice
            self.report_metrics(
                tensor=test_tensor,
                np=test_np,
                integer=test_int,
                fp_num=test_float
            )

        loss = ((
            test_tensor + torch.from_numpy(test_np) + torch.tensor(test_int) +
            torch.tensor(test_float) - 0
        )).mean()

        outputs = (loss, )

        return outputs

# Mock dataset
class MockDataset(Dataset):

    def __len__(self):
        return 1000

    def __getitem__(self, index: int):
        return dict(x=torch.randn(10, dtype=torch.float32))


def main():
    parser = HfArgumentParser(TrainingArguments)
    args, = parser.parse_args_into_dataclasses()
    model = TestModel()
    ds = MockDataset()
    # Use HfMultiTaskTrainer rather than Trainer
    trainer = HfMultiTaskTrainer(model, args, train_dataset=ds)

    trainer.train()


if __name__ == '__main__':
    main()

Run the script to start training: python test_trainer.py --output_dir ./test-output --per_device_train_batch_size 8 --gradient_accumulation_steps 4 --logging_steps 10 --num_train_epochs 10.

The progress in the terminal is:

{'loss': 55.509, 'grad_norm': 1.0, 'learning_rate': 4.8387096774193554e-05, 'tensor': 0.9841784507036209, 'np': -0.21863683552946894, 'integer': 54.3, 'fp_num': 0.4434714410935673, 'epoch': 0.32}               
{'loss': 48.2661, 'grad_norm': 1.0, 'learning_rate': 4.67741935483871e-05, 'tensor': 0.9985833063721656, 'np': -0.02904345905408263, 'integer': 46.725, 'fp_num': 0.5715085473120125, 'epoch': 0.64}              
{'loss': 46.4612, 'grad_norm': 1.0, 'learning_rate': 4.516129032258064e-05, 'tensor': 0.9966234847903251, 'np': 0.010173140384722501, 'integer': 44.95, 'fp_num': 0.5043715258219943, 'epoch': 0.96}              
{'loss': 48.6079, 'grad_norm': 1.0, 'learning_rate': 4.3548387096774194e-05, 'tensor': 0.9955430060625077, 'np': -0.03289987128227949, 'integer': 47.175, 'fp_num': 0.47028585139293366, 'epoch': 1.28}           
{'loss': 50.091, 'grad_norm': 1.0, 'learning_rate': 4.1935483870967746e-05, 'tensor': 0.9734495922923088, 'np': 0.06655221048276871, 'integer': 48.55, 'fp_num': 0.5009696474466848, 'epoch': 1.6}                
{'loss': 52.1638, 'grad_norm': 1.0, 'learning_rate': 4.032258064516129e-05, 'tensor': 1.0023577958345413, 'np': 0.18944044597446918, 'integer': 50.5, 'fp_num': 0.47205086657725437, 'epoch': 1.92}               
{'loss': 61.3063, 'grad_norm': 1.0, 'learning_rate': 3.870967741935484e-05, 'tensor': 1.0168104887008667, 'np': -0.10900555825792253, 'integer': 60.0, 'fp_num': 0.39849607236524115, 'epoch': 2.24}              
{'loss': 55.318, 'grad_norm': 1.0, 'learning_rate': 3.7096774193548386e-05, 'tensor': 1.015606315433979, 'np': 0.21950888196006418, 'integer': 53.575, 'fp_num': 0.5078790131376146, 'epoch': 2.56}               
{'loss': 57.1703, 'grad_norm': 1.0, 'learning_rate': 3.548387096774194e-05, 'tensor': 1.0161942049860955, 'np': -0.08120755353011191, 'integer': 55.675, 'fp_num': 0.5603439507938002, 'epoch': 2.88}             
{'loss': 47.6687, 'grad_norm': 1.0, 'learning_rate': 3.387096774193548e-05, 'tensor': 0.9780291050672532, 'np': 0.21060471932869404, 'integer': 46.025, 'fp_num': 0.4550899259063651, 'epoch': 3.2}               
{'loss': 50.6742, 'grad_norm': 1.0, 'learning_rate': 3.2258064516129034e-05, 'tensor': 0.9773322150111199, 'np': 0.053728557180147615, 'integer': 49.15, 'fp_num': 0.4931880990797102, 'epoch': 3.52}             
{'loss': 55.3104, 'grad_norm': 1.0, 'learning_rate': 3.0645161290322585e-05, 'tensor': 0.962137694656849, 'np': -0.079732296615839, 'integer': 53.975, 'fp_num': 0.45303205544101893, 'epoch': 3.84}              
{'loss': 55.3539, 'grad_norm': 1.0, 'learning_rate': 2.9032258064516133e-05, 'tensor': 1.0214665666222573, 'np': -0.15776186664588748, 'integer': 53.9, 'fp_num': 0.590140296440284, 'epoch': 4.16}               
{'loss': 49.332, 'grad_norm': 1.0, 'learning_rate': 2.7419354838709678e-05, 'tensor': 1.0191335454583168, 'np': -0.2712035422213376, 'integer': 48.025, 'fp_num': 0.5590896723075907, 'epoch': 4.48}              
{'loss': 49.8865, 'grad_norm': 1.0, 'learning_rate': 2.5806451612903226e-05, 'tensor': 1.0170967370271682, 'np': 0.02669397685676813, 'integer': 48.275, 'fp_num': 0.5677363725430722, 'epoch': 4.8}              
{'loss': 55.0644, 'grad_norm': 1.0, 'learning_rate': 2.4193548387096777e-05, 'tensor': 0.99910968542099, 'np': 0.12097712438553572, 'integer': 53.475, 'fp_num': 0.4693036682925622, 'epoch': 5.12}               
{'loss': 56.9469, 'grad_norm': 1.0, 'learning_rate': 2.258064516129032e-05, 'tensor': 1.0159066557884215, 'np': -0.06122639870736748, 'integer': 55.6, 'fp_num': 0.3922143274213026, 'epoch': 5.44}               
{'loss': 58.3238, 'grad_norm': 1.0, 'learning_rate': 2.0967741935483873e-05, 'tensor': 0.9946490600705147, 'np': -0.038768217992037536, 'integer': 56.875, 'fp_num': 0.49290766579450906, 'epoch': 5.76}          
{'loss': 57.8349, 'grad_norm': 1.0, 'learning_rate': 1.935483870967742e-05, 'tensor': 0.9948656186461449, 'np': -0.15342782847583294, 'integer': 56.55, 'fp_num': 0.4434852700815277, 'epoch': 6.08}              
{'loss': 57.5093, 'grad_norm': 1.0, 'learning_rate': 1.774193548387097e-05, 'tensor': 0.9814934283494949, 'np': 0.17727854922413827, 'integer': 55.85, 'fp_num': 0.5005189062297719, 'epoch': 6.4}                
{'loss': 54.0808, 'grad_norm': 1.0, 'learning_rate': 1.6129032258064517e-05, 'tensor': 1.0003552585840225, 'np': 0.09905800204724073, 'integer': 52.425, 'fp_num': 0.5563636991813741, 'epoch': 6.72}             
{'loss': 41.9312, 'grad_norm': 1.0, 'learning_rate': 1.4516129032258066e-05, 'tensor': 0.9884074732661248, 'np': 0.1483861011918634, 'integer': 40.275, 'fp_num': 0.51941084196083, 'epoch': 7.04}                
{'loss': 54.1181, 'grad_norm': 1.0, 'learning_rate': 1.2903225806451613e-05, 'tensor': 1.0151973858475685, 'np': 0.47866107723675666, 'integer': 52.175, 'fp_num': 0.4492639089144623, 'epoch': 7.36}             
{'loss': 50.6587, 'grad_norm': 1.0, 'learning_rate': 1.129032258064516e-05, 'tensor': 0.9820004492998123, 'np': -0.012274338398128748, 'integer': 49.2, 'fp_num': 0.4889366814531261, 'epoch': 7.68}              
{'loss': 55.0801, 'grad_norm': 1.0, 'learning_rate': 9.67741935483871e-06, 'tensor': 0.9795809179544449, 'np': -0.07257360897492618, 'integer': 53.725, 'fp_num': 0.448120297698113, 'epoch': 8.0}                
{'loss': 44.1352, 'grad_norm': 1.0, 'learning_rate': 8.064516129032258e-06, 'tensor': 0.9734664395451545, 'np': 0.286221909429878, 'integer': 42.375, 'fp_num': 0.5005484995389671, 'epoch': 8.32}                
{'loss': 66.0453, 'grad_norm': 1.0, 'learning_rate': 6.451612903225806e-06, 'tensor': 0.9795126229524612, 'np': 0.030494442163035273, 'integer': 64.525, 'fp_num': 0.5103026267522799, 'epoch': 8.64}             
{'loss': 56.4957, 'grad_norm': 1.0, 'learning_rate': 4.838709677419355e-06, 'tensor': 0.9856317490339279, 'np': 0.2455663602799177, 'integer': 54.775, 'fp_num': 0.48952095056963413, 'epoch': 8.96}              
{'loss': 58.896, 'grad_norm': 1.0, 'learning_rate': 3.225806451612903e-06, 'tensor': 0.9927483782172203, 'np': 0.14382120433729143, 'integer': 57.275, 'fp_num': 0.4843773558505675, 'epoch': 9.28}               
{'loss': 51.3854, 'grad_norm': 1.0, 'learning_rate': 1.6129032258064516e-06, 'tensor': 0.974456462264061, 'np': -0.03793883747421205, 'integer': 49.975, 'fp_num': 0.4738723365026173, 'epoch': 9.6}              
{'loss': 51.1056, 'grad_norm': 1.0, 'learning_rate': 0.0, 'tensor': 1.0064959138631822, 'np': 0.17969400193542243, 'integer': 49.375, 'fp_num': 0.5444181010304485, 'epoch': 9.92}                                
{'train_runtime': 0.4796, 'train_samples_per_second': 20852.484, 'train_steps_per_second': 646.427, 'train_loss': 53.31389662219632, 'epoch': 9.92}                                                               
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 310/310 [00:00<00:00, 648.81it/s]

Limitation

This trainer has not been fully tested yet but works for simple multitask training. Please report any issues if this plugin does not work for you.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hf_mtask_trainer-0.0.2a2.tar.gz (13.2 kB view details)

Uploaded Source

Built Distribution

hf_mtask_trainer-0.0.2a2-py3-none-any.whl (11.8 kB view details)

Uploaded Python 3

File details

Details for the file hf_mtask_trainer-0.0.2a2.tar.gz.

File metadata

  • Download URL: hf_mtask_trainer-0.0.2a2.tar.gz
  • Upload date:
  • Size: 13.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for hf_mtask_trainer-0.0.2a2.tar.gz
Algorithm Hash digest
SHA256 71903ea9f32a907f2dd64fdbcd71036e9c488e72f3d5468da3bf3e402e9a838f
MD5 ef8eb1b9a6cdf38ff06249e65551bb1c
BLAKE2b-256 17b28519dbb59bcbd31870fe42cdf5582f7d63c231d55c8fc0d27937297657aa

See more details on using hashes here.

File details

Details for the file hf_mtask_trainer-0.0.2a2-py3-none-any.whl.

File metadata

File hashes

Hashes for hf_mtask_trainer-0.0.2a2-py3-none-any.whl
Algorithm Hash digest
SHA256 41d817fbcc1eca503831782f25db642b9f352cf367bea46cf4535aa71a41b749
MD5 ce434df0ff3da3ed27fd7c3f0f964fdd
BLAKE2b-256 5105646f9780fa57fd3e37e60dbdc003d65c806b07e7d277e13853eff5b0b0a2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page