Kalman Filter implementation with PyTorch
Project description
Torch-KF
torch-kf is a PyTorch implementation of classic Kalman filtering and smoothing, designed for batched processing of many independent signals. It supports filtering and Rauch-Tung-Striebel (RTS) smoothing, runs on CPU or GPU (via PyTorch), and natively handles batch dimensions without Python loops.
This project is inspired by Roger R. Labbe Jr.’s excellent work:
- filterpy
- Kalman and Bayesian Filters in Python (interactive book)
Currently, torch-kf focuses on traditional linear Kalman filters with Gaussian noise. In the future, it may extend to a wider range of filters (e.g. EKF, UKF, IMM, ...).
Why torch-kf?
Kalman filtering is inherently sequential in time and typically involves small matrices (often < 10×10 in physics-based models). As a result, a single Kalman filter does not benefit much from GPU acceleration and may even be faster with NumPy-based implementations such as filterpy.
However, many real-world problems involve filtering large batches of independent signals in parallel, such as:
- multi-object tracking,
- ensemble-based inference,
- large-scale simulations,
- batched time-series processing.
This is where torch-kf shines.
Key ideas
- Batch-first design: filter hundreds or thousands of independent signals at once.
- No Python loops over signals: computations are vectorized.
- Automatic parallelization: PyTorch distributes work across multiple CPU cores or runs it on GPU.
- Flexible broadcasting: states, measurements, and even models can be batched.
When many signals are filtered together, torch-kf can be orders of magnitude faster (typically up to 200× on CPU and 500×–1000× on GPU) compared to running independent filters sequentially as in filterpy.
[!WARNING] If you only need to filter a handful of signals (≈ fewer than 10),
filterpymay still be faster due to PyTorch’s overhead on very small matrices.
Numerical considerations
[!WARNING] torch-kf runs in
float32by default and prioritizes speed over maximum numerical robustness. It uses fast update schemes and explicit matrix inverses, which are well-suited for small state dimensions but can be less stable in extreme cases.If numerical stability becomes an issue, consider:
- switching to
float64, and- enabling
joseph_update=TrueinKalmanFilter.
Installation
pip
pip install torch-kf
From source
git clone git@github.com:raphaelreme/torch-kf.git # OR https://github.com/raphaelreme/torch-kf.git
cd torch-kf
pip install .
Getting started
import torch
from torch_kf import KalmanFilter, GaussianState
# Example: filtering 100 independent 2D trajectories over 1000 timesteps
# Measurements must be column vectors (..., dim, 1)
noisy_data = torch.randn(1000, 100, 2, 1)
# Initialize the Kalman filter model
# Constant-velocity model (dt = 1)
F = torch.tensor([ # Process matrix: # x_{t+1} = x_{t} + v_{t} * dt (dt = 1)
[1, 0, 1, 0],
[0, 1, 0, 1],
[0, 0, 1, 0],
[0, 0, 0, 1],
], dtype=torch.float32)
Q = torch.eye(4) * 1.5**2
# Where only the position is measured, with some noise R
H = torch.tensor([
[1, 0, 0, 0],
[0, 1, 0, 0],
], dtype=torch.float32)
R = torch.eye(2) * 3.0**2
kf = KalmanFilter(F, H, Q, R)
# Initial belief: zero position/velocity with large uncertainty
state = GaussianState(
mean=torch.zeros(100, 4, 1),
covariance=torch.eye(4)[None].expand(100, 4, 4) * 150**2,
)
# Filter all signals at once
states = kf.filter(
state,
noisy_data,
update_first=True,
return_all=True,
)
# states.mean: (1000, 100, 4, 1)
# states.covariance: (1000, 100, 4, 4)
# Optional RTS smoothing
smoothed = kf.rts_smooth(states)
# smoothed.mean: (1000, 100, 4, 1)
# smoothed.covariance: (1000, 100, 4, 4)
# Online filtering: process measure as they come
generator = ... # Read measure from a file / sensor
for t, measure in enumerate(generator):
# A prior on timestep t
state = kf.predict(state)
# Update with measure at time t
state = kf.update(state, measure)
Tip: For standard motion models (constant velocity, acceleration, jerk), see
torch_kf.ckf, which provides helpers to construct well-scaled F, H, Q, and R matrices.
Examples
The examples/ folder contains simple demonstrations of constant-velocity Kalman filters (1D, 2D, …) using batched signals.
Example: filtering and smoothing noisy sinusoidal trajectories with missing (NaN) measurements:
We also benchmark torch-kf against filterpy to highlight when batched execution becomes advantageous:
For small batch sizes, PyTorch overhead dominates. As the number of signals increases, torch-kf can provide 200× speedups on CPU and 500×+ on GPU.
Contributing
Contributions are very welcome! Feel free to open an issue or submit a pull request.
Many extensions of Kalman filtering and smoothing are not yet implemented (e.g. variants, adaptive models). For a more feature-complete reference, see filterpy.
Citation
This library was originally developed for large-scale object tracking in biology. If you use torch-kf in academic work, please cite:
@inproceedings{reme2024particle,
title={Particle tracking in biological images with optical-flow enhanced kalman filtering},
author={Reme, Raphael and Newson, Alasdair and Angelini, Elsa and Olivo-Marin, Jean-Christophe and Lagache, Thibault},
booktitle={2024 IEEE International Symposium on Biomedical Imaging (ISBI)},
pages={1--5},
year={2024},
organization={IEEE}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_kf-0.4.2.tar.gz.
File metadata
- Download URL: torch_kf-0.4.2.tar.gz
- Upload date:
- Size: 16.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1b889c2e007db63b8b97d16cd3e37c436b4468f31019eb816b62858519290bab
|
|
| MD5 |
d9fd39eedf38b3fca886f95cd24174b4
|
|
| BLAKE2b-256 |
11948576676153708fc275743356b10066ed17a12997b16007cd9cc519fd3dd8
|
File details
Details for the file torch_kf-0.4.2-py3-none-any.whl.
File metadata
- Download URL: torch_kf-0.4.2-py3-none-any.whl
- Upload date:
- Size: 18.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
020480fc23260a2ee81bb4156af11ff1338a496764d7a9e6d2d119f2136aed7a
|
|
| MD5 |
2ae7eaf6e65c6e01e33a3bd64b9e51b8
|
|
| BLAKE2b-256 |
a8e007000d8ad39e3b1868d26bee589641ff5b4cbacba2eb71aa7890abcfa3f4
|