JAX-compatible Enumerations.
Project description
JAX_ENUMS
Installation | Examples | Cite
A Jax-compatible enumerable.
Installation
You can install jax_enums
directly from GitHub:
pip install git+https://github.com/epignatelli/jax_enums
Example
class Foo(Enumerable):
BAR = 0
BAZ = 1
def f(array: jax.Array, enumerable: Enum) -> jax.Array:
return array[enumerable.value]
array = jnp.zeros((2, 2))
enumerable = Foo.BAR
f(array, enumerable)
jax.jit(f)(array, enumerable)
Cite
@misc{pignatelli2023jax_enums,
author = {Pignatelli, Eduardo},
title = {JAX_ENUMS: JAX-compatible enumerations},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/epignatelli/jax_enums}}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jax_enums-0.1.0.tar.gz
(10.8 kB
view hashes)
Built Distribution
Close
Hashes for jax_enums-0.1.0-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 99bf56de49c72800b3b24dad14acb12adb167b2b2b35eb3d8f706af678cfeaf4 |
|
MD5 | 65bcf31873e399501bd7cf086208db7e |
|
BLAKE2b-256 | 7c85f2634746a15cda4dacfca6a7615d18a4f0ec67f4853b6b5e90278635ae99 |