Official implementation of CE-EM algorithm
Project description
CE-EM
Official implementation of the the algorithm CE-EM and baseline Particle EM from "Scalable Identification of Partially Observed Systems with Certainty-Equivalent EM".
Usage
Ensure you are using at least Python 3.6
pip install CEEM
Run python -m pytest to ensure everything works.
A Jupyter notebook demonstrating usage can be found in the examples subfolder.
Code overview
ceem/dynamics.pydefines the system API used by the CEEM algorithm.ceem/systems/*.pydefine various systems used in the experimentsceem/ceem.pycontains the CEEM algorithm.ceem/smoother.pydefines different smoothing routines used by the CEEM algorithm in the smoothing step.ceem/learner.pydefines different learning routines used by the CEEM algorithm in the learning step.ceem/opt_criteria.pydefines different optimization criteria used by the CEEM algorithm.ceem/particleem.pyimplements Particle EM
Experiments
Lorenz
Unbiased Estimation in Deterministic Settings
To regenerate the data in data/lorenz/bias_experiment run:
python experiments/lorenz/bias_experiment.py
To generate Table 1 run:
python experiments/lorenz/plotting/process_bias.py
Comparison to Particle Based Methods
To regenerate the data in data/lorenz/comp run:
python experiments/lorenz/comp_pem.py
python experiments/lorenz/comp_ceem.py
To generate Figure 2 run:
python experiments/lorenz/plotting/process_comp.py
Convergence of CE-EM on High Dimensional Problems
To regenerate data in data/lorenz/convergence_experiment run:
python experiments/lorenz/convergence_experiment_pem.py
python experiments/lorenz/convergence_experiment_ceem.py
To generate Figure 3 run:
python experiments/lorenz/plotting/process_convergence.py
Helicopter
The following are scripts for training models in Section 4.2.
Pretrained models are provided in the pretrained_models folder.
Data download
The dataset used in our experiments can be downloaded by running:
wget 'https://zenodo.org/record/3662987/files/datasets.zip?download=1' -O datasets.zip
unzip datasets.zip
Baselines
Naive
Run the experiment with default parameters:
python experiments/heli/baselines.py --model naive
H25
Run the experiment with default parameters:
python experiments/heli/baselines.py --model H25
cp data/h25/best_net.th trained_models/h25.th
SID
Prepare the data first for residual training:
cp data/naive/best_net.th trained_models/naive_baseline.th
python experiments/heli/prepare_residual_dataset.py
Ensure you have MATLAB with the System Identification Toolbox installed then run from within MATLAB:
run_n4sid.m
LSTM
python experiments/heli/train_lstm.py
cp data/heli_lstm/ckpts/best_model.th trained_models/lstm.th
NL (Ours)
Prepare the data first for residual training:
cp data/naive/best_net.th trained_models/naive_baseline.th
python experiments/heli/prepare_residual_dataset.py
Run the experiment with default parameters:
python experiments/heli/ceemnl.py
Move the best model to trained_models
cp data/NLobsLdyn/ckpts/best_model.th trained_models/NL_model.th
Evaluating and plotting test trajectories
First evaluate the models (uses pretrained by default) by running:
python experiments/heli/evaluate_models.py
python experiments/heli/plotting/plotbar.py
Then plot the n th trajectory in the test set by running:
python experiments/heli/plotting/plot_trajectories.py --trajectory 9
To plot the circular acceleration prediction (instead of horizontal) on the n th trajectory in the test set:
python experiments/heli/plotting/plot_trajectories.py --trajectory 9 --moments
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file CEEM-0.0.4.tar.gz.
File metadata
- Download URL: CEEM-0.0.4.tar.gz
- Upload date:
- Size: 31.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/52.0.0.post20210125 requests-toolbelt/0.9.1 tqdm/4.57.0 CPython/3.8.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0295557e081b31b15f3405aa828fd4c792f843923d93fb33fb8d5c93bb57d0a4
|
|
| MD5 |
88d713d952c145c03e08400baac03d52
|
|
| BLAKE2b-256 |
f5aaf2714448c0b4756a933bf43086ce4aa01df9f789f83886555ef529e6d44d
|
File details
Details for the file CEEM-0.0.4-py3-none-any.whl.
File metadata
- Download URL: CEEM-0.0.4-py3-none-any.whl
- Upload date:
- Size: 37.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/52.0.0.post20210125 requests-toolbelt/0.9.1 tqdm/4.57.0 CPython/3.8.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c4253fc8d8e1aefac77e8a943125525e38dcc7f5ea1720df54ec620e31b2b6f7
|
|
| MD5 |
054cee6e718ec8c69c4d346ba5dc8eb8
|
|
| BLAKE2b-256 |
24fd3e9780fdac4e4e3b83b02aecc00688d17fd69f2cb4b01ad21b7e7d7e5670
|