Skip to main content

Modelling Training Dynamics and Interpreting the dynamics

Project description

Latent State Models of Training Dynamics

Directly model training dynamics, then interpret the dynamics model.

Setup

For using the package, directly clone the repo and install the package in editable mode

git clone https://github.com/michahu/visualizing-training.git

pip install -e .

Usage

Below set of commands will walk you through the usage of the package, demo_run notebook can be referred for a notebook version of the same.

Step 0: Config Setup.

Setting up config required to collect training data and evaluation of the model

num_epochs = 100
lr = 1e-3
train_bsz = 256
cpu = True
weight_decay = 1.0
eval_every = 10
n_heads = 4
init_scaling = 1.0
optimizer = "adamw" # type of optimizer
output_dir = 'modular_demo_run' # output directory for writing collected metrics
seed = 0
use_ln = True
test_bsz = 2048
clip_grad = True
dataset_name = "modular" # specifying the type of dataset to be trained on
n_seeds = 4 # no of seeds for which the data needs to be collected

config = {
    "lr": lr,
    "cpu": cpu,
    "num_epochs": num_epochs,
    "eval_every": eval_every,
    "run_output_dir": output_dir,
    "use_accelerator": False,
    "init_scaling": init_scaling,
    "optimizer": optimizer,
    "weight_decay": weight_decay,
    "clip_grad": clip_grad,
    "is_transformer": True,
    "train_bsz": train_bsz,
    "test_bsz": test_bsz,
    "n_heads": n_heads,
    "dataset_name": dataset_name
}

Step 1: Training a model and collecting metrics

from src.training.modular_addition import get_dataloaders
from src.model import Transformer
from src.train import ModelManager

for seed in range(n_seeds): # Saving the metrics for all the seeds
    config["seed"] = seed

    model = Transformer(
        d_model=128,
        d_head=32,
        d_vocab=114,
        num_heads=n_heads,
        num_layers=1,
        n_ctx=3,
        use_ln=use_ln,
    )

    train_loader, test_loader = get_dataloaders(train_bsz=train_bsz, test_bsz=test_bsz)

    #Initialize the ModelManager class which will take care of training and collecting the metrics 
    model_manager = ModelManager(model,train_loader,test_loader, config)

    # Specify the layers where hooks needs to be attached
    model_manager.attach_hooks(['blocks.0.mlp','blocks.0.attn'])

    # method for training and saving the metrics in output_dir
    model_manager.train_and_save_metrics()
print("Finished training and saving metrics for all seeds")

Step 2: Collate statistics into 1 file.

Take the stats computed in step 1 and organize them into CSVs suitable for training the HMM.

from src.utils import training_run_json_to_csv

training_run_json_to_csv(config['run_output_dir'], is_transformer=True, has_loss=False, lr=lr, optimizer=config['optimizer'], init_scaling=config['init_scaling'], input_dir=config['run_output_dir'], n_seeds=n_seeds)

Step 4: Train HMM.

Model selection computes the AIC-BIC-log-likelihood curves for varying number of hidden states in the HMM and saves out the best model for each number of hidden states.

from src.hmm import HMM

max_components = 8 # max no of components for which HMM will be trained
cov_type = "diag" # type of covariance for HMM model
n_seeds = 4 # no of seeds HMM needs to be trained for
n_iter = 10
cols = ['var_w', 'l1', 'l2'] # columnns of interest
first_n = 100 # no of rows to consider in the dataset
hmm_model = HMM(max_components, cov_type, n_seeds, n_iter)
data_dir = 'modular_demo_run/'

hmm_output = hmm_model.get_avg_log_likelihood(data_dir, cols)

Step 5: HMM Model Selection

Visualizing average log-likelihood, along with AIC and BIC helps with the model selection for the different HMM models we have trained. Currently we are selecting the HMM model with the lowest BIC.

from src.visualize import visualize_avg_log_likelihood,

visualize_avg_log_likelihood(hmm_output,'modular_demo_run')

Step 6: Saving the model

model_path = 'model_path'
save_model(model_path,hmm_output)

Step 7: Calculating Feature Importance

Calculating feature importance to shortlist top n most important features contributing to a state transition

from src.utils import munge_data

n_components = 8 # best model chosen from the visualization
model_path = 'model_path'

model, data, best_predictions, lengths = munge_data(hmm_model, model_path, data_dir, cols, n_components)

phases = list(set(hmm_model.best_model.predict(data, lengths=lengths)))

state_transitions = hmm_model.feature_importance(cols, data, best_predictions,phases,lengths) # dictionary storing state transitions

Step 8: Visualizing State Transitions

Visualizing state transitions using a DAG visualization which allows the user to interact with it for deeper insight into the training dynamics.

from src.visualize import visualize_dag

best_model_transmat = model.transmat_

visualize_dag(best_model_transmat, edge_hover_dict = state_transitions)

Citation

Thank you for your interest in our work! If you use this repo, please cite:

@article{
hu2023latent,
title={Latent State Models of Training Dynamics},
author={Michael Y. Hu and Angelica Chen and Naomi Saphra and Kyunghyun Cho},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2023},
url={https://openreview.net/forum?id=NE2xXWo0LF},
note={}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

visualizing_training-1.0.0.tar.gz (32.0 kB view details)

Uploaded Source

Built Distribution

visualizing_training-1.0.0-py3-none-any.whl (34.7 kB view details)

Uploaded Python 3

File details

Details for the file visualizing_training-1.0.0.tar.gz.

File metadata

  • Download URL: visualizing_training-1.0.0.tar.gz
  • Upload date:
  • Size: 32.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.6

File hashes

Hashes for visualizing_training-1.0.0.tar.gz
Algorithm Hash digest
SHA256 3564260013da199668b0371d99e771443e0d29915d81629625dc13281b5daa16
MD5 551f9aa4a445ead575658c9a32a91bbc
BLAKE2b-256 2a1b1218316daac7e6b438c52fe83b8efa231f60064da693795e3b7d7a26fc64

See more details on using hashes here.

File details

Details for the file visualizing_training-1.0.0-py3-none-any.whl.

File metadata

File hashes

Hashes for visualizing_training-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0a15aac32bff7b7850b823bc3047256577af4d590b610ecec56f0239d0f9e212
MD5 83ca5804f0bf86fd2a508a2ef2c751ac
BLAKE2b-256 cb6a467faa42670cf790fc5724731b6aec994ef37b6de52c2862338f22e0d47e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page