Modelling Training Dynamics and Interpreting the dynamics
Project description
Latent State Models of Training Dynamics
Directly model training dynamics, then interpret the dynamics model.
Setup
For using the package, directly clone the repo and install the package in editable mode
git clone https://github.com/michahu/visualizing-training.git
pip install -e .
Usage
Below set of commands will walk you through the usage of the package, demo_run notebook can be referred for a notebook version of the same.
Step 0: Config Setup.
Setting up config required to collect training data and evaluation of the model
num_epochs = 100
lr = 1e-3
train_bsz = 256
cpu = True
weight_decay = 1.0
eval_every = 10
n_heads = 4
init_scaling = 1.0
optimizer = "adamw" # type of optimizer
output_dir = 'modular_demo_run' # output directory for writing collected metrics
seed = 0
use_ln = True
test_bsz = 2048
clip_grad = True
dataset_name = "modular" # specifying the type of dataset to be trained on
n_seeds = 4 # no of seeds for which the data needs to be collected
config = {
"lr": lr,
"cpu": cpu,
"num_epochs": num_epochs,
"eval_every": eval_every,
"run_output_dir": output_dir,
"use_accelerator": False,
"init_scaling": init_scaling,
"optimizer": optimizer,
"weight_decay": weight_decay,
"clip_grad": clip_grad,
"is_transformer": True,
"train_bsz": train_bsz,
"test_bsz": test_bsz,
"n_heads": n_heads,
"dataset_name": dataset_name
}
Step 1: Training a model and collecting metrics
from src.training.modular_addition import get_dataloaders
from src.model import Transformer
from src.train import ModelManager
for seed in range(n_seeds): # Saving the metrics for all the seeds
config["seed"] = seed
model = Transformer(
d_model=128,
d_head=32,
d_vocab=114,
num_heads=n_heads,
num_layers=1,
n_ctx=3,
use_ln=use_ln,
)
train_loader, test_loader = get_dataloaders(train_bsz=train_bsz, test_bsz=test_bsz)
#Initialize the ModelManager class which will take care of training and collecting the metrics
model_manager = ModelManager(model,train_loader,test_loader, config)
# Specify the layers where hooks needs to be attached
model_manager.attach_hooks(['blocks.0.mlp','blocks.0.attn'])
# method for training and saving the metrics in output_dir
model_manager.train_and_save_metrics()
print("Finished training and saving metrics for all seeds")
Step 2: Collate statistics into 1 file.
Take the stats computed in step 1 and organize them into CSVs suitable for training the HMM.
from src.utils import training_run_json_to_csv
training_run_json_to_csv(config['run_output_dir'], is_transformer=True, has_loss=False, lr=lr, optimizer=config['optimizer'], init_scaling=config['init_scaling'], input_dir=config['run_output_dir'], n_seeds=n_seeds)
Step 4: Train HMM.
Model selection computes the AIC-BIC-log-likelihood curves for varying number of hidden states in the HMM and saves out the best model for each number of hidden states.
from src.hmm import HMM
max_components = 8 # max no of components for which HMM will be trained
cov_type = "diag" # type of covariance for HMM model
n_seeds = 4 # no of seeds HMM needs to be trained for
n_iter = 10
cols = ['var_w', 'l1', 'l2'] # columnns of interest
first_n = 100 # no of rows to consider in the dataset
hmm_model = HMM(max_components, cov_type, n_seeds, n_iter)
data_dir = 'modular_demo_run/'
hmm_output = hmm_model.get_avg_log_likelihood(data_dir, cols)
Step 5: HMM Model Selection
Visualizing average log-likelihood, along with AIC and BIC helps with the model selection for the different HMM models we have trained. Currently we are selecting the HMM model with the lowest BIC.
from src.visualize import visualize_avg_log_likelihood,
visualize_avg_log_likelihood(hmm_output,'modular_demo_run')
Step 6: Saving the model
model_path = 'model_path'
save_model(model_path,hmm_output)
Step 7: Calculating Feature Importance
Calculating feature importance to shortlist top n most important features contributing to a state transition
from src.utils import munge_data
n_components = 8 # best model chosen from the visualization
model_path = 'model_path'
model, data, best_predictions, lengths = munge_data(hmm_model, model_path, data_dir, cols, n_components)
phases = list(set(hmm_model.best_model.predict(data, lengths=lengths)))
state_transitions = hmm_model.feature_importance(cols, data, best_predictions,phases,lengths) # dictionary storing state transitions
Step 8: Visualizing State Transitions
Visualizing state transitions using a DAG visualization which allows the user to interact with it for deeper insight into the training dynamics.
from src.visualize import visualize_dag
best_model_transmat = model.transmat_
visualize_dag(best_model_transmat, edge_hover_dict = state_transitions)
Citation
Thank you for your interest in our work! If you use this repo, please cite:
@article{
hu2023latent,
title={Latent State Models of Training Dynamics},
author={Michael Y. Hu and Angelica Chen and Naomi Saphra and Kyunghyun Cho},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2023},
url={https://openreview.net/forum?id=NE2xXWo0LF},
note={}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for visualizing_training-1.0.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3564260013da199668b0371d99e771443e0d29915d81629625dc13281b5daa16 |
|
MD5 | 551f9aa4a445ead575658c9a32a91bbc |
|
BLAKE2b-256 | 2a1b1218316daac7e6b438c52fe83b8efa231f60064da693795e3b7d7a26fc64 |
Hashes for visualizing_training-1.0.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0a15aac32bff7b7850b823bc3047256577af4d590b610ecec56f0239d0f9e212 |
|
MD5 | 83ca5804f0bf86fd2a508a2ef2c751ac |
|
BLAKE2b-256 | cb6a467faa42670cf790fc5724731b6aec994ef37b6de52c2862338f22e0d47e |