JAX-compatible Enumerations.
Project description
JAX_ENUMS
Installation | Examples | Cite
A Jax-compatible enumerable.
Installation
pip install jax_enums
Example
class Foo(Enumerable):
BAR = 0
BAZ = 1
def f(array: jax.Array, enumerable: Enum) -> jax.Array:
return array[enumerable.value]
array = jnp.zeros((2, 2))
enumerable = Foo.BAR
f(array, enumerable)
jax.jit(f)(array, enumerable)
Cite
@misc{pignatelli2023jax_enums,
author = {Pignatelli, Eduardo},
title = {JAX_ENUMS: JAX-compatible enumerations},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/epignatelli/jax_enums}}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jax_enums-0.1.2.tar.gz
(15.9 kB
view details)
Built Distribution
File details
Details for the file jax_enums-0.1.2.tar.gz
.
File metadata
- Download URL: jax_enums-0.1.2.tar.gz
- Upload date:
- Size: 15.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 75391a59e1f5be3a976715306182d22331d74c001ae3b580e95a88c03d58dabf |
|
MD5 | 1d1230a926ef2557794d46818e352d90 |
|
BLAKE2b-256 | c97e86dadf746bba5e0c7cfd08db797dfd9e52c9bbe904ee74998727cac17381 |
File details
Details for the file jax_enums-0.1.2-py2.py3-none-any.whl
.
File metadata
- Download URL: jax_enums-0.1.2-py2.py3-none-any.whl
- Upload date:
- Size: 12.9 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e9c9374789f9adb5874013c80e2be411af0ef1442cbc3c3c0bb0ac89de7f2f2c |
|
MD5 | 44be5884e9bd2994463b5ce52df2a073 |
|
BLAKE2b-256 | c9744cc28410c9f025ccb575123f35941f2812e67afd975f42b99682152e939e |