Skip to main content

Library to map (deep learning) model weights between different model implementations.

Project description

weightbridge :bridge_at_night:

What?

A library to map (deep learning) model weights between different model implementations in Python.

Why?

Model weights trained using one implementation of an architecture typically cannot directly be loaded into a different implementation of the same architecture, due to:

  • Different parameter and layer names.
  • Different nesting of modules.
  • Different parameter shapes (e.g. (8, 8) vs (64) vs (1, 8, 8)).
  • Different order of dimensions (e.g. (64, 48) vs (48, 64)).
  • Different deep learning frameworks (e.g. PyTorch, Tensorflow, Flax).

Adapting the weights manually is a tedious and error-prone process:

k = k.replace('downsample_layers.0.', 'stem.')
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
k = k.replace('pwconv', 'mlp.fc')
if 'grn' in k:
    k = k.replace('grn.beta', 'mlp.grn.bias')
    v = v.reshape(v.shape[-1])
k = k.replace('head.', 'head.fc.')
if v.ndim == 2 and 'head' not in k:
    model_shape = model.state_dict()[k].shape
    v = v.reshape(model_shape)

weightbridge does most of this work for you.

How?

import weightbridge
new_my_weights = weightbridge.adapt(their_weights, my_weights)
  • my_weights contains the (random) untrained weights created at model initialization (e.g. as the result of model.state_dict() in PyTorch, or using model.init in Flax and Haiku).
  • their_weights contains the pretrained weights (e.g. as the result of torch.load, tf.train.load_checkpoint or np.load).

The output has the same structure and weight shapes as my_weights, but with the weight values from their_weights. It can be used as drop-in for my_weights, and for example be stored back into the model using model.load_state_dict in PyTorch, or be used in model.apply in Flax and Haiku.

Installation:

pip install weightbridge

Full examples:

Additional parameters:

  • {in_format|out_format}="{pytorch|tensorflow|flax|haiku|...}" when weights are adapted between different deep learning frameworks (to permute weight axes as required).
  • hints=[...] to provide additional hints when ambiguous matches cannot be resolved. weightbridge prints an error when this happens, for example:
    Failed to pair the following nodes
      OUT load_prefix/encode/stage3/block6/reduce/linear/w ((262144,),)
      OUT load_prefix/encode/stage3/block6/expand/linear/w ((262144,),)
      IN  backbone.0.body.layer3.5.conv1.weight ((262144,),)
      IN  backbone.0.body.layer3.5.conv3.weight ((262144,),)
    
    We can pass hints=[("reduce", "conv1")] (consisting of some uniquely identifying substrings) to resolve the matching failure.
  • cache="some-file" to store the mapping in a file and reuse it in subsequent calls. If it is not an absolute path, the file is created in the directory of the module from which weightbridge.adapt is called.
  • verbose=True to print the matching steps and the final mapping between weights.

weightbridge internally uses a set of heuristics based on the weights' names and shapes to iteratively find mappings between subsets of my_weights and their_weights, until a unique pairing between all weights is found.

What does weightbridge not do?

  • Model implementation: weightbridge does not implement the model, but adapts the weights once the model is implemented (athough it provides a partial sanity-check for the implementation by ensuring that a mapping between the two sets of weights is possible). When the architecture is implemented using different operations, the weights have to be adapted manually. E.g. in Transformer attention, queries, keys and values can be inferred using different operations:
    # Option 1
    x = nn.Linear(features=3 * c)(x)
    q, k, v = jnp.split(x, 3, axis=-1)
    
    # Option 2
    q = nn.Linear(features=c)(x)
    k = nn.Linear(features=c)(x)
    v = nn.Linear(features=c)(x)
    
    The corresponding weights have to be split/ concatenated manually and will not be matched by weightbridge otherwise, since it relies on a one-to-one mapping between weights.
  • Hyperparameters: weightbridge does not ensure that hyperparameters like nn.LayerNorm(epsilon=1e-6) or nn.Conv(padding="SAME") are set correctly (although some hyperparameters like use_bias={True|False} will raise an exception if not set correctly).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

weightbridge-0.0.4.tar.gz (18.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

weightbridge-0.0.4-py3-none-any.whl (22.1 kB view details)

Uploaded Python 3

File details

Details for the file weightbridge-0.0.4.tar.gz.

File metadata

  • Download URL: weightbridge-0.0.4.tar.gz
  • Upload date:
  • Size: 18.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for weightbridge-0.0.4.tar.gz
Algorithm Hash digest
SHA256 350ad9b03cf695ec8c12d367e82de2e9cb36c33229a87ac302ba531e4c1a2c43
MD5 7384d50ab6461a4c779e3fffd40d9941
BLAKE2b-256 b946e003e57b91d68860aec31424f3de985d925c0f771c39c42aed54b799dac9

See more details on using hashes here.

File details

Details for the file weightbridge-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: weightbridge-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 22.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for weightbridge-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 7cbf3a032eea4522e4dbbd5896cfa4ab1a4d650223b320f4716a436a10bdcedc
MD5 44769c2cd78272dbf825df77cb5502f0
BLAKE2b-256 e406adb7bd5c800193d13c791f9f1fd4caeda4b47490560c595ae461709eca69

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page