Skip to main content

FMS Acceleration plugin for online data mixing

Project description

Online Data Mixing

This library contains plugin for online dynamic reward (learnable) based data mixing framework that operates on dynamically mixing datasets online during training while being adapted based on the signals (e.g. training loss, gradnorm etc) from training.

Plugins

Plugin Description Depends Loading Augmentation Callbacks
odm OnlineMixingDataset PyTorch IterableDataset and custom rewards

Design

Usage in Custom Training Loop

OnlineMixingDataset can be imported easily and integrated into existing training loops with minimal changes. A sample custom training loop implementation can be found here. Given code sample uses two instruction tuning datasets and trains ibm-granite/granite-3.1-2b-instruct model for next token prediction task.

Automatic Categorization

When only a single dataset (without category splits) is passed, the dataset will be embedded with a sentence-transformer model and clustered (K-Means by default) to build pseudo categories used by the online data mixer.

from datasets import load_dataset
from fms_acceleration_odm import OnlineMixingDataset

dataset = load_dataset("tatsu-lab/alpaca", split="train[:1%]")
collator = ...  # e.g., DataCollatorForLanguageModeling(...)

odm_dataset = OnlineMixingDataset(
    dataset_dict=dataset,
    collators_dict={"train": collator},
    eval_dataset_dict={},
    eval_collators_dict={},
    auto_categorize_config={
        "input_column": "text",
        "num_categories": 6,
        "model_name": "sentence-transformers/all-MiniLM-L6-v2",
    },
)

Without an explicit num_categories, a heuristic based on the square root of the dataset size is used. Additional knobs such as category_prefix, batch_size, or clustering-specific kwargs can also be provided through auto_categorize_config.

Metrics

All metrics related to the online data mixing will be logged to odm.jsonl file in the checkpoint output directory.

Metric Description
samples_produced_so_far Total samples produced by the dataset so far at the time of logging.
sampling_interval Takes sample count "n" as input. At every "n" steps category/dataset chosen by weighted random sampling where weights are provided by the Multi-Armed Bandit algorithm.
total_categories Total categories or datasets involved in mixing.
current_sampling_weights Current state of the sampling weights at the time of logging.
current_sampling_ratio Current state of the sampling ratios at the time of logging.
arm_idx Last sampled category index. Categories/datasets are sorted in ascending order based on their names and index starts from 0 and each index corresponds to respective category/dataset.
category_level_counts_so_far Split of sample count across datasets so far at the time of logging.
rewards State of the rewards at the time of logging. Essentially are the last provided rewards across datasets.
action Type of action took place at the time logging. It is either "update" or "sample" which correspond to weight update of the MAB algorithm or category sampling.

Rewards

Below are the currently available rewards and we are constantly looking to improve the existing rewards and also add new ones. Further, we encourage users to identify rewards that can help their usecases.

Rewards Description
ENTROPY Calculation of shannon entropy of the logits averaged across all the tokens. Higher entropy would mean model requires more samples from that datasets/category.
ENTROPY3_VARENT1 3 parts of shannon entropy and 1 part of variance of the entropy. Higher values mean requirement of more samples.
ENTROPY_LAST_TOKEN Shannon entropy of the last token in the sample. Higher values mean requirement of more samples.
TRAIN_LOSS Training loss where loss is maintained across categories and is updated based on the latest loss and sampled dataset/category. Higher values mean requirement of more samples.
VALIDATION_LOSS Validation loss across categories calculated using evaluation datasets from each of the categories. Higher values mean requirement of more samples.
GRADNORM Gradient norm where norms are maintained across categories and are updated based on the latest values and sampled dataset/category. Higher values mean reducing samples from that particular dataset/category.

Adding a Custom Reward

Custom rewards can be added to the compute_reward function and adding it to the Reward enum. If the custom reward requires specific set of information from the training loop then _extract_information_from_state_for_reward function has to be extended for extracting such information from trainer state. This is member function of OnlineMixingDataset.

Planned TODOs

Please see issue #153.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fms_acceleration_odm-0.1.5-py3-none-any.whl (20.2 kB view details)

Uploaded Python 3

File details

Details for the file fms_acceleration_odm-0.1.5-py3-none-any.whl.

File metadata

File hashes

Hashes for fms_acceleration_odm-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 d307215caee9716ba4295331eb10f0e86f7caf576f071e747aeb82d9a73e33b2
MD5 546079d937fd7b1136dfc2f8d5ef1613
BLAKE2b-256 2c574e1338f281f8aa1f99e763ed3ace77d62741dd02274777afafdf19543b26

See more details on using hashes here.

Provenance

The following attestation bundles were made for fms_acceleration_odm-0.1.5-py3-none-any.whl:

Publisher: build-and-publish.yml on foundation-model-stack/fms-acceleration

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page