Skip to main content

Kalman based neural decoding in Jax

Project description

KalMax: Kalman based neural decoding in Jax

KalMax = Kalman smoothing of Maximum likelihood estimates in Jax.

You provide $\mathbf{S} \in \mathbb{N}^{T \times N}$ (spike counts) and $\mathbf{X} \in \mathbb{R}^{T \times D}$ (a continuous variable, e.g. position) and KalMax provides jax-optimised functions and classes for:

  1. Fitting rate maps using kernel density estimation (KDE)
  2. Calculating likelihood maps $p(\mathbf{s}_t|\mathbf{x})$
  3. Kalman filter / smoother

Why are these functionalities combined into one package?...

Because Likelihood Estimation + Kalman filtering = Powerful neural decoding. By Kalman filtering/smoothing the maximum likelihood estimates (as opposed to the spikes themselves) we bypass the issues of naive Kalman filters (spikes are rarely linearly related to position) and maximum likelihood decoding (which does not account for temporal continuity in the trajectory), outperforming both for no extra computational cost.

Core KalMax functions are optimised and jit-compiled in jax making them very fast. For example KalMax kalman filtering is >13 times faster than an equivalent numpy implementation by the popular pykalman library (see demo).

Install

git clone https://github.com/TomGeorge1234/KalMax.git
cd KalMax
pip install -e .

(-e) is optional for developer install.

Alternatively

pip install git+https://github.com/TomGeorge1234/KalMax.git

Usage

A full demo Open In Colab is provided in the kalmax_demo.ipynb. Sudo-code is provided below.

import kalmax 
import jax.numpy as jnp 
# 0. PREPARE DATA IN JAX ARRAYS
S_train = jnp.array(...) # (T, N_CELLS)      train spike counts
Z_train = jnp.array(...) # (T, DIMS)         train continuous variable
S_test  = jnp.array(...) # (T_TEST, N_CELLS) test spike counts
bins    = jnp.array(...) # (N_BINS, DIMS)    coordinates at which to estimate receptive fields / likelihoods)
# 1. FIT RECEPTIVE FIELDS using kalmax.kde
firing_rate = kalmax.kde.kde(
    bins = bins,
    trajectory = Z_train,
    spikes = S_train,
    kernel = kalmax.kernels.gaussian_kernel,
    kernel_kwargs = {'covariance':0.01**2*np.eye(DIMS)}, # kernel bandwidth
    ) # --> (N_CELLS, N_BINS)
# 2.1 CALCULATE LIKELIHOODS using kalmax.poisson_log_likelihood
log_likelihoods = kalmax.kde.poisson_log_likelihood(
    spikes = S_test,                       
    mean_rate = firing_rate,
    ) # --> (T_TEST, N_CELLS)

# 2.2 FIT GAUSSIAN TO LIKELIHOODS using kalmax.utils.fit_gaussian
MLE_means, MLE_modes, MLE_covs = kalmax.utils.fit_gaussian_vmap(
    x = bins, 
    likelihoods = jnp.exp(log_likelihoods),
    ) # --> (T_TEST, DIMS), (T_TEST, DIMS, DIMS)
# 3. KALMAN FILTER / SMOOTH using kalmax.KalmanFilter.KalmanFilter
kalman_filter = kalmax.kalman.KalmanFilter(
    dim_Z = DIMS, 
    dim_Y = N_CELLS,
    # SEE DEMO FOR HOW TO FIT/SET THESE
    F=F, # state transition matrix
    Q=Q, # state noise covariance
    H=H, # observation matrix
    R=R, # observation noise covariance
    ) 

# [FILTER]
mus_f, sigmas_f = kalman_filter.filter(
    Y = Y, 
    mu0 = mu0,
    sigma0 = sigma0,
    ) # --> (T, DIMS), (T, DIMS, DIMS)

# [SMOOTH]
mus_s, sigmas_s = kalman_filter.smooth(
    mus_f = mus_f, 
    sigmas_f = sigmas_f,
    ) # --> (T, DIMS), (T, DIMS, DIMS)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kalmax-0.0.0.tar.gz (16.2 kB view details)

Uploaded Source

Built Distribution

kalmax-0.0.0-py3-none-any.whl (16.0 kB view details)

Uploaded Python 3

File details

Details for the file kalmax-0.0.0.tar.gz.

File metadata

  • Download URL: kalmax-0.0.0.tar.gz
  • Upload date:
  • Size: 16.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.5

File hashes

Hashes for kalmax-0.0.0.tar.gz
Algorithm Hash digest
SHA256 1482927282d20762af50f70e2d9c86e75d9c2a0d2a22eddcd04cfad8e385e5a7
MD5 45aa30fc3f5bf3218e8972dfab41ba7a
BLAKE2b-256 245b03151ba22357fe06789672050952520cb78ef5f8884a8a49cefa83c23dbd

See more details on using hashes here.

File details

Details for the file kalmax-0.0.0-py3-none-any.whl.

File metadata

  • Download URL: kalmax-0.0.0-py3-none-any.whl
  • Upload date:
  • Size: 16.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.5

File hashes

Hashes for kalmax-0.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 94b994a341201d3308a94a287a2607096521fbcc51dc68e00614c5815be776e3
MD5 ccf8405bf56f78961daae9b0013a8c60
BLAKE2b-256 4da4f861f94c0f1b82fce57fd666c7a835d4004e6f5ecfa4c260845f0c96166e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page