Skip to main content

An easy-to-use Python library for merging PyTorch models.

Project description


terge is an easy-to-use Python library for merging PyTorch models. It works with models of any size and architecture, including Hugging Face 🤗 Transformers.

Features 🎯

  • 👌 Easy-to-use: a single line of code is all you need to get started.
  • ⚡ Lightning-fast: billions of parameters can be merged in mere seconds.
  • 📐 Architecture-agnostic: models of any size and architecture can be merged, provided they share a couple parameters with the same name and shape.
  • 🛠️ Hyper-customizable: parameters can be filtered in or out with regex, and custom weights can be assigned to models or even to their individual parameters.
  • 🌳 Lineage tracking: maps of merged parameter names to models' weightings can be produced to document precisely how models were merged.
  • 🤗 Hugging Face-friendly: Hugging Face 🤗 Transformers are supported out of the box.

Installation 🧑‍🔧

terge can be installed with pip:

pip install terge

Usage 👩‍💻

The following code snippet demonstrates how you can get started with terge:

import re
import torch
import terge

from transformers import AutoModel # NOTE `transformers` isn't required, this is just for demo purposes.

# A single line is all it takes to merge any number of models.
model = terge.merge([torch.nn.Linear(10, 1) for _ in range(3)])

# This also works for models of different architectures...
model = terge.merge([torch.nn.LSTM(10, 1, num_layers = 1), torch.nn.LSTM(10, 1, num_layers = 2)])

# And models of different sizes...
model = terge.merge([torch.nn.LSTM(10, 1, num_layers = 1), torch.nn.LSTM(100, 1, num_layers = 2)])

# And even Hugging Face 🤗 Transformers...
model = terge.merge([AutoModel.from_pretrained('umarbutler/emubert'),
                     AutoModel.from_pretrained('roberta-base')],
                     progress = True)

# Just make sure there's at least one shared named parameter in there.
model = terge.merge([torch.nn.Linear(10, 1), torch.nn.Linear(1, 10)]) # -> terge.NoParametersToMergeWarning

If you want even greater control over the merging process, terge has got you covered:

# Changing how parameters are merged and what model serves as the base is trivial.
model = terge.merge(
    [torch.nn.Linear(10, 1) for _ in range(3)],
    base = torch.nn.Linear(10, 1), # The base model doesn't even need to be getting merged! You can also
    # use the index of a model in the input models. The default is 0.
    weights = [1, 2, 3], # Weights are relative and correspond to the order of the input models such that,
    # here, the second model is weighted double the weight of the first model and the third model is weighted
    # triple the weight of the first model. The default is [1, 1, ...].
)

# Assigning custom weights to individual parameters is also easy.
model = terge.merge(
    [torch.nn.Linear(10, 1) for _ in range(3)],
    weights = {re.compile(r'weight'): [1, 2, 3], 'bias': [3, 2, 1]}, # Anything that doesn't match this map
    # will get a weight of 1. You can change that adding `re.compile(r'.*'): [...]` to the *end* of your
    # weights map.
)

# If you want to filter specific parameters in or out, that can be done too.
model = terge.merge(
    [torch.nn.Linear(10, 1) for _ in range(3)],
    included = re.compile(r'weight'), # Only parameters with 'weight' in their name will be merged.
    # You could also pass a string for an exact match.
    excluded = ['bias', re.compile(r'bias')], # Lists of strings and regex patterns work as well.
    # NOTE Exclusions execute after inclusions, so this isn't actually necessary.
)

# You can also enable lineage tracking to understand exactly how models got merged.
model, lineage = terge.merge(
    [torch.nn.Linear(10, 1) for _ in range(3)],
    lineage = True,
) # -> {'weight': ('arithmetic', [(0, 0.3333333333333333), (1, 0.3333333333333333), (2, 0.3333333333333333)]),
  #     'bias': ('arithmetic', [(0, 0.3333333333333333), (1, 0.3333333333333333), (2, 0.3333333333333333)])}

# Finally, for an extra speed boost, you can merge in-place (just keep in mind, this will modify your base model).
models = terge.merge(
    [torch.nn.Linear(10, 1) for _ in range(3)],
    inplace = True,
)

API 🧩

merge()

def merge(
    models: list[torch.nn.Module],
    base: torch.nn.Module | int = 0,
    method: Literal['arithmetic'] | dict[str | re.Pattern, Literal['arithmetic']] = 'arithmetic',
    weights: list[float] | dict[str | re.Pattern, list[float]] = None,
    included: re.Pattern | str | list[str | re.Pattern] = None,
    excluded: re.Pattern | str | list[str | re.Pattern] = None,
    inplace: bool = False,
    dtype: torch.dtype = torch.float64,
    lineage: bool = False,
    progress: bool = False,
) -> torch.nn.Module | tuple[torch.nn.Module, dict[str, tuple[str, list[tuple[int, float]]]]]

merge() merges PyTorch models.

models represents the models to be merged.

base represents the model whose parameters will be used as defaults and that, if inplace is set to True, will be merged into; or the index of such a model in models. It defaults to 0, that is, the index of the first model in models.

method represents the method to be used for merging the models' parameters, or a map of parameter names or regex patterns matching parameter names to the methods to be used to merge them. Currently, only the 'arithmetic' method is supported (that is, the merging of parameters by taking their ordinary or weighted arithmetic mean). method defaults to 'arithmetic'.

weights represents a list of all of the relative weights to be assigned to the models' parameters, or a map of parameter names or regex patterns matching parameter names to lists of weights. If set to None, all models will be weighted equally. If a dictionary is provided and there are any parameters to be merged that do not match any of the keys of that dictionary, they will be also weighted equally. weights defaults to None.

included represents a regex pattern, string or list of regex patterns and strings matching parameter names to be merged. If set to None, all parameters will be merged. included defaults to None.

excluded represents a regex pattern, string or list of regex patterns and strings matching parameter names to be excluded from merging. If set to None, no parameters will be excluded. If included is provided, this argument will apply to the subset of parameters that match included. excluded defaults to None.

inplace represents whether, for the sake of expediency or memory conservation, the base should be merged into in place instead of being deep copied. It defaults to False.

dtype represents the data type to be used for storing the weightings. It defaults to torch.float64.

lineage represents whether to output a tuple containing the merged model along with a dictionary mapping the names of merged parameters to a tuple containing the names of merge methods and a list of tuples containing the indices of merged models that contributed to those parameters and the weights they were assigned. It defaults to False.

progress represents whether to display a progress bar. It defaults to False.

merge() will return either a merged model, or, if lineage is True, a tuple containing the merged model along with a dictionary mapping the names of merged parameters to a tuple containing the names of merge methods and a list of tuples containing the indices of merged models that contributed to those parameters and the weights they were assigned, which looks like this:

{
    'parameter_name': ('method', [(model_index, weight), ...]),
    ...
}

Changelog 🔄

terge adheres to Keep a Changelog and Semantic Versioning. All notable changes to terge are documented in its Changelog 🔄.

License 📜

terge is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

terge-0.1.1.tar.gz (337.3 kB view details)

Uploaded Source

Built Distribution

terge-0.1.1-py3-none-any.whl (9.1 kB view details)

Uploaded Python 3

File details

Details for the file terge-0.1.1.tar.gz.

File metadata

  • Download URL: terge-0.1.1.tar.gz
  • Upload date:
  • Size: 337.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.2

File hashes

Hashes for terge-0.1.1.tar.gz
Algorithm Hash digest
SHA256 3e690b1cb2de7bf1771ebe1c04a026d39d014717c8ca1e4c946e56137df4cf16
MD5 b40a50333c2910ab8f90f2f4b3ab844b
BLAKE2b-256 e98c2934826e3e444ef18e003f6f9b673679423ed4757271c2870e26c2033e91

See more details on using hashes here.

File details

Details for the file terge-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: terge-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 9.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.2

File hashes

Hashes for terge-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 df0407668b0b6d1e550e02e7ad2d64b3a801834c27ba68fd31383ec5f24e6f45
MD5 a1d866519055d03773feb65fa2b41c18
BLAKE2b-256 f924f664711caae479aa2375769d07f6658b3832bf99cd3c26fb799ed3cf0519

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page