Skip to main content

Kalman based neural decoding in JAX

Project description

KalMax: Kalman based neural decoding in Jax

KalMax = Kalman smoothing of Maximum likelihood estimates in Jax.

You provide $\mathbf{S} \in \mathbb{N}^{T \times N}$ (spike counts) and $\mathbf{X} \in \mathbb{R}^{T \times D}$ (a continuous variable, e.g. position) and KalMax provides jax-optimised functions and classes for:

  1. Fitting rate maps using kernel density estimation (KDE)
  2. Calculating likelihood maps $p(\mathbf{s}_t|\mathbf{x})$
  3. Kalman filter / smoother

Why are these functionalities combined into one package?...

Because Likelihood Estimation + Kalman filtering = Powerful neural decoding. By Kalman filtering/smoothing the maximum likelihood estimates (as opposed to the spikes themselves) we bypass the issues of naive Kalman filters (spikes are rarely linearly related to position) and maximum likelihood decoding (which does not account for temporal continuity in the trajectory), outperforming both for no extra computational cost.

Core KalMax functions are optimised and jit-compiled in jax making them very fast. For example KalMax kalman filtering is >13 times faster than an equivalent numpy implementation by the popular pykalman library (see demo).

Install

pip install kalmax

Development install

git clone https://github.com/TomGeorge1234/KalMax.git
cd KalMax
pip install -e ".[dev]"   # installs with test/lint dependencies

To run tests and linting:

pytest                     # run test suite
ruff check src/            # lint
ruff format --check src/   # check formatting

Usage

A full demo Open In Colab is provided in examples/kalmax_demo.ipynb. Pseudo-code is provided below.

import kalmax 
import jax.numpy as jnp 
# 0. PREPARE DATA IN JAX ARRAYS
S_train = jnp.array(...) # (T, N_CELLS)      train spike counts
Z_train = jnp.array(...) # (T, DIMS)         train continuous variable
S_test  = jnp.array(...) # (T_TEST, N_CELLS) test spike counts
bins    = jnp.array(...) # (N_BINS, DIMS)    coordinates at which to estimate receptive fields / likelihoods)
# 1. FIT RECEPTIVE FIELDS using kalmax.kde
firing_rate = kalmax.kde.kde(
    bins = bins,
    trajectory = Z_train,
    spikes = S_train,
    kernel = kalmax.kernels.gaussian_kernel,
    kernel_bandwidth = 0.01,
    ) # --> (N_CELLS, N_BINS)
# 2.1 CALCULATE LIKELIHOODS using kalmax.poisson_log_likelihood
log_likelihoods = kalmax.kde.poisson_log_likelihood(
    spikes = S_test,                       
    mean_rate = firing_rate,
    ) # --> (T_TEST, N_BINS)

# 2.2 FIT GAUSSIAN TO LIKELIHOODS using kalmax.utils.fit_gaussian
MLE_means, MLE_modes, MLE_covs = kalmax.utils.fit_gaussian_vmap(
    x = bins, 
    likelihoods = jnp.exp(log_likelihoods),
    ) # --> (T_TEST, DIMS), (T_TEST, DIMS, DIMS)
# 3. KALMAN FILTER / SMOOTH using kalmax.KalmanFilter
kalman_filter = kalmax.KalmanFilter(
    dim_Z = DIMS,
    dim_Y = DIMS,
    # SEE DEMO FOR HOW TO FIT/SET THESE
    F=F, # state transition matrix
    Q=Q, # state noise covariance
    H=H, # observation matrix
    R=R, # observation noise covariance
    ) 

# [FILTER]
mus_f, sigmas_f = kalman_filter.filter(
    Y = Y, 
    mu0 = mu0,
    sigma0 = sigma0,
    ) # --> (T, DIMS), (T, DIMS, DIMS)

# [SMOOTH]
mus_s, sigmas_s = kalman_filter.smooth(
    mus_f = mus_f, 
    sigmas_f = sigmas_f,
    ) # --> (T, DIMS), (T, DIMS, DIMS)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kalmax-0.3.0.tar.gz (24.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

kalmax-0.3.0-py3-none-any.whl (17.7 kB view details)

Uploaded Python 3

File details

Details for the file kalmax-0.3.0.tar.gz.

File metadata

  • Download URL: kalmax-0.3.0.tar.gz
  • Upload date:
  • Size: 24.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.5

File hashes

Hashes for kalmax-0.3.0.tar.gz
Algorithm Hash digest
SHA256 02ddd649ca80ef43a1ddd04a1325be2649282c40bc8253e500a15dbfe5836f2e
MD5 e0cfc7512627dcfc261103aa141385b1
BLAKE2b-256 3904d09747f91a699d8a0f267312f56e67db7240369b8cea9195829803dd5be1

See more details on using hashes here.

File details

Details for the file kalmax-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: kalmax-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 17.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.5

File hashes

Hashes for kalmax-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b0b6a87551fca0dadc60e5095fa187abf16edb1191389ec9663930795c2bca75
MD5 bda8bfe0684d5eaceab44e5a5773e64f
BLAKE2b-256 a7b38352e925f1d8b9529b2415ea71bb7d1fa8587678b272f6def483ec1c071b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page