Skip to main content

Kalman Filter implementation with PyTorch

Project description

torch-kf

Lint and Test

PyTorch implementation of Kalman filters. It supports filtering of batch of signals, runs on gpu (supported by PyTorch) or multiple cpus.

This is based on rlabbe's filterpy and interactive book on kalman filters. Currently only traditional Kalman filters are implemented without any smoothing.

This implementation is designed for use-cases with multiple signals to filter. By construction, the Kalman filter computations are sequentials and cannot be parallelize, and usually involve quite small matrices (for physic-based system, the state is usually restricted to less than 10 dimensions), which cannot benefits from gpu/cpus parallelization. This is not true when there are multiples signals to filter in // (or multiple filters to run in //), which happens quite often.

torch-kf natively supports batch computations of Kalman filters (no need to loop on your batch of signals). Moreover, thanks to PyTorch, it distributes the computations automatically on your cpus, or is able to run on gpu. It is therefore much faster (up to 1000 time faster) when batch of signals are involved. If you have less than 10 signals to filter, filterpy will still by faster (up to 10 times faster for a single signal) because PyTorch has a huge overhead when small matrices are involved.

This implementation is quite simple but not so much user friendly for people not familiar with PyTorch (or numpy) broadcasting rules. We highly recommend that you read about broadcasting before trying to use this library.

[!WARNING] torch-kf is running by default in float32 and is implemented with the fastest but sadly not the more stable numerical scheme. We did not face any real issue yet, but be aware that this may become one for some use-cases.

Install

Pip

$ pip install torch-kf

Conda

Not yet available

Getting started

import torch
from torch_kf import KalmanFilter, GaussianState

# Some noisy_data to filter
# 1000 timessteps, 100 signals, 2D and an additional dimension to have vertical vectors (required for correct matmult)
noisy_data = torch.randn(1000, 100, 2, 1)

# Create a Kalman Filter (for instance a constant velocity filter) (See this is fully implemented, or rlabbe's book)
F = torch.tensor([  # x_{t+1} = x_{t} + v_{t} * dt     (dt = 1)
    [1, 0, 1, 0],
    [0, 1, 0, 1],
    [0, 0, 1, 0],
    [0, 0, 0, 1],
])
Q = torch.eye(4) * 1.5 **2  # 1.5 std on both pos and velocity (See full implementation to build a better Q)
H = torch.tensor([  # Only x and y are measured
    [1, 0, 0, 0],
    [0, 1, 0, 0],
])
R = torch.eye(2) * 3**2

kf = KalmanFilter(F, H, Q, R)

# Create an inital belief for each signal
# For instance let's start from 0 pos and 0 vel with a huge uncertainty
state = GaussianState(
    torch.zeros_like(100, 4, 1),  # Shape (100, 4, 1)
    torch.eye(4)[None].expand(100, 4, 4) * 150**2,  # Shape (100, 4, 4)
)

# And let's filter and save our signals all at once

filtered_data = torch.empty_like(noisy_data)

for t, measure in enumerate(noisy_data):  # Update first and then predict in this case
    # Update with measure at time t
    state = kf.update(state, measure)

    # Save state at time t
    filtered_data[t] = state.mean

    # Predict for t + 1
    state = kf.predict(state)

Examples

We provide a simple example of constant velocity kalman filter (1d, 2d, ...) in the example folder using batch of signals and show when our implementation is worth to use.

On a laptop with pretty good cpus and a GPU (a bit rusty), we have typically these performances:

Computational time

One can see that both cpus and gpu version have a large overhead when the batch is small. But they may lead to a 200x speed up or more when numerous signals are filtered together.

Contribute

Please feel free to open a PR or an issue at any time.

Many variants of Kalman filtering/smoothing are still missing and the documentation is pretty poor, in comparison filterpy is a much more complete library and may give some ideas of what is missing.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_kf-0.0.1.tar.gz (10.0 kB view hashes)

Uploaded Source

Built Distribution

torch_kf-0.0.1-py3-none-any.whl (8.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page