Skip to main content

Immutable Map, compatible with JAX & Equinox

Project description

xmmutablemap

JAX-compatible Immutable Mapping

JAX prefers immutable objects but neither Python nor JAX provide an immutable dictionary. 😢
This repository defines a light-weight immutable map (lower-level than a dict) that JAX understands as a PyTree. 🎉 🕶️

Installation

PyPI platforms PyPI version

pip install xmmutablemap
using uv
uv add xmmutablemap
from source, using pip
pip install git+https://github.com/GalacticDynamics/xmmutablemap.git
building from source
cd /path/to/parent
git clone https://github.com/GalacticDynamics/xmmutablemap.git
cd xmmutablemap
pip install -e .  # editable mode

Documentation

xmutablemap provides the class ImmutableMap, which is a full implementation of Python's Mapping ABC. If you've used a dict then you already know how to use ImmutableMap! The things ImmutableMap adds is 1) immutability (and related benefits like hashability) and 2) compatibility with JAX.

from xmmutablemap import ImmutableMap

print(ImmutableMap(a=1, b=2, c=3))
# ImmutableMap({'a': 1, 'b': 2, 'c': 3})

print(ImmutableMap({"a": 1, "b": 2.0, "c": "3"}))
# ImmutableMap({'a': 1, 'b': 2.0, 'c': '3'})

JAX Integration

One of the key benefits of ImmutableMap is its compatibility with JAX. Since it's immutable and hashable, it can be used in places where JAX would normally complain about mutable objects like regular dictionaries.

Using ImmutableMap as a Default in JAX Dataclasses

Here's an example showing how ImmutableMap can be used as a default value in a dataclass, which is particularly useful with JAX:

import functools
import jax
import jax.numpy as jnp
from dataclasses import dataclass
from xmmutablemap import ImmutableMap


@functools.partial(
    jax.tree_util.register_dataclass, data_fields=["params"], meta_fields=["batch_size"]
)
@dataclass(frozen=True)
class Config:
    """Configuration with immutable default parameters."""

    # This works! ImmutableMap is immutable and hashable
    params: ImmutableMap[str, float] = ImmutableMap(
        learning_rate=0.001, momentum=0.9, weight_decay=1e-4
    )
    batch_size: int = 32


# JAX can safely transform functions using this dataclass
@jax.jit
def train_step(config: Config, data: jnp.ndarray) -> jnp.ndarray:
    """Example training step that uses config parameters."""
    lr = config.params["learning_rate"]
    return data * lr


# This works perfectly
config = Config()
data = jnp.array([1.0, 2.0, 3.0])
result = train_step(config, data)
print(f"Result: {result}")
# Result: [0.001 0.002 0.003]

Key Benefits for JAX

  • Immutability: Once created, ImmutableMap cannot be modified, preventing accidental mutations that could break JAX's functional programming model
  • Hashability: JAX can safely cache and memoize functions that use ImmutableMap instances
  • PyTree Support: ImmutableMap is registered as a JAX PyTree, so it works seamlessly with JAX transformations like jit, grad, vmap, etc.
  • Safe Defaults: Can be used as default values in dataclasses without the typical pitfalls of mutable defaults

Development

Actions Status

We welcome contributions!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xmmutablemap-0.2.1.tar.gz (97.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

xmmutablemap-0.2.1-py3-none-any.whl (7.6 kB view details)

Uploaded Python 3

File details

Details for the file xmmutablemap-0.2.1.tar.gz.

File metadata

  • Download URL: xmmutablemap-0.2.1.tar.gz
  • Upload date:
  • Size: 97.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for xmmutablemap-0.2.1.tar.gz
Algorithm Hash digest
SHA256 38a881b0fcc54352e2751628d11519d023bc6677e978483736f8f51b45c7147f
MD5 2df4cd4e3cff0d5092a5407a1a10a9c3
BLAKE2b-256 9fa6c4043eb00f297956a5447109985f41fabb0adbbf80786a21c2949f798ba4

See more details on using hashes here.

Provenance

The following attestation bundles were made for xmmutablemap-0.2.1.tar.gz:

Publisher: cd.yml on GalacticDynamics/xmmutablemap

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file xmmutablemap-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: xmmutablemap-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 7.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for xmmutablemap-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 002f97986e1cc343d362e38eae62a0a51ca7b76bd3b1de0f70a8f05dbe2b8c5c
MD5 f60c29ac589f4dc987ed58218d62be32
BLAKE2b-256 99894628be42c56e0063b2158a468fc0de3e21285800f8cab07f2d3be684c763

See more details on using hashes here.

Provenance

The following attestation bundles were made for xmmutablemap-0.2.1-py3-none-any.whl:

Publisher: cd.yml on GalacticDynamics/xmmutablemap

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page