Some toy examples of score matching algorithms written in PyTorch
Project description
toy_gradlogp
This repo implements some toy examples of the following score matching algorithms in PyTorch:
ssm-vr
: sliced score matching with variance reductionssm
: sliced score matchingdeen
: deep energy estimator networksdsm
: denoisnig score matching
Related projects:
- toy_gradlogp_tf2: TensorFlow 2.0 Implementation.
Installation
Basic requirements:
- Python >= 3.6
- TensorFlow >= 2.3.0
- PyTorch >= 1.8.0
Install from PyPI
pip install toy_gradlogp
Or install the latest version from this repo
pip install git+https://github.com.Ending2015a/toy_gradlogp.git@master
Examples
The examples are placed in toy_gradlogp/run/
Train an energy model
Run ssm-vr
on 2spirals
dataset (don't forget to add --gpu
to enable gpu)
python -m toy_gradlogp.run.train_energy --gpu --loss ssm-vr --data 2spirals
To see the full options, type --help
command:
python -m toy_gradlogp.run.train_energy --help
usage: train_energy.py [-h] [--logdir LOGDIR]
[--data {8gaussians,2spirals,checkerboard,rings}]
[--loss {ssm-vr,ssm,deen,dsm}]
[--noise {radermacher,sphere,gaussian}] [--lr LR]
[--size SIZE] [--eval_size EVAL_SIZE]
[--batch_size BATCH_SIZE] [--n_epochs N_EPOCHS]
[--n_slices N_SLICES] [--n_steps N_STEPS] [--eps EPS]
[--gpu] [--log_freq LOG_FREQ] [--eval_freq EVAL_FREQ]
[--vis_freq VIS_FREQ]
optional arguments:
-h, --help show this help message and exit
--logdir LOGDIR
--data {8gaussians,2spirals,checkerboard,rings}
dataset
--loss {ssm-vr,ssm,deen,dsm}
loss type
--noise {radermacher,sphere,gaussian}
noise type
--lr LR learning rate
--size SIZE dataset size
--eval_size EVAL_SIZE
dataset size for evaluation
--batch_size BATCH_SIZE
training batch size
--n_epochs N_EPOCHS number of epochs to train
--n_slices N_SLICES number of slices for sliced score matching
--n_steps N_STEPS number of steps for langevin dynamics
--eps EPS noise scale for langevin dynamics
--gpu enable gpu
--log_freq LOG_FREQ logging frequency (unit: epoch)
--eval_freq EVAL_FREQ
evaluation frequency (unit: epoch)
--vis_freq VIS_FREQ visualization frequency (unit: epoch)
Results
Tips: The larger density has a lower energy!
8gaussians
Algorithm | Results |
---|---|
ssm-vr |
|
ssm |
|
deen |
|
dsm |
2spirals
Algorithm | Results |
---|---|
ssm-vr |
|
ssm |
|
deen |
|
dsm |
checkerboard
Algorithm | Results |
---|---|
ssm-vr |
|
ssm |
|
deen |
|
dsm |
rings
Algorithm | Results |
---|---|
ssm-vr |
|
ssm |
|
deen |
|
dsm |
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
toy_gradlogp-0.2.2.tar.gz
(12.0 kB
view details)
File details
Details for the file toy_gradlogp-0.2.2.tar.gz
.
File metadata
- Download URL: toy_gradlogp-0.2.2.tar.gz
- Upload date:
- Size: 12.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.7.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d88d0de1fc88992cb5a145e63a53f783b2f61758dd77adf26ff4eb10d5e15ba2 |
|
MD5 | 052e775e04870f37be8512fc7b458238 |
|
BLAKE2b-256 | fb0555e4a7769e61f2597909e75f57412d0102ce53ba47f5efec3d04a313c056 |