Kalman based neural decoding in Jax
Project description
KalMax: Kalman based neural decoding in Jax
KalMax = Kalman smoothing of Maximum likelihood estimates in Jax.
You provide $\mathbf{S} \in \mathbb{N}^{T \times N}$ (spike counts) and $\mathbf{X} \in \mathbb{R}^{T \times D}$ (a continuous variable, e.g. position) and KalMax
provides jax-optimised functions and classes for:
- Fitting rate maps using kernel density estimation (KDE)
- Calculating likelihood maps $p(\mathbf{s}_t|\mathbf{x})$
- Kalman filter / smoother
Why are these functionalities combined into one package?...
Because Likelihood Estimation + Kalman filtering = Powerful neural decoding. By Kalman filtering/smoothing the maximum likelihood estimates (as opposed to the spikes themselves) we bypass the issues of naive Kalman filters (spikes are rarely linearly related to position) and maximum likelihood decoding (which does not account for temporal continuity in the trajectory), outperforming both for no extra computational cost.
Core KalMax
functions are optimised and jit-compiled in jax making them very fast. For example KalMax
kalman filtering is >13 times faster than an equivalent numpy implementation by the popular pykalman
library (see demo).
Install
git clone https://github.com/TomGeorge1234/KalMax.git
cd KalMax
pip install -e .
(-e
) is optional for developer install.
Alternatively
pip install git+https://github.com/TomGeorge1234/KalMax.git
Usage
A full demo is provided in the kalmax_demo.ipynb
. Sudo-code is provided below.
import kalmax
import jax.numpy as jnp
# 0. PREPARE DATA IN JAX ARRAYS
S_train = jnp.array(...) # (T, N_CELLS) train spike counts
Z_train = jnp.array(...) # (T, DIMS) train continuous variable
S_test = jnp.array(...) # (T_TEST, N_CELLS) test spike counts
bins = jnp.array(...) # (N_BINS, DIMS) coordinates at which to estimate receptive fields / likelihoods)
# 1. FIT RECEPTIVE FIELDS using kalmax.kde
firing_rate = kalmax.kde.kde(
bins = bins,
trajectory = Z_train,
spikes = S_train,
kernel = kalmax.kernels.gaussian_kernel,
kernel_kwargs = {'covariance':0.01**2*np.eye(DIMS)}, # kernel bandwidth
) # --> (N_CELLS, N_BINS)
# 2.1 CALCULATE LIKELIHOODS using kalmax.poisson_log_likelihood
log_likelihoods = kalmax.kde.poisson_log_likelihood(
spikes = S_test,
mean_rate = firing_rate,
) # --> (T_TEST, N_CELLS)
# 2.2 FIT GAUSSIAN TO LIKELIHOODS using kalmax.utils.fit_gaussian
MLE_means, MLE_modes, MLE_covs = kalmax.utils.fit_gaussian_vmap(
x = bins,
likelihoods = jnp.exp(log_likelihoods),
) # --> (T_TEST, DIMS), (T_TEST, DIMS, DIMS)
# 3. KALMAN FILTER / SMOOTH using kalmax.KalmanFilter.KalmanFilter
kalman_filter = kalmax.kalman.KalmanFilter(
dim_Z = DIMS,
dim_Y = N_CELLS,
# SEE DEMO FOR HOW TO FIT/SET THESE
F=F, # state transition matrix
Q=Q, # state noise covariance
H=H, # observation matrix
R=R, # observation noise covariance
)
# [FILTER]
mus_f, sigmas_f = kalman_filter.filter(
Y = Y,
mu0 = mu0,
sigma0 = sigma0,
) # --> (T, DIMS), (T, DIMS, DIMS)
# [SMOOTH]
mus_s, sigmas_s = kalman_filter.smooth(
mus_f = mus_f,
sigmas_f = sigmas_f,
) # --> (T, DIMS), (T, DIMS, DIMS)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file kalmax-0.0.0.tar.gz
.
File metadata
- Download URL: kalmax-0.0.0.tar.gz
- Upload date:
- Size: 16.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1482927282d20762af50f70e2d9c86e75d9c2a0d2a22eddcd04cfad8e385e5a7 |
|
MD5 | 45aa30fc3f5bf3218e8972dfab41ba7a |
|
BLAKE2b-256 | 245b03151ba22357fe06789672050952520cb78ef5f8884a8a49cefa83c23dbd |
File details
Details for the file kalmax-0.0.0-py3-none-any.whl
.
File metadata
- Download URL: kalmax-0.0.0-py3-none-any.whl
- Upload date:
- Size: 16.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 94b994a341201d3308a94a287a2607096521fbcc51dc68e00614c5815be776e3 |
|
MD5 | ccf8405bf56f78961daae9b0013a8c60 |
|
BLAKE2b-256 | 4da4f861f94c0f1b82fce57fd666c7a835d4004e6f5ecfa4c260845f0c96166e |