Some toy examples of score matching algorithms written in TensorFlow 2.0
Project description
toy_gradlogp_tf2
This repo implements some toy examples of the following score matching algorithms in TensorFlow 2.0:
ssm-vr
: sliced score matching with variance reductionssm
: sliced score matchingdeen
: deep energy estimator networksdsm
: denoisnig score matching
Related projects:
- toy_gradlogp: PyTorch Implementation.
Installation
Basic requirements:
- Python >= 3.6
- TensorFlow >= 2.3.0
Install from PyPI
pip install toy_gradlogp_tf2
Or install the latest version from this repo
pip install git+https://github.com.Ending2015a/toy_gradlogp_tf2.git@master
Examples
The examples are placed in toy_gradlogp/run/
Train an energy model
Run ssm-vr
on 2spirals
dataset
python -m toy_gradlogp.run.train_energy --loss ssm-vr --data 2spirals
To see the full options, type --help
command:
python -m toy_gradlogp.run.train_energy --help
usage: train_energy.py [-h] [--logdir LOGDIR]
[--data {8gaussians,2spirals,checkerboard,rings}]
[--loss {ssm-vr,ssm,deen,dsm}]
[--noise {radermacher,sphere,gaussian}] [--lr LR]
[--size SIZE] [--eval_size EVAL_SIZE]
[--batch_size BATCH_SIZE] [--n_epochs N_EPOCHS]
[--n_slices N_SLICES] [--n_steps N_STEPS] [--eps EPS]
[--log_freq LOG_FREQ] [--eval_freq EVAL_FREQ]
[--vis_freq VIS_FREQ]
optional arguments:
-h, --help show this help message and exit
--logdir LOGDIR
--data {8gaussians,2spirals,checkerboard,rings}
dataset
--loss {ssm-vr,ssm,deen,dsm}
loss type
--noise {radermacher,sphere,gaussian}
noise type
--lr LR learning rate
--size SIZE dataset size
--eval_size EVAL_SIZE
dataset size for evaluation
--batch_size BATCH_SIZE
training batch size
--n_epochs N_EPOCHS number of epochs to train
--n_slices N_SLICES number of slices for sliced score matching
--n_steps N_STEPS number of steps for langevin dynamics
--eps EPS noise scale for langevin dynamics
--log_freq LOG_FREQ logging frequency (unit: epoch)
--eval_freq EVAL_FREQ
evaluation frequency (unit: epoch)
--vis_freq VIS_FREQ visualization frequency (unit: epoch)
Results
Tips: The larger density has a lower energy!
8gaussians
Algorithm | Results |
---|---|
ssm-vr |
|
ssm |
|
deen |
|
dsm |
2spirals
Algorithm | Results |
---|---|
ssm-vr |
|
ssm |
|
deen |
|
dsm |
checkerboard
Algorithm | Results |
---|---|
ssm-vr |
|
ssm |
|
deen |
|
dsm |
rings
Algorithm | Results |
---|---|
ssm-vr |
|
ssm |
|
deen |
|
dsm |
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
toy-gradlogp-tf2-0.1.0.tar.gz
(11.4 kB
view details)
File details
Details for the file toy-gradlogp-tf2-0.1.0.tar.gz
.
File metadata
- Download URL: toy-gradlogp-tf2-0.1.0.tar.gz
- Upload date:
- Size: 11.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.7.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 583b5f2b6eedcf4c809982f78cca0809f6f5adb9ce3655130e249c7953ee2ab3 |
|
MD5 | 5ac53fed05b7b6c459d11f635c44333d |
|
BLAKE2b-256 | 80e6eb170f3e9f5a38f7b8e33a456042b0245801f3002065bcae26e8a8f4351c |