Skip to main content

Kalman Filter implementation with PyTorch

Project description

Torch-KF

License PyPi Python Downloads Codecov Lint and Test

torch-kf is a PyTorch implementation of classic Kalman filtering and smoothing, designed for batched processing of many independent signals. It supports filtering and Rauch-Tung-Striebel (RTS) smoothing, runs on CPU or GPU (via PyTorch), and natively handles batch dimensions without Python loops.

This project is inspired by Roger R. Labbe Jr.’s excellent work:

Currently, torch-kf focuses on traditional linear Kalman filters with Gaussian noise. In the future, it may extend to a wider range of filters (e.g. EKF, UKF, IMM, ...).


Why torch-kf?

Kalman filtering is inherently sequential in time and typically involves small matrices (often < 10×10 in physics-based models). As a result, a single Kalman filter does not benefit much from GPU acceleration and may even be faster with NumPy-based implementations such as filterpy.

However, many real-world problems involve filtering large batches of independent signals in parallel, such as:

  • multi-object tracking,
  • ensemble-based inference,
  • large-scale simulations,
  • batched time-series processing.

This is where torch-kf shines.

Key ideas

  • Batch-first design: filter hundreds or thousands of independent signals at once.
  • No Python loops over signals: computations are vectorized.
  • Automatic parallelization: PyTorch distributes work across multiple CPU cores or runs it on GPU.
  • Flexible broadcasting: states, measurements, and even models can be batched.

When many signals are filtered together, torch-kf can be orders of magnitude faster (typically up to 200× on CPU and 500×–1000× on GPU) compared to running independent filters sequentially as in filterpy.

[!WARNING] If you only need to filter a handful of signals (≈ fewer than 10), filterpy may still be faster due to PyTorch’s overhead on very small matrices.


Numerical considerations

[!WARNING] torch-kf runs in float32 by default and prioritizes speed over maximum numerical robustness. It uses fast update schemes and explicit matrix inverses, which are well-suited for small state dimensions but can be less stable in extreme cases.

If numerical stability becomes an issue, consider:

  • switching to float64, and
  • enabling joseph_update=True in KalmanFilter.

Installation

pip

pip install torch-kf

From source

git clone git@github.com:raphaelreme/torch-kf.git  # OR https://github.com/raphaelreme/torch-kf.git
cd torch-kf
pip install .

Getting started

import torch
from torch_kf import KalmanFilter, GaussianState

# Example: filtering 100 independent 2D trajectories over 1000 timesteps
# Measurements must be column vectors (..., dim, 1)
noisy_data = torch.randn(1000, 100, 2, 1)

# Initialize the Kalman filter model
# Constant-velocity model (dt = 1)
F = torch.tensor([  # Process matrix: # x_{t+1} = x_{t} + v_{t} * dt     (dt = 1)
    [1, 0, 1, 0],
    [0, 1, 0, 1],
    [0, 0, 1, 0],
    [0, 0, 0, 1],
], dtype=torch.float32)

Q = torch.eye(4) * 1.5**2

# Where only the position is measured, with some noise R
H = torch.tensor([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
], dtype=torch.float32)
R = torch.eye(2) * 3.0**2

kf = KalmanFilter(F, H, Q, R)

# Initial belief: zero position/velocity with large uncertainty
state = GaussianState(
    mean=torch.zeros(100, 4, 1),
    covariance=torch.eye(4)[None].expand(100, 4, 4) * 150**2,
)

# Filter all signals at once
states = kf.filter(
    state,
    noisy_data,
    update_first=True,
    return_all=True,
)

# states.mean:       (1000, 100, 4, 1)
# states.covariance: (1000, 100, 4, 4)

# Optional RTS smoothing
smoothed = kf.rts_smooth(states)
# smoothed.mean:       (1000, 100, 4, 1)
# smoothed.covariance: (1000, 100, 4, 4)


# Online filtering: process measure as they come
generator = ...  # Read measure from a file / sensor
for t, measure in enumerate(generator):
    # A prior on timestep t
    state = kf.predict(state)

    # Update with measure at time t
    state = kf.update(state, measure)

Tip: For standard motion models (constant velocity, acceleration, jerk), see torch_kf.ckf, which provides helpers to construct well-scaled F, H, Q, and R matrices.


Examples

The examples/ folder contains simple demonstrations of constant-velocity Kalman filters (1D, 2D, …) using batched signals.

Example: filtering and smoothing noisy sinusoidal trajectories with missing (NaN) measurements:

Sinusoidal position Sinusoidal velocity

We also benchmark torch-kf against filterpy to highlight when batched execution becomes advantageous:

Computational time

For small batch sizes, PyTorch overhead dominates. As the number of signals increases, torch-kf can provide 200× speedups on CPU and 500×+ on GPU.


Contributing

Contributions are very welcome! Feel free to open an issue or submit a pull request.

Many extensions of Kalman filtering and smoothing are not yet implemented (e.g. variants, adaptive models). For a more feature-complete reference, see filterpy.


Citation

This library was originally developed for large-scale object tracking in biology. If you use torch-kf in academic work, please cite:

@inproceedings{reme2024particle,
  title={Particle tracking in biological images with optical-flow enhanced kalman filtering},
  author={Reme, Raphael and Newson, Alasdair and Angelini, Elsa and Olivo-Marin, Jean-Christophe and Lagache, Thibault},
  booktitle={2024 IEEE International Symposium on Biomedical Imaging (ISBI)},
  pages={1--5},
  year={2024},
  organization={IEEE}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_kf-0.4.2.tar.gz (16.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_kf-0.4.2-py3-none-any.whl (18.2 kB view details)

Uploaded Python 3

File details

Details for the file torch_kf-0.4.2.tar.gz.

File metadata

  • Download URL: torch_kf-0.4.2.tar.gz
  • Upload date:
  • Size: 16.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for torch_kf-0.4.2.tar.gz
Algorithm Hash digest
SHA256 1b889c2e007db63b8b97d16cd3e37c436b4468f31019eb816b62858519290bab
MD5 d9fd39eedf38b3fca886f95cd24174b4
BLAKE2b-256 11948576676153708fc275743356b10066ed17a12997b16007cd9cc519fd3dd8

See more details on using hashes here.

File details

Details for the file torch_kf-0.4.2-py3-none-any.whl.

File metadata

  • Download URL: torch_kf-0.4.2-py3-none-any.whl
  • Upload date:
  • Size: 18.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for torch_kf-0.4.2-py3-none-any.whl
Algorithm Hash digest
SHA256 020480fc23260a2ee81bb4156af11ff1338a496764d7a9e6d2d119f2136aed7a
MD5 2ae7eaf6e65c6e01e33a3bd64b9e51b8
BLAKE2b-256 a8e007000d8ad39e3b1868d26bee589641ff5b4cbacba2eb71aa7890abcfa3f4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page