Skip to main content

A powerful library for transforming PyTorch state dict keys with rule-based mappings, arithmetic operations, and advanced transformations.

Project description

torch-state-bridge

torch-state-bridge is a powerful and flexible library for transforming PyTorch state_dict keys using rule-based mappings, regex captures, arithmetic expressions, and composable transformation pipelines.

It is designed to make model weight conversion easy across:

  • different architectures
  • renamed modules
  • framework migrations
  • checkpoints with inconsistent naming

✨ Features

  • 🔁 Rule-based key transformation using readable patterns
  • 🔢 Arithmetic expressions in key mappings
  • 🔄 Forward & reverse mappings
  • 🧩 Composable pipelines for complex workflows
  • 🌳 Nested dictionary support
  • 👀 Preview & diff tools before applying changes
  • 🚫 Collision detection
  • ♻️ Reusable rule templates
  • LRU-cached transformations for performance

📦 Installation

pip install torch-state-bridge

Or install from source:

git clone https://github.com/yourname/torch-state-bridge.git
cd torch-state-bridge
pip install -e .

🚀 Quick Start

from torch_state_bridge import state_bridge

state_dict = {
    "layer.0.weight": weight_tensor,
    "layer.0.bias": bias_tensor,
}

rules = """
layer.{n}.weight, block.{n}.weight
layer.{n}.bias,   block.{n}.bias
"""

new_state_dict = state_bridge(state_dict, rules)

Result:

layer.0.weight → block.0.weight
layer.0.bias   → block.0.bias

🧠 Rule Syntax

Basic Rule

source_pattern, destination_pattern

Capture Groups

layer.{n}.weight, block.{n}.weight
  • {n} captures numeric values
  • Captures are reusable in destination

Arithmetic Expressions

layer.{n}.weight, block.{(n + 1)}.weight

Supported operators:

  • + - * / // % **

Arithmetic is safe and sandboxed.


🔄 Reverse Rules

state_bridge(state_dict, rules, reverse=True)

⚠️ Reverse mode is not allowed for rules with arithmetic expressions.


🧪 Preview Before Applying

from torch_state_bridge import state_bridge_preview

mapping, unchanged, collisions = state_bridge_preview(state_dict, rules)
  • mapping: old → new keys
  • unchanged: keys not affected
  • collisions: conflicting output keys

🖨 Pretty Diff Output

from torch_state_bridge import print_diff

print_diff(state_dict, rules)

Example output:

============================================================
TRANSFORMATION PREVIEW
============================================================

📝 CHANGES:
  layer.0.weight -> block.0.weight
  layer.0.bias   -> block.0.bias

✓ UNCHANGED: 12 keys
============================================================

🧩 Batch Operations Pipeline

Apply multiple transformations sequentially:

from torch_state_bridge import state_bridge_batch

ops = [
    {"type": "prefix", "add": "model."},
    {"type": "rules", "rules_text": "layer.{n}, block.{n}"},
    {"type": "remove_prefix", "remove": "model."}
]

new_sd = state_bridge_batch(state_dict, ops)

Supported Batch Operations

  • prefix
  • suffix
  • remove_prefix
  • remove_suffix
  • replace
  • rules
  • filter

🌳 Nested State Dicts

Handles deeply nested dictionaries:

from torch_state_bridge import state_bridge_nested

nested = {
    "model": {
        "layer1": {
            "weight": tensor
        }
    }
}

rules = "model.layer1, model.block1"
new_nested = state_bridge_nested(nested, rules)

🔗 Rule Chains

Chain multiple rule engines with tracing:

from torch_state_bridge import RuleChain

chain = (
    RuleChain()
    .add("rename layers", "layer.{n}, block.{n}")
    .add("add prefix", "{key}, model.{key}")
)

new_sd = chain.apply(state_dict, trace=True)

🧱 Rule Templates

Reusable built-in templates:

from torch_state_bridge import RuleTemplate

new_sd = RuleTemplate.apply_template(
    state_dict,
    "huggingface_to_timm"
)

Available Templates

  • huggingface_to_timm
  • pytorch_to_tensorflow
  • add_prefix
  • remove_prefix

You can also expand templates manually:

rules = RuleTemplate.expand_template("add_prefix", prefix="model")

🔍 Rule Validation

from torch_state_bridge import validate_rules

errors = validate_rules(rules_text)
if errors:
    print(errors)

📐 Range Expansion

from torch_state_bridge import expand_range_rules

rules = "layer.{0..2}.weight, block.{0..2}.weight"
print(expand_range_rules(rules))

🔁 Inverse Rule Generation

from torch_state_bridge import generate_inverse_rules

inverse = generate_inverse_rules("layer.{n}, block.{n}")

⚠️ Collision Detection

By default, key collisions raise an error:

state_bridge(state_dict, rules, detect_collision=True)

Disable if needed:

state_bridge(state_dict, rules, detect_collision=False)

🛡 Safety

  • No eval
  • AST-based math evaluation
  • Strict regex capture rules
  • Safe integer-only arithmetic

📄 License

MIT License © 2026 Your Name


🤝 Contributing

Contributions are welcome!

  • Bug reports
  • Feature requests
  • New rule templates
  • Documentation improvements

🌟 Why torch-state-bridge?

Because renaming model weights should be declarative, safe, and composable.

If you work with:

  • model conversion
  • checkpoint surgery
  • research code cleanup
  • framework interoperability

torch-state-bridge is built for you.


Happy bridging 🚀

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_state_bridge-0.1.0.tar.gz (8.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_state_bridge-0.1.0-py3-none-any.whl (9.3 kB view details)

Uploaded Python 3

File details

Details for the file torch_state_bridge-0.1.0.tar.gz.

File metadata

  • Download URL: torch_state_bridge-0.1.0.tar.gz
  • Upload date:
  • Size: 8.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.20 {"installer":{"name":"uv","version":"0.9.20","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":null,"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for torch_state_bridge-0.1.0.tar.gz
Algorithm Hash digest
SHA256 a19d7e3279ab519b10b2fbe817b2cfef9e48375876c6fce570b3c1dfedd8e5b8
MD5 95e2c3cf84a07644e77de8e691f666fb
BLAKE2b-256 a32c6b9b0041d76ddea72f5b7d91bf6b74857a18d76e1c1c67b795066262e10d

See more details on using hashes here.

File details

Details for the file torch_state_bridge-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torch_state_bridge-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 9.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.20 {"installer":{"name":"uv","version":"0.9.20","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":null,"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for torch_state_bridge-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fe1e55dfc19b107b672ac4477164f548f1a48bf762e96fd5d5cffb74edc2fa56
MD5 a2e504c3955bb366d9938a446cb262f5
BLAKE2b-256 c683bb3ae92fa0f058ce857cd48cdf604c4bd0be6d82a0a4e0d8c41eaae287aa

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page