Skip to main content

Zero-overhead functional lensing for JAX PyTrees

Project description

Optix 🔍

A functional lensing library for JAX/Equinox, providing a way to focus on and modify nested values within PyTree structures. Optix generates the same HLO code as direct access, ensuring zero overhead.

Features

  • Type-safe lenses for any JAX PyTree structure
  • Zero runtime overhead (generates identical HLO code)
  • Intuitive API for accessing and modifying nested values
  • Complete static typing support

Example

from optix import focus
import jax.numpy as jnp

# Create a nested PyTree structure
data = MyStruct(
    x=jnp.array([1.0, 2.0]),
    nested=NestedStruct(y=jnp.array(3.0))
)

# Focus on and modify a nested value
result = focus(data).at(lambda x: x.nested.y).apply(jnp.square)
>>> MyStruct(
>>>     x=Array([1., 2.], dtype=float32),
>>>     nested=NestedStruct(
>>>         y=Array(9., dtype=float32)
>>>     )
>>> )

Installation

pip install optix

License

MIT License

Credits

Special thanks to Patrick Kidger for providing helpful hints and the Equinox library, which this project builds upon.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_optix-0.1.1.tar.gz (3.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_optix-0.1.1-py3-none-any.whl (3.5 kB view details)

Uploaded Python 3

File details

Details for the file jax_optix-0.1.1.tar.gz.

File metadata

  • Download URL: jax_optix-0.1.1.tar.gz
  • Upload date:
  • Size: 3.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for jax_optix-0.1.1.tar.gz
Algorithm Hash digest
SHA256 919c1e80b0c03426dcc687d7c2bee27ee349bc7290d1c78027d6deb26be67c80
MD5 b82c54b9e9daedc54697eb45b1da2db7
BLAKE2b-256 7d53ffb42b1494676397400292fcc7128858e883c3d6cb0cb994fdb685d5a1b5

See more details on using hashes here.

File details

Details for the file jax_optix-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: jax_optix-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 3.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for jax_optix-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a096f75bfb53fbebf78a42ce27153e65d1270b6c6117abc52ba8d5d852377bd2
MD5 e4a33bdc04dc62474f3616a84b7218ab
BLAKE2b-256 4e746e2f6dd1a360cec0af4d46da9b6baeea51eb220c1d17c0def417ffa5b3d1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page