Skip to main content

A differentiable implementation of an all-pole filter in JAX

Project description

jaxpole

This is an implementation of a differentiable time-varying all-pole filter in JAX based on torchlpc.

Install

pip install jaxpole

or locally from source

pip install -e '.[dev]'

How to use

import jax.numpy as jnp
import jax

pole = 0.99 * jnp.exp(1j * jnp.pi / 4)
coeffs = jnp.array([-2 * pole.real, pole.real**2 + pole.imag**2])

x = jax.random.normal(jax.random.PRNGKey(0), (1, 1000)) # (B, T)
A = jnp.tile(coeffs, (1, x.shape[-1], 1)) # (B, T, P)
zi = jnp.zeros((1, 2)) # (B, P)

# filter the signal
y = allpole(x, A, zi)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxpole-0.0.3.tar.gz (9.7 kB view hashes)

Uploaded Source

Built Distribution

jaxpole-0.0.3-py3-none-any.whl (9.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page