Some toy examples of score matching algorithms written in PyTorch
Project description
toy_gradlogp
This repo implements some toy examples of the following score matching algorithms in PyTorch:
ssm-vr: sliced score matching with variance reductionssm: sliced score matchingdeen: deep energy estimator networks (denoising score matching)
Installation
Basic requirements:
- Python >= 3.6
- TensorFlow >= 2.3.0
- PyTorch >= 1.8.0
Install from PyPI
pip install toy_gradlogp
Or install the latest version from this repo
pip install git+https://github.com.Ending2015a/toy_gradlogp.git@master
Examples
Clone this repo to run the example codes!!!
git clone https://github.com/Ending2015a/toy_gradlogp.git
Train an energy model
Type --help to see this message:
usage: train_energy.py [-h] [--logdir LOGDIR]
[--data {8gaussians,2spirals,checkerboard,rings}]
[--loss {ssm-vr,ssm,deen}]
[--noise {radermacher,sphere,gaussian}] [--lr LR]
[--size SIZE] [--eval_size EVAL_SIZE]
[--batch_size BATCH_SIZE] [--n_epochs N_EPOCHS]
[--n_slices N_SLICES] [--gpu] [--log_freq LOG_FREQ]
[--eval_freq EVAL_FREQ] [--vis_freq VIS_FREQ]
optional arguments:
-h, --help show this help message and exit
--logdir LOGDIR
--data {8gaussians,2spirals,checkerboard,rings}
--loss {ssm-vr,ssm,deen}
Loss type
--noise {radermacher,sphere,gaussian}
Noise type
--lr LR learning rate
--size SIZE dataset size
--eval_size EVAL_SIZE
dataset size for evaluation
--batch_size BATCH_SIZE
training batch size
--n_epochs N_EPOCHS number of epochs to train
--n_slices N_SLICES number of slices for sliced score matching
--gpu
--log_freq LOG_FREQ
--eval_freq EVAL_FREQ
--vis_freq VIS_FREQ
Run ssm-vr on 2spirals dataset
python -m examples.train_energy --gpu --loss ssm-vr --data 2spirals
Results
Tips: The larger density has a lower energy!
8gaussians
| Algorithm | Results |
|---|---|
ssm-vr |
|
ssm |
|
deen |
2spirals
| Algorithm | Results |
|---|---|
ssm-vr |
|
ssm |
|
deen |
checkerboard
| Algorithm | Results |
|---|---|
ssm-vr |
|
ssm |
|
deen |
rings
| Algorithm | Results |
|---|---|
ssm-vr |
|
ssm |
|
deen |
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
toy_gradlogp-0.1.0.tar.gz
(9.4 kB
view details)
File details
Details for the file toy_gradlogp-0.1.0.tar.gz.
File metadata
- Download URL: toy_gradlogp-0.1.0.tar.gz
- Upload date:
- Size: 9.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.4.0 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.7.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fdd5f1e6c79e90369319b1691f4feee4d54207851eb7d725562a1738c311d29b
|
|
| MD5 |
180022f504a7848ef9e158fc9fcb0904
|
|
| BLAKE2b-256 |
11525ba50b26a15479a81009ec314b3f58331c608be87d3a01d1532bacc706d1
|