Skip to main content

Zero-overhead functional lensing for JAX PyTrees

Project description

Optix 🔍

A functional lensing library for JAX/Equinox, providing a way to focus on and modify nested values within PyTree structures. Optix generates the same HLO code as direct access, ensuring zero overhead.

Features

  • Type-safe lenses for any JAX PyTree structure
  • Zero runtime overhead (generates identical HLO code)
  • Intuitive API for accessing and modifying nested values
  • Complete static typing support

Example

from optix import focus
import jax.numpy as jnp

# Create a nested PyTree structure
data = MyStruct(
    x=jnp.array([1.0, 2.0]),
    nested=NestedStruct(y=jnp.array(3.0))
)

# Focus on and modify a nested value
result = focus(data).at(lambda x: x.nested.y).apply(jnp.square)
>>> MyStruct(
>>>     x=Array([1., 2.], dtype=float32),
>>>     nested=NestedStruct(
>>>         y=Array(9., dtype=float32)
>>>     )
>>> )

Installation

pip install jax-optix

License

MIT License

Credits

Special thanks to Patrick Kidger for providing helpful hints and the Equinox library, which this project builds upon.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_optix-0.1.9.tar.gz (3.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_optix-0.1.9-py3-none-any.whl (3.9 kB view details)

Uploaded Python 3

File details

Details for the file jax_optix-0.1.9.tar.gz.

File metadata

  • Download URL: jax_optix-0.1.9.tar.gz
  • Upload date:
  • Size: 3.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.31

File hashes

Hashes for jax_optix-0.1.9.tar.gz
Algorithm Hash digest
SHA256 9109b026d8b2fce1c3e353c8334710e3b507f82a2969e2ccb59bea6a243720fe
MD5 c4e3b2a85af7e7083ac4ca7f2cc90bb0
BLAKE2b-256 588d9c8bc3aa4c0a1c6ecd35d24fce75cd65dbbd23fa9ce7b30698bfcbf0f5b6

See more details on using hashes here.

File details

Details for the file jax_optix-0.1.9-py3-none-any.whl.

File metadata

  • Download URL: jax_optix-0.1.9-py3-none-any.whl
  • Upload date:
  • Size: 3.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.31

File hashes

Hashes for jax_optix-0.1.9-py3-none-any.whl
Algorithm Hash digest
SHA256 39f4ce14a33eb6c9dd4fb6b9e03909a96aa24d1b44594631c6f24acef21c8ce4
MD5 f9abb5c56531c43b40cf6dde2d4de54e
BLAKE2b-256 abe17e2a708e32b3105e0c6ec51514c9dcc797ad5a33008a6c70b764b5817d44

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page