Skip to main content

Official implementation of CE-EM algorithm

Project description

CE-EM

Official implementation of the the algorithm CE-EM and baseline Particle EM from "Scalable Identification of Partially Observed Systems with Certainty-Equivalent EM".

Website

Usage

Ensure you are using at least Python 3.6

pip install CEEM

Run python -m pytest to ensure everything works.

A Jupyter notebook demonstrating usage can be found in the examples subfolder.

Code overview

  • ceem/dynamics.py defines the system API used by the CEEM algorithm.
  • ceem/systems/*.py define various systems used in the experiments
  • ceem/ceem.py contains the CEEM algorithm.
  • ceem/smoother.py defines different smoothing routines used by the CEEM algorithm in the smoothing step.
  • ceem/learner.py defines different learning routines used by the CEEM algorithm in the learning step.
  • ceem/opt_criteria.py defines different optimization criteria used by the CEEM algorithm.
  • ceem/particleem.py implements Particle EM

Experiments

Lorenz

Unbiased Estimation in Deterministic Settings

To regenerate the data in data/lorenz/bias_experiment run:

python experiments/lorenz/bias_experiment.py

To generate Table 1 run:

python experiments/lorenz/plotting/process_bias.py

Comparison to Particle Based Methods

To regenerate the data in data/lorenz/comp run:

python experiments/lorenz/comp_pem.py
python experiments/lorenz/comp_ceem.py

To generate Figure 2 run:

python experiments/lorenz/plotting/process_comp.py

Convergence of CE-EM on High Dimensional Problems

To regenerate data in data/lorenz/convergence_experiment run:

python experiments/lorenz/convergence_experiment_pem.py
python experiments/lorenz/convergence_experiment_ceem.py

To generate Figure 3 run:

python experiments/lorenz/plotting/process_convergence.py

Helicopter

The following are scripts for training models in Section 4.2. Pretrained models are provided in the pretrained_models folder.

Data download

The dataset used in our experiments can be downloaded by running:

wget 'https://zenodo.org/record/3662987/files/datasets.zip?download=1' -O datasets.zip
unzip datasets.zip

Baselines

Naive

Run the experiment with default parameters:

python experiments/heli/baselines.py --model naive

H25

Run the experiment with default parameters:

python experiments/heli/baselines.py --model H25
cp data/h25/best_net.th trained_models/h25.th

SID

Prepare the data first for residual training:

cp data/naive/best_net.th trained_models/naive_baseline.th
python experiments/heli/prepare_residual_dataset.py

Ensure you have MATLAB with the System Identification Toolbox installed then run from within MATLAB:

run_n4sid.m

LSTM

python experiments/heli/train_lstm.py
cp data/heli_lstm/ckpts/best_model.th trained_models/lstm.th

NL (Ours)

Prepare the data first for residual training:

cp data/naive/best_net.th trained_models/naive_baseline.th
python experiments/heli/prepare_residual_dataset.py

Run the experiment with default parameters:

python experiments/heli/ceemnl.py 

Move the best model to trained_models

cp data/NLobsLdyn/ckpts/best_model.th trained_models/NL_model.th

Evaluating and plotting test trajectories

First evaluate the models (uses pretrained by default) by running:

python experiments/heli/evaluate_models.py
python experiments/heli/plotting/plotbar.py

Then plot the n th trajectory in the test set by running:

python experiments/heli/plotting/plot_trajectories.py --trajectory 9

To plot the circular acceleration prediction (instead of horizontal) on the n th trajectory in the test set:

python experiments/heli/plotting/plot_trajectories.py --trajectory 9 --moments

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

CEEM-0.0.4.tar.gz (31.2 kB view details)

Uploaded Source

Built Distribution

CEEM-0.0.4-py3-none-any.whl (37.8 kB view details)

Uploaded Python 3

File details

Details for the file CEEM-0.0.4.tar.gz.

File metadata

  • Download URL: CEEM-0.0.4.tar.gz
  • Upload date:
  • Size: 31.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/52.0.0.post20210125 requests-toolbelt/0.9.1 tqdm/4.57.0 CPython/3.8.5

File hashes

Hashes for CEEM-0.0.4.tar.gz
Algorithm Hash digest
SHA256 0295557e081b31b15f3405aa828fd4c792f843923d93fb33fb8d5c93bb57d0a4
MD5 88d713d952c145c03e08400baac03d52
BLAKE2b-256 f5aaf2714448c0b4756a933bf43086ce4aa01df9f789f83886555ef529e6d44d

See more details on using hashes here.

File details

Details for the file CEEM-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: CEEM-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 37.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/52.0.0.post20210125 requests-toolbelt/0.9.1 tqdm/4.57.0 CPython/3.8.5

File hashes

Hashes for CEEM-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 c4253fc8d8e1aefac77e8a943125525e38dcc7f5ea1720df54ec620e31b2b6f7
MD5 054cee6e718ec8c69c4d346ba5dc8eb8
BLAKE2b-256 24fd3e9780fdac4e4e3b83b02aecc00688d17fd69f2cb4b01ad21b7e7d7e5670

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page