A diffrentiable kalman filter library for auto-tuning kalman filters.
Project description
Differentiable Kalman Filter
A PyTorch-based implementation of a differentiable Kalman Filter designed for both linear and non-linear dynamical systems with Gaussian noise. This module seamlessly integrates with neural networks, enabling learnable dynamics, observation, and noise models optimized through Stochastic Variational Inference (SVI).
Features
- Fully Differentiable: End-to-end differentiable implementation compatible with PyTorch's autograd
- Flexible Models: Support for both linear and non-linear state transition and observation models
- Neural Network Integration: Models can be parameterized using neural networks
- Automatic Jacobian Computation: Utilizes PyTorch's autograd for derivative calculations
- Monte Carlo Sampling: Supports evaluation of expected joint log-likelihood to perform Expectation-Maximization (EM) learning
- Rauch-Tung-Striebel Smoothing: Implements forward-backward smoothing for improved state estimation using RTS algorithm
Installation
pip install torch # Required dependency
# Add your package installation command here
Quick Start
Here's a simple example of using the Differentiable Kalman Filter:
import torch
from diffkalman import DifferentiableKalmanFilter
from diffkalman.utils import SymmetricPositiveDefiniteMatrix
from diffkalman.em_loop import em_updates
# Define custom state transition and observation functions
class StateTransition(torch.nn.Module):
def forward(self, x, *args):
# Your state transition logic here
return x
class ObservationModel(torch.nn.Module):
def forward(self, x, *args):
# Your observation logic here
return x
# Initialize the filter
f = StateTransition()
h = ObservationModel()
Q = SymmetricPositiveDefiniteMatrix(dim=4, trainable=True)
R = SymmetricPositiveDefiniteMatrix(dim=2, trainable=True)
kalman_filter = DifferentiableKalmanFilter(
dim_x=4, # State dimension
dim_z=2, # Observation dimension
f=f, # State transition function
h=h # Observation function
)
# Run the filter
results = kalman_filter.sequence_filter(
z_seq=observations, # Shape: (T, dim_z)
x0=initial_state, # Shape: (dim_x,)
P0=initial_covariance, # Shape: (dim_x, dim_x)
Q=Q().repeat(len(observations), 1, 1), # Shape: (T, dim_x, dim_x)
R=R().repeat(len(observations), 1, 1) # Shape: (T, dim_z, dim_z)
)
Detailed Usage
State Estimation
The module provides three main estimation methods:
- Filtering: Forward pass only
filtered_results = kalman_filter.sequence_filter(
z_seq=observations,
x0=initial_state,
P0=initial_covariance,
Q=process_noise,
R=observation_noise
)
- Smoothing: Forward-backward pass
smoothed_results = kalman_filter.sequence_smooth(
z_seq=observations,
x0=initial_state,
P0=initial_covariance,
Q=process_noise,
R=observation_noise
)
- Single-step Prediction: For real-time applications
step_result = kalman_filter.predict_update(
z=current_observation,
x=current_state,
P=current_covariance,
Q=process_noise,
R=observation_noise
)
Parameter Learning
The module supports learning model parameters through using backpropagation using the negative expected joint log-likelihood of the data as the loss function.
# Define optimizer
optimizer = torch.optim.Adam(params=[
{'params': kalman_filter.f.parameters()},
{'params': kalman_filter.h.parameters()},
{'params': Q.parameters()},
{'params': R.parameters()}
]
NUM_EPOCHS = 10
NUM_CYCLES = 10
# Run the EM loop
marginal_likelihoods = em_updates(
kalman_filter=kalman_filter,
z_seq=observations,
x0=initial_state,
P0=initial_covariance,
Q=Q,
R=R,
optimizer=optimizer,
num_cycles=NUM_CYCLES,
num_epochs=NUM_EPOCHS
)
API Reference
DifferentiableKalmanFilter
Main class implementing the Kalman Filter algorithm.
Constructor Parameters:
dim_x(int): State space dimensiondim_z(int): Observation space dimensionf(nn.Module): State transition functionh(nn.Module): Observation functionmc_samples(int, optional): Number of Monte Carlo samples for log-likelihood estimation
Key Methods:
predict: State prediction stepupdate: Measurement update steppredict_update: Combined prediction and updatesequence_filter: Full sequence filteringsequence_smooth: Full sequence smoothingmarginal_log_likelihood: Compute marginal log-likelihoodmonte_carlo_expected_joint_log_likekihood: Estimate expected joint log-likelihood
Requirements
- PyTorch >= 1.9.0
- Python >= 3.7
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
License
This project is licensed under the MIT License - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file diffkalman-0.1.1.tar.gz.
File metadata
- Download URL: diffkalman-0.1.1.tar.gz
- Upload date:
- Size: 67.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6ed07e7332d439e5b662b70d2cc4eb66cff25d45cf4164d53e8889dfe790ab83
|
|
| MD5 |
a334c88d82b52f9037e8129daeee5a0a
|
|
| BLAKE2b-256 |
b46681efd482ca734822ebc0f2cdb4e2bff203ba507047d474a78c91eaba8e28
|
File details
Details for the file diffkalman-0.1.1-py3-none-any.whl.
File metadata
- Download URL: diffkalman-0.1.1-py3-none-any.whl
- Upload date:
- Size: 14.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
68b8b6dcbfab89f20277c338851d14648307472f3698ea67d7242cfe1804f4d6
|
|
| MD5 |
44eea6ffbc4af1c38fac478f9cb3369e
|
|
| BLAKE2b-256 |
1320ee4f883396ef569c7c0641b54b8cbac4de53b2f63920dc2549d6a3e9aaa8
|