Skip to main content

Mask algebra for selecting and combining JAX PyTree leaves.

Project description

maskx

Minimal path-based masking for JAX PyTrees.

maskx builds Mask objects from pytree paths and simple mask algebra.

import jax
import maskx

weight = maskx.select(model, target=r".*/weight", leaf_type=jax.Array)
decoder = maskx.select(model, target=r"decoder/.*", leaf_type=jax.Array)

mask = decoder & weight
paths = mask.paths()
count = mask.count()

Selectors can be based on target, path_prefix, path_in, leaf_type, shape, dtype, and ndim.

Mask operators: |, &, ^, +, -, ~

Works with Optax:

import jax
import optax
import maskx

weight = maskx.select(model, target=r".*/weight", leaf_type=jax.Array)
optimizer = optax.masked(optax.adam(1e-3), weight.tree)

Works with Paramax:

import jax
import jax.tree_util as jtu
import maskx
import paramax

weight_mask = maskx.select(model, target="weight", leaf_type=jax.Array)

frozen = jtu.tree_map(
    lambda leaf, selected: paramax.NonTrainable(leaf) if selected else leaf,
    model,
    weight_mask.tree,
)

The library is intentionally small: it only builds and combines masks.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

maskx-0.1.0.tar.gz (55.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

maskx-0.1.0-py3-none-any.whl (5.4 kB view details)

Uploaded Python 3

File details

Details for the file maskx-0.1.0.tar.gz.

File metadata

  • Download URL: maskx-0.1.0.tar.gz
  • Upload date:
  • Size: 55.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for maskx-0.1.0.tar.gz
Algorithm Hash digest
SHA256 0bb60cdffca90a00bdf242d7217be99d49d27f330c7a1c0e384aa2c8f3b61007
MD5 5dcf000f7f9110a895092ae1da99d3ec
BLAKE2b-256 4f470f9d2e8efe41d9e3beaac31fb1e925d5e736c30f06033ab8a84d2a75f1f2

See more details on using hashes here.

File details

Details for the file maskx-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: maskx-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 5.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for maskx-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7d44613106de3a2a2590434a5d6ee6c7401af91ac7a61f7c8af0f8cf99afe554
MD5 62db537967a567cf3843cfbad6cd40f1
BLAKE2b-256 8fa1ac270dcffc5d2db7b58cf5c77de5f603776ed1c0a6838d53203ca5025077

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page