Skip to main content

Some toy examples of score matching algorithms written in PyTorch

Project description

toy_gradlogp

This repo implements some toy examples of the following score matching algorithms in PyTorch:

Installation

Basic requirements:

  • Python >= 3.6
  • TensorFlow >= 2.3.0
  • PyTorch >= 1.8.0

Install from PyPI

pip install toy_gradlogp

Or install the latest version from this repo

pip install git+https://github.com.Ending2015a/toy_gradlogp.git@master

Examples

Clone this repo to run the example codes!!!

git clone https://github.com/Ending2015a/toy_gradlogp.git

Train an energy model

Type --help to see this message:

usage: train_energy.py [-h] [--logdir LOGDIR]
                       [--data {8gaussians,2spirals,checkerboard,rings}]
                       [--loss {ssm-vr,ssm,deen}]
                       [--noise {radermacher,sphere,gaussian}] [--lr LR]
                       [--size SIZE] [--eval_size EVAL_SIZE]
                       [--batch_size BATCH_SIZE] [--n_epochs N_EPOCHS]
                       [--n_slices N_SLICES] [--gpu] [--log_freq LOG_FREQ]
                       [--eval_freq EVAL_FREQ] [--vis_freq VIS_FREQ]

optional arguments:
  -h, --help            show this help message and exit
  --logdir LOGDIR
  --data {8gaussians,2spirals,checkerboard,rings}
  --loss {ssm-vr,ssm,deen}
                        Loss type
  --noise {radermacher,sphere,gaussian}
                        Noise type
  --lr LR               learning rate
  --size SIZE           dataset size
  --eval_size EVAL_SIZE
                        dataset size for evaluation
  --batch_size BATCH_SIZE
                        training batch size
  --n_epochs N_EPOCHS   number of epochs to train
  --n_slices N_SLICES   number of slices for sliced score matching
  --gpu
  --log_freq LOG_FREQ
  --eval_freq EVAL_FREQ
  --vis_freq VIS_FREQ

Run ssm-vr on 2spirals dataset

python -m examples.train_energy --gpu --loss ssm-vr --data 2spirals

Results

Tips: The larger density has a lower energy!

8gaussians

Algorithm Results
ssm-vr
ssm
deen

2spirals

Algorithm Results
ssm-vr
ssm
deen

checkerboard

Algorithm Results
ssm-vr
ssm
deen

rings

Algorithm Results
ssm-vr
ssm
deen

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

toy_gradlogp-0.1.0.tar.gz (9.4 kB view details)

Uploaded Source

File details

Details for the file toy_gradlogp-0.1.0.tar.gz.

File metadata

  • Download URL: toy_gradlogp-0.1.0.tar.gz
  • Upload date:
  • Size: 9.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.4.0 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.7.4

File hashes

Hashes for toy_gradlogp-0.1.0.tar.gz
Algorithm Hash digest
SHA256 fdd5f1e6c79e90369319b1691f4feee4d54207851eb7d725562a1738c311d29b
MD5 180022f504a7848ef9e158fc9fcb0904
BLAKE2b-256 11525ba50b26a15479a81009ec314b3f58331c608be87d3a01d1532bacc706d1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page