High-order Polynomial Projection Operators (HiPPO) for JAX
Project description
Hippox: High-order Polynomial Projection Operators for JAX
What is Hippox?
Hippox provides a simple dataclass for initializing High-order Polynomial Projection Operators (HiPPOs) as parameters in JAX neural network libraries such as Flax and Haiku.
Example
Here is an example of initializing HiPPO parameters inside a Haiku module:
import haiku as hk
from hippox.main import Hippo
class MyHippoModule(hk.Module):
def __init__(self, state_size, measure)
_hippo = Hippo(state_size=state_size, measure=measure)
_hippo()
self._lambda_real = hk.get_parameter(
'lambda_real',
shape=[state_size,],
init = _hippo.lambda_initializer('real')
)
self._lambda_imag = hk.get_parameter(
'lambda_imaginary',
shape=[state_size,],
init = _hippo.lambda_initializer('imaginary')
)
self._state_matrix = self._lambda_real + 1j * self._lambda_imag
self._input_matrix = hk.get_parameter(
'input_matrix',
shape=[state_size, 1],
init=_hippo.b_initializer()
)
def __call__(input, prev_state):
new_state = self._state_matrix @ prev_state + self._input_matrix @ input
return new_state
If using a library (such as Equinox) which does not require an initializer function but simply takes JAX ndarrays for parameterization, then you can call the HiPPO matrices directly as a property of the base class after it has been called:
import equinox as eqx
from hippox.main import Hippo
class MyHippoModule(eqx.Module):
A: jnp.ndarray
B: jnp.ndarray
def __init__(self, state_size, measure)
_hippo = Hippo(state_size=state_size, measure=measure)
_hippo_params = _hippo()
self.A = _hippo_params.state_matrix
self.B = _hippo_params.input_matrix
def __call__(input, prev_state):
new_state = self.A @ prev_state + self.B @ input
return new_state
Installation
hippox can be easily installed through PyPi:
pip install hippox
References
Repositories
-
https://github.com/HazyResearch/state-spaces - Original paper implementations in PyTorch
-
https://github.com/srush/annotated-s4 - JAX implementation of S4 models (S4, S4D, DSS)
Papers
-
HiPPO: Recurrent Memory with Optimal Polynomial Projections: https://arxiv.org/abs/2008.07669 - Original paper which introduced HiPPOs
-
Efficiently Modeling Long Sequences with Structured State Spaces: https://arxiv.org/abs/2111.00396 - S4 paper, introduces normal/diagonal plus low rank decomposition
-
How to Train Your HiPPO: State Space Models with Generalized Orthogonal Basis Projections: https://arxiv.org/abs/2206.12037 - Generalizes and explains the core principals behind HiPPO
-
On the Parameterization and Initialization of Diagonal State Space Models: https://arxiv.org/abs/2206.11893 - S4D paper, details and explains the diagonal only parameterization
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file hippox-0.0.9.tar.gz.
File metadata
- Download URL: hippox-0.0.9.tar.gz
- Upload date:
- Size: 8.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bf46150ebcdaeb403738818bee283d7b7a9c65bf533e97dd973991b62f2dfe5a
|
|
| MD5 |
579489d9437f2e1e16dbf13d882dfa87
|
|
| BLAKE2b-256 |
dd5404c8cf1077f2af402572bebf8b98785d9d0c515aa3be2b31220eb2fc080e
|
File details
Details for the file hippox-0.0.9-py3-none-any.whl.
File metadata
- Download URL: hippox-0.0.9-py3-none-any.whl
- Upload date:
- Size: 8.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f9e9e71dc762a0528418369bd9269b1a773348531d6cac2ef5f84384319d9148
|
|
| MD5 |
3be5d0ddf5a84a1b3ced28cdb809082c
|
|
| BLAKE2b-256 |
6dd4e743e5b8905cdfa0bfc7882c53ca24176f71b1146f055e84b374c8636c86
|