An implementation of methods described in Git Re-basin-paper by Ainsworth et al.
Project description
rebasin
An implementation of methods described in "Git Re-basin"-paper by Ainsworth et al.
Installation
pip install rebasin
Usage
Currently, only weight-matching is implemented as a method for rebasing, and only a simplified form of linear interpolation is implemented.
import torch
from torch import nn
from rebasin import PermutationCoordinateDescent
from rebasin import interpolation
model_a, model_b, train_dl, val_dl, loss_fn = ...
def eval_fn(model: nn.Module, device: str | torch.device | None = None) -> float:
loss = 0.0
for inputs, logits in val_dl:
if device is not None:
inputs = inputs.to(device)
logits = logits.to(device)
outputs = model(inputs)
loss = loss_fn(outputs, logits)
return loss / len(val_dl)
input_data = next(iter(train_dl))[0]
# Rebasin
pcd = PermutationCoordinateDescent(model_a, model_b, input_data)
pcd.calculate_permutations()
pcd.apply_permutations()
# Interpolate
lerp = interpolation.LerpSimple(
models=[model_a, model_b],
eval_fn=eval_fn, # Can be any metric as long as the function takes a model and a device
eval_mode="min", # "min" or "max"
train_dataloader=train_dl, # Used to recalculate BatchNorm statistics; optional
)
lerp.interpolate(steps=10)
# Access model with lowest validation loss:
lerp.best_model
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
rebasin-0.0.20.tar.gz
(18.1 kB
view hashes)
Built Distribution
rebasin-0.0.20-py3-none-any.whl
(19.0 kB
view hashes)