Skip to main content

Mask algebra for selecting and combining JAX PyTree leaves.

Project description

maskx

pip install maskx

Mask algebra for selecting and combining JAX PyTree leaves. Backed by flat NumPy arrays for fast operations on large trees.

import jax
import maskx

weight = maskx.select(model, target=r".*/weight", leaf_type=jax.Array)
decoder = maskx.select(model, target=r"decoder/.*", leaf_type=jax.Array)

mask = decoder & weight
mask.paths()    # selected leaf paths
mask.count()    # number of selected leaves
mask.summary()  # "2/348 leaves selected"

Selectors: target, path_prefix, path_in, leaf_type, shape, dtype, ndim, where.

Operators: |, &, ^, +, -, ~

a = maskx.select(model, target=r"decoder/.*", leaf_type=jax.Array)
b = maskx.select(model, target=r".*/weight", leaf_type=jax.Array)

a | b   # union — decoder leaves OR weights
a & b   # intersection — decoder weights only
a ^ b   # symmetric difference — in one but not both
a + b   # alias for union (a | b)
a - b   # difference — decoder leaves that are NOT weights
~a      # complement — everything except decoder leaves

# chain freely
trainable = (a | b) - maskx.select(model, target=r".*norm.*")

# cumulative: build up from multiple masks
masks = [maskx.select(model, path_prefix=p) for p in prefixes]
combined = masks[0]
for m in masks[1:]:
    combined = combined | m

# or via combine_masks
combined = maskx.combine_masks(*masks, op="or")   # "and", "xor" also supported

Apply a function to selected leaves only:

mask.apply(model, fn=lambda x: x * 0)

Works with Optax:

weight = maskx.select(model, target=r".*/weight", leaf_type=jax.Array)
optimizer = optax.masked(optax.adam(1e-3), weight.tree)

Works with Paramax:

weight_mask = maskx.select(model, target="weight", leaf_type=jax.Array)
frozen = weight_mask.apply(model, fn=paramax.NonTrainable)

Example notebook

See docs/notebooks/equinox_optax_demo.ipynb for a small Equinox MLP example that uses maskx to train only selected parameters with Optax.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

maskx-0.1.2.tar.gz (59.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

maskx-0.1.2-py3-none-any.whl (6.7 kB view details)

Uploaded Python 3

File details

Details for the file maskx-0.1.2.tar.gz.

File metadata

  • Download URL: maskx-0.1.2.tar.gz
  • Upload date:
  • Size: 59.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for maskx-0.1.2.tar.gz
Algorithm Hash digest
SHA256 2fefd665d129373dea307b99a715ec9c50d09a895cefa41b97e84e4b7d6e2996
MD5 da78bde221ef313a7a2f0d354d4d2ee6
BLAKE2b-256 c902d20caa70d960a8cab9c8e4e47265ebb7262d11d888400ac12cfd2b8e5f9d

See more details on using hashes here.

File details

Details for the file maskx-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: maskx-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 6.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for maskx-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 e116b59ada9b88d4ef448fb48e427d8885028d9f872708f3211680edb77a67a0
MD5 3188225f60cc47f58a11cc9677bfbc9a
BLAKE2b-256 3fff3810995b931d447b97cadd64213993b1ed5b04e760e4ad0bff5d45233853

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page