Skip to main content

Zero-overhead functional lensing for JAX PyTrees

Project description

Optix 🔍

A functional lensing library for JAX/Equinox, providing a way to focus on and modify nested values within PyTree structures. Optix generates the same HLO code as direct access, ensuring zero overhead.

Features

  • Type-safe lenses for any JAX PyTree structure
  • Zero runtime overhead (generates identical HLO code)
  • Intuitive API for accessing and modifying nested values
  • Complete static typing support

Example

from optix import focus
import jax.numpy as jnp

# Create a nested PyTree structure
data = MyStruct(
    x=jnp.array([1.0, 2.0]),
    nested=NestedStruct(y=jnp.array(3.0))
)

# Focus on and modify a nested value
result = focus(data).at(lambda x: x.nested.y).apply(jnp.square)
>>> MyStruct(
>>>     x=Array([1., 2.], dtype=float32),
>>>     nested=NestedStruct(
>>>         y=Array(9., dtype=float32)
>>>     )
>>> )

Installation

pip install optix

License

MIT License

Credits

Special thanks to Patrick Kidger for providing helpful hints and the Equinox library, which this project builds upon.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_optix-0.1.0.tar.gz (3.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_optix-0.1.0-py3-none-any.whl (3.5 kB view details)

Uploaded Python 3

File details

Details for the file jax_optix-0.1.0.tar.gz.

File metadata

  • Download URL: jax_optix-0.1.0.tar.gz
  • Upload date:
  • Size: 3.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for jax_optix-0.1.0.tar.gz
Algorithm Hash digest
SHA256 09665ce98b56c78f805ce550972bf4ed40c01d3b47b13ade67631eefbd295b1a
MD5 e964026a425fc91d3e7cec10de030b32
BLAKE2b-256 f9685dbb80ca9e8a9001cc92683485b74a7584874276559dc8d56c3a4242b44a

See more details on using hashes here.

File details

Details for the file jax_optix-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: jax_optix-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 3.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for jax_optix-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fae24ec63ec934a231275d931c1fdd81b72e734b270354086c1119f309f0d49c
MD5 813dc4ddbe74b289cc0df64f1b6c26b0
BLAKE2b-256 2adef6f4cf7fba0a10a0243617232ef470c872c2a153125adda0a2e610cee5f2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page