FMS Acceleration plugin for online data mixing
Project description
Online Data Mixing
This library contains plugin for online dynamic reward (learnable) based data mixing framework that operates on dynamically mixing datasets online during training while being adapted based on the signals (e.g. training loss, gradnorm etc) from training.
Plugins
| Plugin | Description | Depends | Loading | Augmentation | Callbacks |
|---|---|---|---|---|---|
| odm | OnlineMixingDataset PyTorch IterableDataset and custom rewards | ✅ | ✅ | ✅ |
Design
Usage in Custom Training Loop
OnlineMixingDataset can be imported easily and integrated into existing training loops with minimal changes. A sample custom training loop implementation can be found here. Given code sample uses two instruction tuning datasets and trains ibm-granite/granite-3.1-2b-instruct model for next token prediction task.
Automatic Categorization
When only a single dataset (without category splits) is passed, the dataset will be embedded with a sentence-transformer model and clustered (K-Means by default) to build pseudo categories used by the online data mixer.
from datasets import load_dataset
from fms_acceleration_odm import OnlineMixingDataset
dataset = load_dataset("tatsu-lab/alpaca", split="train[:1%]")
collator = ... # e.g., DataCollatorForLanguageModeling(...)
odm_dataset = OnlineMixingDataset(
dataset_dict=dataset,
collators_dict={"train": collator},
eval_dataset_dict={},
eval_collators_dict={},
auto_categorize_config={
"input_column": "text",
"num_categories": 6,
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
},
)
Without an explicit num_categories, a heuristic based on the square root of the dataset size is used. Additional knobs such as category_prefix, batch_size, or clustering-specific kwargs can also be provided through auto_categorize_config.
Metrics
All metrics related to the online data mixing will be logged to odm.jsonl file in the checkpoint output directory.
| Metric | Description |
|---|---|
samples_produced_so_far |
Total samples produced by the dataset so far at the time of logging. |
sampling_interval |
Takes sample count "n" as input. At every "n" steps category/dataset chosen by weighted random sampling where weights are provided by the Multi-Armed Bandit algorithm. |
total_categories |
Total categories or datasets involved in mixing. |
current_sampling_weights |
Current state of the sampling weights at the time of logging. |
current_sampling_ratio |
Current state of the sampling ratios at the time of logging. |
arm_idx |
Last sampled category index. Categories/datasets are sorted in ascending order based on their names and index starts from 0 and each index corresponds to respective category/dataset. |
category_level_counts_so_far |
Split of sample count across datasets so far at the time of logging. |
rewards |
State of the rewards at the time of logging. Essentially are the last provided rewards across datasets. |
action |
Type of action took place at the time logging. It is either "update" or "sample" which correspond to weight update of the MAB algorithm or category sampling. |
Rewards
Below are the currently available rewards and we are constantly looking to improve the existing rewards and also add new ones. Further, we encourage users to identify rewards that can help their usecases.
| Rewards | Description |
|---|---|
ENTROPY |
Calculation of shannon entropy of the logits averaged across all the tokens. Higher entropy would mean model requires more samples from that datasets/category. |
ENTROPY3_VARENT1 |
3 parts of shannon entropy and 1 part of variance of the entropy. Higher values mean requirement of more samples. |
ENTROPY_LAST_TOKEN |
Shannon entropy of the last token in the sample. Higher values mean requirement of more samples. |
TRAIN_LOSS |
Training loss where loss is maintained across categories and is updated based on the latest loss and sampled dataset/category. Higher values mean requirement of more samples. |
VALIDATION_LOSS |
Validation loss across categories calculated using evaluation datasets from each of the categories. Higher values mean requirement of more samples. |
GRADNORM |
Gradient norm where norms are maintained across categories and are updated based on the latest values and sampled dataset/category. Higher values mean reducing samples from that particular dataset/category. |
Adding a Custom Reward
Custom rewards can be added to the compute_reward function and adding it to the Reward enum. If the custom reward requires specific set of information from the training loop then _extract_information_from_state_for_reward function has to be extended for extracting such information from trainer state. This is member function of OnlineMixingDataset.
Planned TODOs
Please see issue #153.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file fms_acceleration_odm-0.1.5-py3-none-any.whl.
File metadata
- Download URL: fms_acceleration_odm-0.1.5-py3-none-any.whl
- Upload date:
- Size: 20.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d307215caee9716ba4295331eb10f0e86f7caf576f071e747aeb82d9a73e33b2
|
|
| MD5 |
546079d937fd7b1136dfc2f8d5ef1613
|
|
| BLAKE2b-256 |
2c574e1338f281f8aa1f99e763ed3ace77d62741dd02274777afafdf19543b26
|
Provenance
The following attestation bundles were made for fms_acceleration_odm-0.1.5-py3-none-any.whl:
Publisher:
build-and-publish.yml on foundation-model-stack/fms-acceleration
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
fms_acceleration_odm-0.1.5-py3-none-any.whl -
Subject digest:
d307215caee9716ba4295331eb10f0e86f7caf576f071e747aeb82d9a73e33b2 - Sigstore transparency entry: 872121086
- Sigstore integration time:
-
Permalink:
foundation-model-stack/fms-acceleration@b93674fe66135ef78f7f2c3d0a69bc65ee53c63e -
Branch / Tag:
refs/tags/v0.6.4 - Owner: https://github.com/foundation-model-stack
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
build-and-publish.yml@b93674fe66135ef78f7f2c3d0a69bc65ee53c63e -
Trigger Event:
release
-
Statement type: