A differentiable implementation of an all-pole filter in JAX
Project description
jaxpole
This is a Direct-Form I implementation of a time-varying all-pole filter in JAX based on torchlpc.
Install
pip install jaxpole
or locally from source
pip install -e '.[dev]'
How to use
import jax.numpy as jnp
import jax
pole = 0.99 * jnp.exp(1j * jnp.pi / 4)
coeffs = jnp.array([-2 * pole.real, pole.real**2 + pole.imag**2])
x = jax.random.normal(jax.random.PRNGKey(0), (1, 1000)) # (B, T)
A = jnp.tile(coeffs, (1, x.shape[-1], 1)) # (B, T, P)
zi = jnp.zeros((1, 2)) # (B, P)
# filter the signal
y = allpole(x, A, zi)
(1, 1000)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jaxpole-0.0.2.tar.gz
(9.5 kB
view hashes)