Skip to main content

Zero-overhead functional lensing for JAX PyTrees

Project description

Optix 🔍

A functional lensing library for JAX/Equinox, providing a way to focus on and modify nested values within PyTree structures. Optix generates the same HLO code as direct access, ensuring zero overhead.

Features

  • Type-safe lenses for any JAX PyTree structure
  • Zero runtime overhead (generates identical HLO code)
  • Intuitive API for accessing and modifying nested values
  • Complete static typing support

Example

from optix import focus
import jax.numpy as jnp

# Create a nested PyTree structure
data = MyStruct(
    x=jnp.array([1.0, 2.0]),
    nested=NestedStruct(y=jnp.array(3.0))
)

# Focus on and modify a nested value
result = focus(data).at(lambda x: x.nested.y).apply(jnp.square)
>>> MyStruct(
>>>     x=Array([1., 2.], dtype=float32),
>>>     nested=NestedStruct(
>>>         y=Array(9., dtype=float32)
>>>     )
>>> )

Installation

pip install jax-optix

License

MIT License

Credits

Special thanks to Patrick Kidger for providing helpful hints and the Equinox library, which this project builds upon.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_optix-0.1.10.tar.gz (4.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_optix-0.1.10-py3-none-any.whl (3.9 kB view details)

Uploaded Python 3

File details

Details for the file jax_optix-0.1.10.tar.gz.

File metadata

  • Download URL: jax_optix-0.1.10.tar.gz
  • Upload date:
  • Size: 4.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.31

File hashes

Hashes for jax_optix-0.1.10.tar.gz
Algorithm Hash digest
SHA256 29c0b9351a9d58b3101ba7ef75667b760ddcf5c350c5a0339df3ff0b888a210b
MD5 6ee04d4e744a4126d903b007d55411f9
BLAKE2b-256 46ea6139fd59a5ac1a0749766df598edd6195c57208ba686ac131270466db97b

See more details on using hashes here.

File details

Details for the file jax_optix-0.1.10-py3-none-any.whl.

File metadata

  • Download URL: jax_optix-0.1.10-py3-none-any.whl
  • Upload date:
  • Size: 3.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.31

File hashes

Hashes for jax_optix-0.1.10-py3-none-any.whl
Algorithm Hash digest
SHA256 26931feef9faf4b1ceae0374c3c98efebd371d376c4639b08bdf5171b10c681e
MD5 91ba52e7116d4a9ba69bcb438516a352
BLAKE2b-256 fabbd228b5b8a4446934518c40d754be4e98671b00beec840f880bcaa12aacd8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page