Skip to main content

Zero-overhead functional lensing for JAX PyTrees

Project description

Optix 🔍

A functional lensing library for JAX/Equinox, providing a way to focus on and modify nested values within PyTree structures. Optix generates the same HLO code as direct access, ensuring zero overhead.

Features

  • Type-safe lenses for any JAX PyTree structure
  • Zero runtime overhead (generates identical HLO code)
  • Intuitive API for accessing and modifying nested values
  • Complete static typing support

Example

from optix import focus
import jax.numpy as jnp

# Create a nested PyTree structure
data = MyStruct(
    x=jnp.array([1.0, 2.0]),
    nested=NestedStruct(y=jnp.array(3.0))
)

# Focus on and modify a nested value
result = focus(data).at(lambda x: x.nested.y).apply(jnp.square)
>>> MyStruct(
>>>     x=Array([1., 2.], dtype=float32),
>>>     nested=NestedStruct(
>>>         y=Array(9., dtype=float32)
>>>     )
>>> )

Installation

pip install optix

License

MIT License

Credits

Special thanks to Patrick Kidger for providing helpful hints and the Equinox library, which this project builds upon.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_optix-0.1.2.tar.gz (3.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_optix-0.1.2-py3-none-any.whl (3.5 kB view details)

Uploaded Python 3

File details

Details for the file jax_optix-0.1.2.tar.gz.

File metadata

  • Download URL: jax_optix-0.1.2.tar.gz
  • Upload date:
  • Size: 3.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for jax_optix-0.1.2.tar.gz
Algorithm Hash digest
SHA256 04f809b57d9dcc9878c2307cf0b50118a783087a4cbe0fc5bdcd61c62be146c8
MD5 f8a27a7a59c1ea1ae18c92b1ff070140
BLAKE2b-256 7092fbbedb0e6e397d7ffbf8be63595d2b82f384a0edb21f246553260918d92b

See more details on using hashes here.

File details

Details for the file jax_optix-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: jax_optix-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 3.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for jax_optix-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 3ae3d96ffe3e8332b05c2e31d5a7b174a2fdfe7d62b774f24462093cfa5de720
MD5 79909d76c168f68981052f1c5568f988
BLAKE2b-256 f863783bfd4c4695f2ef69918c2a18da924e37b7e68d6b34db461060d6c9651d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page