Skip to main content

Zero-overhead functional lensing for JAX PyTrees

Project description

Optix 🔍

A functional lensing library for JAX/Equinox, providing a way to focus on and modify nested values within PyTree structures. Optix generates the same HLO code as direct access, ensuring zero overhead.

Features

  • Type-safe lenses for any JAX PyTree structure
  • Zero runtime overhead (generates identical HLO code)
  • Intuitive API for accessing and modifying nested values
  • Complete static typing support

Example

from optix import focus
import jax.numpy as jnp

# Create a nested PyTree structure
data = MyStruct(
    x=jnp.array([1.0, 2.0]),
    nested=NestedStruct(y=jnp.array(3.0))
)

# Focus on and modify a nested value
result = focus(data).at(lambda x: x.nested.y).apply(jnp.square)
>>> MyStruct(
>>>     x=Array([1., 2.], dtype=float32),
>>>     nested=NestedStruct(
>>>         y=Array(9., dtype=float32)
>>>     )
>>> )

Installation

pip install jax-optix

License

MIT License

Credits

Special thanks to Patrick Kidger for providing helpful hints and the Equinox library, which this project builds upon.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_optix-0.1.6.tar.gz (3.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_optix-0.1.6-py3-none-any.whl (3.7 kB view details)

Uploaded Python 3

File details

Details for the file jax_optix-0.1.6.tar.gz.

File metadata

  • Download URL: jax_optix-0.1.6.tar.gz
  • Upload date:
  • Size: 3.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.31

File hashes

Hashes for jax_optix-0.1.6.tar.gz
Algorithm Hash digest
SHA256 a0f0db571e04f91d33d4496df51cc3739bc9935758c49474fce9d2765e6a3d31
MD5 022b8230b74dfbe5bcf22d78e354cdb6
BLAKE2b-256 57feb99066d4d46e65672d9ffb832081cbdcb223a5a1df53577819eb52806211

See more details on using hashes here.

File details

Details for the file jax_optix-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: jax_optix-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 3.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.31

File hashes

Hashes for jax_optix-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 52c11f159295b9e2e5772189ef6d4dd1d2c9ac1f624354729dfe0903afb34451
MD5 7c7ed9212c67e763484a46b6eae8bc3d
BLAKE2b-256 9800f034de2c15c33718bdc65959b1f5dc031e6108519d1b90b58a1c3bab06ff

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page