Skip to main content

Some toy examples of score matching algorithms written in PyTorch

Project description

toy_gradlogp

This repo implements some toy examples of the following score matching algorithms in PyTorch:

Installation

Basic requirements:

  • Python >= 3.6
  • TensorFlow >= 2.3.0
  • PyTorch >= 1.8.0

Install from PyPI

pip install toy_gradlogp

Or install the latest version from this repo

pip install git+https://github.com.Ending2015a/toy_gradlogp.git@master

Examples

Clone this repo to run the example codes!!!

git clone https://github.com/Ending2015a/toy_gradlogp.git

Train an energy model

Type --help to see this message:

usage: train_energy.py [-h] [--logdir LOGDIR]
                       [--data {8gaussians,2spirals,checkerboard,rings}]
                       [--loss {ssm-vr,ssm,deen}]
                       [--noise {radermacher,sphere,gaussian}] [--lr LR]
                       [--size SIZE] [--eval_size EVAL_SIZE]
                       [--batch_size BATCH_SIZE] [--n_epochs N_EPOCHS]
                       [--n_slices N_SLICES] [--gpu] [--log_freq LOG_FREQ]
                       [--eval_freq EVAL_FREQ] [--vis_freq VIS_FREQ]

optional arguments:
  -h, --help            show this help message and exit
  --logdir LOGDIR
  --data {8gaussians,2spirals,checkerboard,rings}
  --loss {ssm-vr,ssm,deen}
                        Loss type
  --noise {radermacher,sphere,gaussian}
                        Noise type
  --lr LR               learning rate
  --size SIZE           dataset size
  --eval_size EVAL_SIZE
                        dataset size for evaluation
  --batch_size BATCH_SIZE
                        training batch size
  --n_epochs N_EPOCHS   number of epochs to train
  --n_slices N_SLICES   number of slices for sliced score matching
  --gpu
  --log_freq LOG_FREQ
  --eval_freq EVAL_FREQ
  --vis_freq VIS_FREQ

Run ssm-vr on 2spirals dataset

python -m examples.train_energy --gpu --loss ssm-vr --data 2spirals

Results

Tips: The larger density has a lower energy!

8gaussians

Algorithm Results
ssm-vr
ssm
deen

2spirals

Algorithm Results
ssm-vr
ssm
deen

checkerboard

Algorithm Results
ssm-vr
ssm
deen

rings

Algorithm Results
ssm-vr
ssm
deen

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

toy_gradlogp-0.1.0.tar.gz (9.4 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page