Skip to main content

A diffrentiable kalman filter library for auto-tuning kalman filters.

Project description

Differentiable Kalman Filter

A PyTorch-based implementation of a differentiable Kalman Filter designed for both linear and non-linear dynamical systems with Gaussian noise. This module seamlessly integrates with neural networks, enabling learnable dynamics, observation, and noise models optimized through Stochastic Variational Inference (SVI).

Features

  • Fully Differentiable: End-to-end differentiable implementation compatible with PyTorch's autograd
  • Flexible Models: Support for both linear and non-linear state transition and observation models
  • Neural Network Integration: Models can be parameterized using neural networks
  • Automatic Jacobian Computation: Utilizes PyTorch's autograd for derivative calculations
  • Monte Carlo Sampling: Supports evaluation of expected joint log-likelihood to perform Expectation-Maximization (EM) learning
  • Rauch-Tung-Striebel Smoothing: Implements forward-backward smoothing for improved state estimation using RTS algorithm

Installation

pip install torch  # Required dependency
# Add your package installation command here

Quick Start

Here's a simple example of using the Differentiable Kalman Filter:

import torch
from diffkalman import DifferentiableKalmanFilter
from diffkalman.utils import SymmetricPositiveDefiniteMatrix
from diffkalman.em_loop import em_updates

# Define custom state transition and observation functions
class StateTransition(torch.nn.Module):
    def forward(self, x, *args):
        # Your state transition logic here
        return x

class ObservationModel(torch.nn.Module):
    def forward(self, x, *args):
        # Your observation logic here
        return x

# Initialize the filter
f = StateTransition()
h = ObservationModel()
Q = SymmetricPositiveDefiniteMatrix(dim=4, trainable=True)
R = SymmetricPositiveDefiniteMatrix(dim=2, trainable=True)
kalman_filter = DifferentiableKalmanFilter(
    dim_x=4,  # State dimension
    dim_z=2,  # Observation dimension
    f=f,      # State transition function
    h=h       # Observation function
)

# Run the filter
results = kalman_filter.sequence_filter(
    z_seq=observations,           # Shape: (T, dim_z)
    x0=initial_state,            # Shape: (dim_x,)
    P0=initial_covariance,       # Shape: (dim_x, dim_x)
    Q=Q().repeat(len(observations), 1, 1),  # Shape: (T, dim_x, dim_x)
    R=R().repeat(len(observations), 1, 1)   # Shape: (T, dim_z, dim_z)
)

Detailed Usage

State Estimation

The module provides three main estimation methods:

  1. Filtering: Forward pass only
filtered_results = kalman_filter.sequence_filter(
    z_seq=observations,
    x0=initial_state,
    P0=initial_covariance,
    Q=process_noise,
    R=observation_noise
)
  1. Smoothing: Forward-backward pass
smoothed_results = kalman_filter.sequence_smooth(
    z_seq=observations,
    x0=initial_state,
    P0=initial_covariance,
    Q=process_noise,
    R=observation_noise
)
  1. Single-step Prediction: For real-time applications
step_result = kalman_filter.predict_update(
    z=current_observation,
    x=current_state,
    P=current_covariance,
    Q=process_noise,
    R=observation_noise
)

Parameter Learning

The module supports learning model parameters through using backpropagation using the negative expected joint log-likelihood of the data as the loss function.

# Define optimizer
optimizer = torch.optim.Adam(params=[
    {'params': kalman_filter.f.parameters()},
    {'params': kalman_filter.h.parameters()},
    {'params': Q.parameters()},
    {'params': R.parameters()}
]

NUM_EPOCHS = 10
NUM_CYCLES = 10

# Run the EM loop
marginal_likelihoods = em_updates(
    kalman_filter=kalman_filter,
    z_seq=observations,
    x0=initial_state,
    P0=initial_covariance,
    Q=Q,
    R=R,
    optimizer=optimizer,
    num_cycles=NUM_CYCLES,
    num_epochs=NUM_EPOCHS
) 
    

API Reference

DifferentiableKalmanFilter

Main class implementing the Kalman Filter algorithm.

Constructor Parameters:

  • dim_x (int): State space dimension
  • dim_z (int): Observation space dimension
  • f (nn.Module): State transition function
  • h (nn.Module): Observation function
  • mc_samples (int, optional): Number of Monte Carlo samples for log-likelihood estimation

Key Methods:

  • predict: State prediction step
  • update: Measurement update step
  • predict_update: Combined prediction and update
  • sequence_filter: Full sequence filtering
  • sequence_smooth: Full sequence smoothing
  • marginal_log_likelihood: Compute marginal log-likelihood
  • monte_carlo_expected_joint_log_likekihood: Estimate expected joint log-likelihood

Requirements

  • PyTorch >= 1.9.0
  • Python >= 3.7

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffkalman-0.1.1.tar.gz (67.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffkalman-0.1.1-py3-none-any.whl (14.8 kB view details)

Uploaded Python 3

File details

Details for the file diffkalman-0.1.1.tar.gz.

File metadata

  • Download URL: diffkalman-0.1.1.tar.gz
  • Upload date:
  • Size: 67.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.14

File hashes

Hashes for diffkalman-0.1.1.tar.gz
Algorithm Hash digest
SHA256 6ed07e7332d439e5b662b70d2cc4eb66cff25d45cf4164d53e8889dfe790ab83
MD5 a334c88d82b52f9037e8129daeee5a0a
BLAKE2b-256 b46681efd482ca734822ebc0f2cdb4e2bff203ba507047d474a78c91eaba8e28

See more details on using hashes here.

File details

Details for the file diffkalman-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: diffkalman-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 14.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.14

File hashes

Hashes for diffkalman-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 68b8b6dcbfab89f20277c338851d14648307472f3698ea67d7242cfe1804f4d6
MD5 44eea6ffbc4af1c38fac478f9cb3369e
BLAKE2b-256 1320ee4f883396ef569c7c0641b54b8cbac4de53b2f63920dc2549d6a3e9aaa8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page