Official implementation of CE-EM algorithm
Project description
CE-EM
Official implementation of the the algorithm CE-EM and baseline Particle EM from "Scalable Identification of Partially Observed Systems with Certainty-Equivalent EM".
Usage
Ensure you are using at least Python 3.6
pip install CEEM
Run python -m pytest
to ensure everything works.
A Jupyter notebook demonstrating usage can be found in the examples
subfolder.
Code overview
ceem/dynamics.py
defines the system API used by the CEEM algorithm.ceem/systems/*.py
define various systems used in the experimentsceem/ceem.py
contains the CEEM algorithm.ceem/smoother.py
defines different smoothing routines used by the CEEM algorithm in the smoothing step.ceem/learner.py
defines different learning routines used by the CEEM algorithm in the learning step.ceem/opt_criteria.py
defines different optimization criteria used by the CEEM algorithm.ceem/particleem.py
implements Particle EM
Experiments
Lorenz
Unbiased Estimation in Deterministic Settings
To regenerate the data in data/lorenz/bias_experiment
run:
python experiments/lorenz/bias_experiment.py
To generate Table 1 run:
python experiments/lorenz/plotting/process_bias.py
Comparison to Particle Based Methods
To regenerate the data in data/lorenz/comp
run:
python experiments/lorenz/comp_pem.py
python experiments/lorenz/comp_ceem.py
To generate Figure 2 run:
python experiments/lorenz/plotting/process_comp.py
Convergence of CE-EM on High Dimensional Problems
To regenerate data in data/lorenz/convergence_experiment
run:
python experiments/lorenz/convergence_experiment_pem.py
python experiments/lorenz/convergence_experiment_ceem.py
To generate Figure 3 run:
python experiments/lorenz/plotting/process_convergence.py
Helicopter
The following are scripts for training models in Section 4.2.
Pretrained models are provided in the pretrained_models
folder.
Data download
The dataset used in our experiments can be downloaded by running:
wget 'https://zenodo.org/record/3662987/files/datasets.zip?download=1' -O datasets.zip
unzip datasets.zip
Baselines
Naive
Run the experiment with default parameters:
python experiments/heli/baselines.py --model naive
H25
Run the experiment with default parameters:
python experiments/heli/baselines.py --model H25
cp data/h25/best_net.th trained_models/h25.th
SID
Prepare the data first for residual training:
cp data/naive/best_net.th trained_models/naive_baseline.th
python experiments/heli/prepare_residual_dataset.py
Ensure you have MATLAB with the System Identification Toolbox installed then run from within MATLAB:
run_n4sid.m
LSTM
python experiments/heli/train_lstm.py
cp data/heli_lstm/ckpts/best_model.th trained_models/lstm.th
NL (Ours)
Prepare the data first for residual training:
cp data/naive/best_net.th trained_models/naive_baseline.th
python experiments/heli/prepare_residual_dataset.py
Run the experiment with default parameters:
python experiments/heli/ceemnl.py
Move the best model to trained_models
cp data/NLobsLdyn/ckpts/best_model.th trained_models/NL_model.th
Evaluating and plotting test trajectories
First evaluate the models (uses pretrained by default) by running:
python experiments/heli/evaluate_models.py
python experiments/heli/plotting/plotbar.py
Then plot the n th trajectory in the test set by running:
python experiments/heli/plotting/plot_trajectories.py --trajectory 9
To plot the circular acceleration prediction (instead of horizontal) on the n th trajectory in the test set:
python experiments/heli/plotting/plot_trajectories.py --trajectory 9 --moments
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file CEEM-0.0.4.tar.gz
.
File metadata
- Download URL: CEEM-0.0.4.tar.gz
- Upload date:
- Size: 31.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/52.0.0.post20210125 requests-toolbelt/0.9.1 tqdm/4.57.0 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0295557e081b31b15f3405aa828fd4c792f843923d93fb33fb8d5c93bb57d0a4 |
|
MD5 | 88d713d952c145c03e08400baac03d52 |
|
BLAKE2b-256 | f5aaf2714448c0b4756a933bf43086ce4aa01df9f789f83886555ef529e6d44d |
File details
Details for the file CEEM-0.0.4-py3-none-any.whl
.
File metadata
- Download URL: CEEM-0.0.4-py3-none-any.whl
- Upload date:
- Size: 37.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/52.0.0.post20210125 requests-toolbelt/0.9.1 tqdm/4.57.0 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c4253fc8d8e1aefac77e8a943125525e38dcc7f5ea1720df54ec620e31b2b6f7 |
|
MD5 | 054cee6e718ec8c69c4d346ba5dc8eb8 |
|
BLAKE2b-256 | 24fd3e9780fdac4e4e3b83b02aecc00688d17fd69f2cb4b01ad21b7e7d7e5670 |