Skip to main content

An implementation of methods described in Git Re-basin-paper by Ainsworth et al.

Project description

rebasin

An implementation of methods described in "Git Re-basin"-paper by Ainsworth et al.

Installation

pip install rebasin

Usage

Currently, only weight-matching is implemented as a method for rebasing, and only a simplified form of linear interpolation is implemented.

import torch
from torch import nn
from rebasin import PermutationCoordinateDescent
from rebasin import interpolation

model_a, model_b, train_dl, val_dl, loss_fn = ...

def eval_fn(model: nn.Module, device: str | torch.device | None = None) -> float:
    loss = 0.0
    for inputs, logits in val_dl:
        if device is not None:
            inputs = inputs.to(device)
            logits = logits.to(device)
        outputs = model(inputs)
        loss = loss_fn(outputs, logits)
    return loss / len(val_dl)

input_data = next(iter(train_dl))[0]

# Rebasin
pcd = PermutationCoordinateDescent(model_a, model_b, input_data)
pcd.calculate_permutations()
pcd.apply_permutations()

# Interpolate
lerp = interpolation.LerpSimple(
    models=[model_a, model_b], 
    eval_fn=eval_fn,  # Can be any metric as long as the function takes a model and a device
    eval_mode="min",  # "min" or "max"
    train_dataloader=train_dl,  # Used to recalculate BatchNorm statistics; optional
)
lerp.interpolate(steps=10)

# Access model with lowest validation loss:
lerp.best_model

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rebasin-0.0.21.tar.gz (19.0 kB view hashes)

Uploaded Source

Built Distribution

rebasin-0.0.21-py3-none-any.whl (20.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page