Skip to main content

Zero-overhead functional lensing for JAX PyTrees

Project description

Optix 🔍

A functional lensing library for JAX/Equinox, providing a way to focus on and modify nested values within PyTree structures. Optix generates the same HLO code as direct access, ensuring zero overhead.

Features

  • Type-safe lenses for any JAX PyTree structure
  • Zero runtime overhead (generates identical HLO code)
  • Intuitive API for accessing and modifying nested values
  • Complete static typing support

Example

from optix import focus
import jax.numpy as jnp

# Create a nested PyTree structure
data = MyStruct(
    x=jnp.array([1.0, 2.0]),
    nested=NestedStruct(y=jnp.array(3.0))
)

# Focus on and modify a nested value
result = focus(data).at(lambda x: x.nested.y).apply(jnp.square)
>>> MyStruct(
>>>     x=Array([1., 2.], dtype=float32),
>>>     nested=NestedStruct(
>>>         y=Array(9., dtype=float32)
>>>     )
>>> )

Installation

pip install optix

License

MIT License

Credits

Special thanks to Patrick Kidger for providing helpful hints and the Equinox library, which this project builds upon.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_optix-0.1.3.tar.gz (3.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_optix-0.1.3-py3-none-any.whl (3.5 kB view details)

Uploaded Python 3

File details

Details for the file jax_optix-0.1.3.tar.gz.

File metadata

  • Download URL: jax_optix-0.1.3.tar.gz
  • Upload date:
  • Size: 3.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for jax_optix-0.1.3.tar.gz
Algorithm Hash digest
SHA256 79c9b645452e88d3a6125c20486bf21ccd68e2d8ce22683124d157d0b0d8bf6f
MD5 f788f7947e9b838c1c9d88777a6cdcb7
BLAKE2b-256 a78c576ec6fe4960add17fd6428e00a9c2eb341f2a5c7c73ef5f47d4b98ad0c0

See more details on using hashes here.

File details

Details for the file jax_optix-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: jax_optix-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 3.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for jax_optix-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 558d603e1e4ac52ba5bf55e3c8c2569fce316a30026b68b61132c9bc13130614
MD5 a0baa172dd42e5edf01995de9b103222
BLAKE2b-256 46456024e784fd44d58d53219b52ad0dbeec53ccaf2bbcb6d299e2a959398219

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page