Skip to main content

High-order Polynomial Projection Operators (HiPPO) for JAX

Project description

Hippox: High-order Polynomial Projection Operators for JAX

image

What is Hippox?

Hippox provides a simple dataclass for initializing High-order Polynomial Projection Operators (HiPPOs) as parameters in JAX neural network libraries such as Flax and Haiku.

Example

Here is an example of initializing HiPPO parameters inside a Haiku module:

import haiku as hk 
from hippox.main import Hippo

class MyHippoModule(hk.Module):
    def __init__(self, state_size, measure)
        _hippo = Hippo(state_size=state_size, measure=measure)
        _hippo()

        self._lambda_real = hk.get_parameter(
            'lambda_real',
            shape=[state_size,],
            init = _hippo.lambda_initializer('real')
        )
        self._lambda_imag = hk.get_parameter(
            'lambda_imaginary',
            shape=[state_size,],
            init = _hippo.lambda_initializer('imaginary')
        )
        self._state_matrix = self._lambda_real + 1j * self._lambda_imag

        self._input_matrix = hk.get_parameter(
            'input_matrix',
            shape=[state_size, 1],
            init=_hippo.b_initializer()
        )

    def __call__(input, prev_state):
        new_state = self._state_matrix @ prev_state + self._input_matrix @ input
        return new_state

If using a library (such as Equinox) which does not require an initializer function but simply takes JAX ndarrays for parameterization, then you can call the HiPPO matrices directly as a property of the base class after it has been called:

import equinox as eqx
from hippox.main import Hippo

class MyHippoModule(eqx.Module):
    A: jnp.ndarray
    B: jnp.ndarray

    def __init__(self, state_size, measure)
        _hippo = Hippo(state_size=state_size, measure=measure)
        _hippo_params = _hippo()
        
        self.A = _hippo_params.state_matrix
        self.B = _hippo_params.input_matrix

    def __call__(input, prev_state):
        new_state = self.A @ prev_state + self.B @ input
        return new_state

Installation

hippox can be easily installed through PyPi:

pip install hippox

References

Repositories

  1. https://github.com/HazyResearch/state-spaces - Original paper implementations in PyTorch

  2. https://github.com/srush/annotated-s4 - JAX implementation of S4 models (S4, S4D, DSS)

Papers

  1. HiPPO: Recurrent Memory with Optimal Polynomial Projections: https://arxiv.org/abs/2008.07669 - Original paper which introduced HiPPOs

  2. Efficiently Modeling Long Sequences with Structured State Spaces: https://arxiv.org/abs/2111.00396 - S4 paper, introduces normal/diagonal plus low rank decomposition

  3. How to Train Your HiPPO: State Space Models with Generalized Orthogonal Basis Projections: https://arxiv.org/abs/2206.12037 - Generalizes and explains the core principals behind HiPPO

  4. On the Parameterization and Initialization of Diagonal State Space Models: https://arxiv.org/abs/2206.11893 - S4D paper, details and explains the diagonal only parameterization

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hippox-0.0.9.tar.gz (8.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hippox-0.0.9-py3-none-any.whl (8.8 kB view details)

Uploaded Python 3

File details

Details for the file hippox-0.0.9.tar.gz.

File metadata

  • Download URL: hippox-0.0.9.tar.gz
  • Upload date:
  • Size: 8.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.0

File hashes

Hashes for hippox-0.0.9.tar.gz
Algorithm Hash digest
SHA256 bf46150ebcdaeb403738818bee283d7b7a9c65bf533e97dd973991b62f2dfe5a
MD5 579489d9437f2e1e16dbf13d882dfa87
BLAKE2b-256 dd5404c8cf1077f2af402572bebf8b98785d9d0c515aa3be2b31220eb2fc080e

See more details on using hashes here.

File details

Details for the file hippox-0.0.9-py3-none-any.whl.

File metadata

  • Download URL: hippox-0.0.9-py3-none-any.whl
  • Upload date:
  • Size: 8.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.0

File hashes

Hashes for hippox-0.0.9-py3-none-any.whl
Algorithm Hash digest
SHA256 f9e9e71dc762a0528418369bd9269b1a773348531d6cac2ef5f84384319d9148
MD5 3be5d0ddf5a84a1b3ced28cdb809082c
BLAKE2b-256 6dd4e743e5b8905cdfa0bfc7882c53ca24176f71b1146f055e84b374c8636c86

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page