Skip to main content

Some toy examples of score matching algorithms written in PyTorch

Project description

toy_gradlogp

This repo implements some toy examples of the following score matching algorithms in PyTorch:

Installation

Basic requirements:

  • Python >= 3.6
  • TensorFlow >= 2.3.0
  • PyTorch >= 1.8.0

Install from PyPI

pip install toy_gradlogp

Or install the latest version from this repo

pip install git+https://github.com.Ending2015a/toy_gradlogp.git@master

Examples

The examples are placed in gradlogp/run/

Train an energy model

Run ssm-vr on 2spirals dataset (don't forget to add --gpu to enable gpu)

python -m gradlogp.run.train_energy --gpu --loss ssm-vr --data 2spirals

To see the full options, type --help command:

python -m gradlogp.run.train_energy --help
usage: train_energy.py [-h] [--logdir LOGDIR]
                       [--data {8gaussians,2spirals,checkerboard,rings}]
                       [--loss {ssm-vr,ssm,deen,dsm}]
                       [--noise {radermacher,sphere,gaussian}] [--lr LR]
                       [--size SIZE] [--eval_size EVAL_SIZE]
                       [--batch_size BATCH_SIZE] [--n_epochs N_EPOCHS]
                       [--n_slices N_SLICES] [--n_steps N_STEPS] [--eps EPS]
                       [--gpu] [--log_freq LOG_FREQ] [--eval_freq EVAL_FREQ]
                       [--vis_freq VIS_FREQ]

optional arguments:
  -h, --help            show this help message and exit
  --logdir LOGDIR
  --data {8gaussians,2spirals,checkerboard,rings}
                        dataset
  --loss {ssm-vr,ssm,deen,dsm}
                        loss type
  --noise {radermacher,sphere,gaussian}
                        noise type
  --lr LR               learning rate
  --size SIZE           dataset size
  --eval_size EVAL_SIZE
                        dataset size for evaluation
  --batch_size BATCH_SIZE
                        training batch size
  --n_epochs N_EPOCHS   number of epochs to train
  --n_slices N_SLICES   number of slices for sliced score matching
  --n_steps N_STEPS     number of steps for langevin dynamics
  --eps EPS             noise scale for langevin dynamics
  --gpu                 enable gpu
  --log_freq LOG_FREQ   logging frequency (unit: epoch)
  --eval_freq EVAL_FREQ
                        evaluation frequency (unit: epoch)
  --vis_freq VIS_FREQ   visualization frequency (unit: epoch)

Results

Tips: The larger density has a lower energy!

8gaussians

Algorithm Results
ssm-vr
ssm
deen
dsm

2spirals

Algorithm Results
ssm-vr
ssm
deen
dsm

checkerboard

Algorithm Results
ssm-vr
ssm
deen
dsm

rings

Algorithm Results
ssm-vr
ssm
deen
dsm

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

toy_gradlogp-0.2.0.tar.gz (11.8 kB view details)

Uploaded Source

File details

Details for the file toy_gradlogp-0.2.0.tar.gz.

File metadata

  • Download URL: toy_gradlogp-0.2.0.tar.gz
  • Upload date:
  • Size: 11.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.4.0 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.7.4

File hashes

Hashes for toy_gradlogp-0.2.0.tar.gz
Algorithm Hash digest
SHA256 183dec4397b858d26b79107fe7ce5bd62434cca7550686707c517b4bfbf2dfa0
MD5 4e27facffb2406ddd0b867926c3c2031
BLAKE2b-256 8b639e7ee332db1c304a70fcb303272114abb5df9cfbaf8bf55cd18585540646

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page