Skip to main content

Sensor-fused 2D rat tracking with JAX EKF/UKF for SpikeGadgets/Trodes

Project description

trodestrack

Sensor-fused 2D rat tracking with JAX EKF/UKF for neuroscience research

trodestrack combines video tracking (Trodes LEDs and/or DeepLabCut keypoints) with IMU data from SpikeGadgets headstages to provide accurate position, velocity, and heading estimates for freely-moving rats on behavioral mazes.

Features

  • Sensor Fusion: Extended Kalman Filter (EKF) and Unscented Kalman Filter (UKF) for combining video (~30 Hz) and IMU (100 Hz) measurements
  • 3D IMU Support: Full 6-axis IMU processing (3-axis gyro + 3-axis accel) with gravity compensation
  • Online & Offline Processing: Forward-only EKF and RTS smoothing — both run as batch operations over complete input arrays. The trodestrack online CLI is forward-only ("no future-frame dependence"), not a streaming / real-time ingest loop.
  • Robust Handling: Occlusions, reflections, and camera/sensor dropout. Config-driven real-data runs can apply persistent LED identity correction before filtering and fail fast when IMU calibration or fused trajectories look implausible.
  • JAX-Accelerated: High-performance JIT-compiled JAX implementation. The throughput-floor benchmarks (≥10× realtime offline on CPU, ≤33 ms amortized mean per frame online on a 30-minute session) live in tests/benchmark/test_throughput.py and are not run on every PR — invoke them locally with JAX_PLATFORMS=cpu uv run pytest -m benchmark (the CPU pin is required; the benchmark errors on accelerator backends to keep the documented floors meaningful). Reference run on an M-series Mac CPU under the corrected (block-until-ready) timing: ~38× realtime / ~0.41 ms per frame; absolute throughput is hardware-dependent.
  • Rich Simulation: Comprehensive synthetic data generation for testing and validation
  • Diagnostic Visualization: Publication-quality video output for quality control

Hardware Compatibility

Supported IMU Hardware:

trodestrack is designed for SpikeGadgets headstages with integrated 6-axis IMU sensors.

Official Hardware Specifications (source: SpikeGadgets Product Manual):

  • 3-axis accelerometer: ±2g range, 16-bit signed (0.000061g per LSB)
  • 3-axis gyroscope: ±2000 deg/s range, 16-bit signed (0.061 deg/s per LSB)
  • Sensor refresh rate: 104 Hz (when both sensors enabled)
  • Internal sampling: 500 Hz per sensor (both enabled), 1 kHz (single sensor)
  • Output format: Sample-and-hold repeats expand 104 Hz data to ~20-30 kHz nominal rate

Data Processing:

  • Preprocessing removes sample-and-hold duplicates → ~100 Hz effective rate
  • Timestamp-based integration handles variable sampling rates
  • Compatible with both 2D IMU mode (gyro-Z, accel-XY) and full 3D mode (all 6 axes)

Simulation Defaults:

  • All synthetic data generation uses realistic SpikeGadgets specifications
  • IMU rate: 104 Hz (matches hardware sensor refresh rate)
  • Noise levels: 0.01 °/s/√Hz gyro, 0.2 mg/√Hz accel (per SpikeGadgets spec)
  • Ensures simulations accurately predict real-world performance

Video Tracking:

  • Trodes LED detection (dual LED setup for heading)
  • DeepLabCut keypoint tracking (any pose estimation output)
  • Camera rate: typically 30 Hz (configurable)

Installation

Requirements

  • Python ≥ 3.11
  • uv package manager

Install from source

git clone https://github.com/edeno/trodestrack.git
cd trodestrack
uv sync

Quick Start

1. Generate and Filter Synthetic Data (3 minutes)

The fastest way to understand TrodesTrack is to run the EKF example on synthetic data:

# Clone and setup
git clone https://github.com/edeno/trodestrack.git
cd trodestrack
uv sync

# Run EKF on basic scenarios (stationary, constant velocity, circular)
uv run python examples/03_ekf_basic_scenarios.py

This generates 3 diagnostic PNGs showing filter performance, bias convergence, and NEES consistency checks. Key insight: gyro bias is only observable during rotation!

2. Compare EKF vs UKF

uv run python examples/04_ukf_basic_scenarios.py

Compares sigma-point (UKF) vs Jacobian (EKF) approaches. Verdict: EKF wins 6/9 metrics (UKF: 3/9). Under JIT-compiled JAX with warm dispatch, the wall-clock cost is comparable on these scenarios; on backends without JIT (per-step Python loops) UKF can be several times slower. Start with EKF and re-measure on your target backend.

3. Test Dropout Robustness

uv run python examples/05_ekf_with_dropouts.py
uv run python examples/06_ukf_with_dropouts.py

Simulates 10%, 20%, and 30% camera dropout to stress-test IMU-only periods.

4. Use Smoothing for Offline Analysis

uv run python examples/07_smoother_demonstration.py

Shows how backward RTS smoothing achieves 3× drift reduction on 5-second dropout by using future observations.

5. Generate QA Reports

uv run python examples/08_qa_report_generation.py

Creates a publication-quality PDF with the full set of accuracy metrics, NEES/NIS consistency checks, and time-series plots.

Real Data With a YAML Config

For SpikeGadgets/Trodes-style real data, use a session YAML instead of long per-file CLI flags:

inputs:
  format: spikegadgets_trodes
  imu_file: path/to/imu.parquet
  position_file: path/to/position.parquet
camera:
  meters_per_pixel: 0.0022
filter:
  state_mode: 2d_cam_6dof_imu_orientation
  enable_experimental_accel_translation: false
  use_gravity_orientation_update: true
  use_mahalanobis_gating: false
outputs:
  output_dir: runs/session_001
led_identity:
  mode: auto
  initial_state: auto

Run forward filtering or offline smoothing:

uv run trodestrack online --config session.yaml
uv run trodestrack smooth --config session.yaml

The config loader supports prepared text arrays and SpikeGadgets IMU parquet plus Trodes dual-LED parquet. Real-data IMU-fused runs write loader/calibration diagnostics and run a vision-only plausibility check before accepting fused output; this roughly doubles filter runtime when outputs.run_safety_checks: true. The safety check gates trajectory envelope, speed, and fused-vs-vision position deviation. It requires enough dual-LED frames to estimate the camera midpoint envelope (outputs.safety_min_dual_led_frames, default 20), while single-LED frames still contribute to fused-vs-vision deviation checks. Accelerometer-driven translation also requires stationary gravity to have a small horizontal component and camera/IMU acceleration axes to correlate above the configured thresholds. For tilted headstages, start with 2d_cam_6dof_imu_orientation and leave accelerometer-driven translation disabled until the safety check passes. This validated default fuses 6-DOF IMU orientation with camera position; it does not claim accelerometer-driven position integration.

See examples/session_spikegadgets_trodes.yaml for a runnable template with the real-data safety and LED-identity options spelled out. Set led_identity.initial_state: original or swapped when you know the first valid dual-LED frame's label convention; auto cannot infer a global all-session label reversal from continuity alone.

Python API Examples

Generate synthetic data

from trodestrack.sim.rat_imu import RatIMUSimConfig, simulate_rat_imu

# Default config matches SpikeGadgets hardware (104 Hz IMU, realistic noise)
config = RatIMUSimConfig(duration_s=10.0)
sim = simulate_rat_imu(config, seed=42)  # seed is an arg of simulate_rat_imu

Run EKF filter

from trodestrack.models.ekf import extended_kalman_filter, EKFConfig

cfg = EKFConfig()
result = extended_kalman_filter(
    cfg,
    sim["t_imu"],
    sim["U_imu"],
    sim["t_cam_exp"],
    sim["Z_cam_led1"],
    sim["Z_cam_led2"],
    sim["mask_cam"],
)
# result.filtered_means: (N_cam, n_state)
# result.filtered_covariances: (N_cam, n_state, n_state)

Working with State Layouts (Recommended Pattern)

TrodesTrack uses an explicit state layout system to eliminate hardcoded dimension assumptions and support multiple tracking modes (5D, 8D, 10D, 14D, 15D, 16D states). Always use state layouts instead of magic indices like [:, 0:2].

import numpy as np

from trodestrack.models.ekf import extended_kalman_filter, EKFConfig
from trodestrack.models.state_layout import get_layout
from trodestrack.sim.simple import simulate_circular, SimpleSimConfig

# Generate simulation and run filter
sim_config = SimpleSimConfig(duration_s=10.0)
sim = simulate_circular(sim_config)
ekf_config = EKFConfig()
result = extended_kalman_filter(
    ekf_config,
    sim["t_imu"],
    sim["U_imu"],
    sim["t_cam_exp"],
    sim["Z_cam_led1"],
    sim["Z_cam_led2"],
    sim["mask_cam"],
)

# Get state layout from filter config (BEST PRACTICE!)
# EKFConfig defaults to "2d_cam_3d_imu" (10D: [x, y, vx, vy, vz, theta, biases...]).
layout = get_layout(ekf_config.state_mode)

# ✅ GOOD: Extract states using layout indices (dimension-agnostic)
positions = result.filtered_means[:, layout.pos_idx]      # (N, 2) for 2D layouts, (N, 3) for 3D
velocities = result.filtered_means[:, layout.vel_idx]     # (N, 3) for 2d_cam_3d_imu, (N, 2) for 2d_full
# Heading shape depends on layout: scalar yaw for 2D layouts, 3-tuple Euler
# for 3d_euler, 4-tuple quaternion for 3d_quat / 3d_cam_6dof_imu /
# 2d_cam_6dof_imu_orientation. Guard before treating as a 1D angle.
heading_block = result.filtered_means[:, layout.heading_idx]
if layout.has_heading_2d:
    headings = heading_block.squeeze(-1) if heading_block.ndim == 2 else heading_block
    # headings: (N,) yaw in radians
else:
    headings = heading_block  # (N, 3) Euler or (N, 4) quaternion components

# ❌ BAD: Hardcoded indices (breaks when switching state modes!)
# positions = result.filtered_means[:, 0:2]  # Fragile! Don't do this!

# Extract uncertainties (covariances) using layout indices
P = result.filtered_covariances                           # (N, layout.n, layout.n)
pos_cov = P[:, layout.pos_idx, :][:, :, layout.pos_idx] # (N, 2, 2) position covariance
pos_std = np.sqrt(np.diagonal(pos_cov, axis1=1, axis2=2)) # (N, 2) position uncertainty

# Plot position with ±2σ uncertainty bands
import matplotlib.pyplot as plt
t = sim['t_cam_exp']
plt.plot(t, positions[:, 0], label='x')
plt.fill_between(t,
                 positions[:, 0] - 2*pos_std[:, 0],
                 positions[:, 0] + 2*pos_std[:, 0],
                 alpha=0.3, label='±2σ')
plt.xlabel('Time (s)')
plt.ylabel('X Position (m)')
plt.legend()
plt.show()

Available State Layouts:

Layout String Dimensions State Vector Use Case
"2d_cam_3d_imu" 10D [x, y, vx, vy, vz, θ, b_gz, b_ax, b_ay, b_az] Default: 2D camera with 3D accel (detect rearing)
"2d_full" 8D [x, y, vx, vy, θ, b_gz, b_ax, b_ay] Standard 2D sensor fusion (camera + 2-axis IMU)
"vision_only" 5D [x, y, vx, vy, θ] Camera-driven tracking, no IMU integration (the public APIs still require placeholder IMU timestamps/measurements; see docs/user-guide/state-layouts.md)
"2d_cam_6dof_imu_orientation" 14D [x, y, vx, vy, qw, qx, qy, qz, b_gx, b_gy, b_gz, b_ax, b_ay, b_az] Experimental: 2D camera + 6-DOF IMU with quaternion orientation
"3d_euler" 15D [x, y, z, vx, vy, vz, roll, pitch, yaw, b_gx, b_gy, b_gz, b_ax, b_ay, b_az] State vector for 3D pose with Euler-angle orientation. No public entry point today — the 2D extended_kalman_filter rejects 15D states and the 3D path requires 3d_cam_6dof_imu.
"3d_quat" 16D [x, y, z, vx, vy, vz, qw, qx, qy, qz, b_gx, b_gy, b_gz, b_ax, b_ay, b_az] State vector for 3D pose with quaternion orientation. UKF rejects quaternion layouts; consume via the experimental extended_kalman_filter_3d using "3d_cam_6dof_imu" (same vector, distinct registration).
"3d_cam_6dof_imu" 16D same as "3d_quat" Required by the experimental extended_kalman_filter_3d entry point (3D LED observations + 6-channel IMU). This is the only registered mode that flows end-to-end through a 3D camera filter today.

Why use state layouts?

  1. Dimension-agnostic code: Works with 5D, 8D, 10D, 14D, 15D, 16D states without modification
  2. Self-documenting: layout.pos_idx is clearer than [:, 0:2]
  3. Robust to changes: Switching state modes doesn't break your analysis code
  4. Matches internal implementation: Filters use the same layout system

See src/trodestrack/models/state_layout.py for full API documentation.

Generate QA report

import numpy as np
from trodestrack.models.state_layout import get_layout
from trodestrack.qa.report import generate_qa_report
from trodestrack.qa.metrics import compute_nees

layout = get_layout(ekf_config.state_mode)

# Align ground truth (IMU rate, 5D [x, y, vx, vy, theta]) to camera frames.
X_truth_at_cam = np.array(
    [sim["X_truth"][np.argmin(np.abs(sim["t_imu"] - t_c))] for t_c in sim["t_cam_exp"]]
)
filtered = np.asarray(result.filtered_means)
filtered_cov = np.asarray(result.filtered_covariances)

# Layout-aware indexing handles every scalar-heading state mode.
pos_idx = list(layout.pos_idx)
vel_idx_2d = list(layout.vel_idx)[:2]  # X_truth has only vx, vy
heading_col = int(layout.heading_idx)

# Position-only NEES (state_dim=2): expected mean ~ 2 for a consistent filter.
nees = compute_nees(
    states_true=X_truth_at_cam[:, :2],
    states_est=filtered[:, pos_idx],
    covariances_est=filtered_cov[np.ix_(np.arange(filtered.shape[0]), pos_idx, pos_idx)],
)
generate_qa_report(
    pdf_path="report.pdf",
    t=sim["t_cam_exp"],
    positions_true=X_truth_at_cam[:, :2],
    positions_est=filtered[:, pos_idx],
    velocities_true=X_truth_at_cam[:, 2:4],
    velocities_est=filtered[:, vel_idx_2d],
    headings_true=X_truth_at_cam[:, 4],
    headings_est=filtered[:, heading_col],
    nees=nees,
    state_dim=2,
)

The CLI trodestrack report wraps this; see src/trodestrack/cli/report.py for the full call site.

Explore All Examples

See examples/README.md for the complete learning path. Examples are numbered to teach concepts progressively:

  • 01-02: Simulation fundamentals
  • 03-04: Filter basics (EKF and UKF)
  • 05-06: Robustness (dropouts and occlusions)
  • 07: Smoothing techniques
  • 08: QA reporting

Status

See CHANGELOG.md for completed features and release history; see docs/plans/ for active work.

Documentation

User Documentation

Developer Documentation

Development

Run tests

uv run pytest tests/ -v

Code quality

# Type checking
uv run mypy src/trodestrack --ignore-missing-imports

# Linting
uv run ruff check src/ tests/

# Formatting (Ruff is the project's formatter; see pyproject.toml [tool.ruff])
uv run ruff format src/ tests/

Development commands

See CLAUDE.md for complete list of development commands and project architecture.

Architecture

trodestrack/
  sim/          # Simulation: analytic scenarios + realistic rat IMU
  models/       # EKF, UKF, RTS/IEKS smoothers, state initialization
  runtime/      # Online filter API + offline smoother workflows
  qa/           # Metrics (RMSE, NEES, NIS), plots, PDF reports
  viz/          # Diagnostic videos with multi-panel state visualization
  cli/          # CLI: trodestrack online / smooth / report
  io/           # Session loading, real-data safety checks, LED identity correction
  config/       # YAML session configuration schemas

Contributing

This project follows strict test-driven development (TDD) practices:

  1. Write tests first
  2. Run tests and verify they fail
  3. Implement features
  4. Run tests until they pass
  5. Refactor for clarity

See CLAUDE.md for development guidelines and code style requirements.

Citation

If you use trodestrack in your research, please cite:

@software{trodestrack2025,
  title={trodestrack: Sensor-fused 2D rat tracking with JAX EKF/UKF},
  author={Your Name},
  year={2025},
  url={https://github.com/edeno/trodestrack}
}

License

MIT License - see LICENSE file for details.

Acknowledgments

  • SpikeGadgets for hardware specifications
  • DeepLabCut team for pose estimation framework
  • JAX team for high-performance numerical computing

Contact

For questions, issues, or feature requests, please open an issue on GitHub.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trodestrack-0.3.0.tar.gz (7.7 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trodestrack-0.3.0-py3-none-any.whl (257.6 kB view details)

Uploaded Python 3

File details

Details for the file trodestrack-0.3.0.tar.gz.

File metadata

  • Download URL: trodestrack-0.3.0.tar.gz
  • Upload date:
  • Size: 7.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for trodestrack-0.3.0.tar.gz
Algorithm Hash digest
SHA256 a3af35a722df503dbab01ab541f72c2c6c618c3f2b29f6365c47a9cb973b8652
MD5 51ee84c5f63dbec739f78968012074fd
BLAKE2b-256 60c6341768ab65dfdc9a347fc2c07dbd5fa906b676f6ad981f5ff34ff9079793

See more details on using hashes here.

Provenance

The following attestation bundles were made for trodestrack-0.3.0.tar.gz:

Publisher: ci.yml on edeno/trodestrack

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file trodestrack-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: trodestrack-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 257.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for trodestrack-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6ae4b5767a982d37847880fb02b5e15ddd458d7d0f622be6e29fa4d8e7cdcad0
MD5 411ed3d5dbc46d588710f5a1a523311f
BLAKE2b-256 b2c9730078b1597e9614f57ef3846bdeeb5844898ae41748c91c670ed3857ff8

See more details on using hashes here.

Provenance

The following attestation bundles were made for trodestrack-0.3.0-py3-none-any.whl:

Publisher: ci.yml on edeno/trodestrack

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page