Skip to main content

Zero-overhead functional lensing for JAX PyTrees

Project description

Optix 🔍

A functional lensing library for JAX/Equinox, providing a way to focus on and modify nested values within PyTree structures. Optix generates the same HLO code as direct access, ensuring zero overhead.

Features

  • Type-safe lenses for any JAX PyTree structure
  • Zero runtime overhead (generates identical HLO code)
  • Intuitive API for accessing and modifying nested values
  • Complete static typing support

Example

from optix import focus
import jax.numpy as jnp

# Create a nested PyTree structure
data = MyStruct(
    x=jnp.array([1.0, 2.0]),
    nested=NestedStruct(y=jnp.array(3.0))
)

# Focus on and modify a nested value
result = focus(data).at(lambda x: x.nested.y).apply(jnp.square)
>>> MyStruct(
>>>     x=Array([1., 2.], dtype=float32),
>>>     nested=NestedStruct(
>>>         y=Array(9., dtype=float32)
>>>     )
>>> )

Installation

pip install jax-optix

License

MIT License

Credits

Special thanks to Patrick Kidger for providing helpful hints and the Equinox library, which this project builds upon.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_optix-0.1.5.tar.gz (3.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_optix-0.1.5-py3-none-any.whl (3.6 kB view details)

Uploaded Python 3

File details

Details for the file jax_optix-0.1.5.tar.gz.

File metadata

  • Download URL: jax_optix-0.1.5.tar.gz
  • Upload date:
  • Size: 3.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.31

File hashes

Hashes for jax_optix-0.1.5.tar.gz
Algorithm Hash digest
SHA256 34dd85ad4f70b3cf12880d2b34b9da2a734487c8f02cbaa08478a10149b5e8a5
MD5 4fb2f74ae3a651c2c74a8496cee0a270
BLAKE2b-256 e102bd1505afb71ddf27bc1618f32e82ac8becc58faff381725b6a0957ea0c3c

See more details on using hashes here.

File details

Details for the file jax_optix-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: jax_optix-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 3.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.31

File hashes

Hashes for jax_optix-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 21ccf7c0abfc2726a2524426f1e663d298e6cba30650be4ad7f1672682c7985b
MD5 b8fc4dfb58c4d123dd938c36b5fa6812
BLAKE2b-256 d186ed9da35d572e90dbab9c5ef08089d95c710b860a2fcf9ae2f0ffa2bda73e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page