Skip to main content

Some toy examples of score matching algorithms written in TensorFlow 2.0

Project description

toy_gradlogp_tf2

This repo implements some toy examples of the following score matching algorithms in TensorFlow 2.0:

Related projects:

Installation

Basic requirements:

  • Python >= 3.6
  • TensorFlow >= 2.3.0

Install from PyPI

pip install toy_gradlogp_tf2

Or install the latest version from this repo

pip install git+https://github.com.Ending2015a/toy_gradlogp_tf2.git@master

Examples

The examples are placed in toy_gradlogp/run/

Train an energy model

Run ssm-vr on 2spirals dataset

python -m toy_gradlogp.run.train_energy --loss ssm-vr --data 2spirals

To see the full options, type --help command:

python -m toy_gradlogp.run.train_energy --help
usage: train_energy.py [-h] [--logdir LOGDIR]
                       [--data {8gaussians,2spirals,checkerboard,rings}]
                       [--loss {ssm-vr,ssm,deen,dsm}]
                       [--noise {radermacher,sphere,gaussian}] [--lr LR]
                       [--size SIZE] [--eval_size EVAL_SIZE]
                       [--batch_size BATCH_SIZE] [--n_epochs N_EPOCHS]
                       [--n_slices N_SLICES] [--n_steps N_STEPS] [--eps EPS]
                       [--log_freq LOG_FREQ] [--eval_freq EVAL_FREQ]
                       [--vis_freq VIS_FREQ]

optional arguments:
  -h, --help            show this help message and exit
  --logdir LOGDIR
  --data {8gaussians,2spirals,checkerboard,rings}
                        dataset
  --loss {ssm-vr,ssm,deen,dsm}
                        loss type
  --noise {radermacher,sphere,gaussian}
                        noise type
  --lr LR               learning rate
  --size SIZE           dataset size
  --eval_size EVAL_SIZE
                        dataset size for evaluation
  --batch_size BATCH_SIZE
                        training batch size
  --n_epochs N_EPOCHS   number of epochs to train
  --n_slices N_SLICES   number of slices for sliced score matching
  --n_steps N_STEPS     number of steps for langevin dynamics
  --eps EPS             noise scale for langevin dynamics
  --log_freq LOG_FREQ   logging frequency (unit: epoch)
  --eval_freq EVAL_FREQ
                        evaluation frequency (unit: epoch)
  --vis_freq VIS_FREQ   visualization frequency (unit: epoch)

Results

Tips: The larger density has a lower energy!

8gaussians

Algorithm Results
ssm-vr
ssm
deen
dsm

2spirals

Algorithm Results
ssm-vr
ssm
deen
dsm

checkerboard

Algorithm Results
ssm-vr
ssm
deen
dsm

rings

Algorithm Results
ssm-vr
ssm
deen
dsm

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

toy-gradlogp-tf2-0.1.0.tar.gz (11.4 kB view details)

Uploaded Source

File details

Details for the file toy-gradlogp-tf2-0.1.0.tar.gz.

File metadata

  • Download URL: toy-gradlogp-tf2-0.1.0.tar.gz
  • Upload date:
  • Size: 11.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.7.10

File hashes

Hashes for toy-gradlogp-tf2-0.1.0.tar.gz
Algorithm Hash digest
SHA256 583b5f2b6eedcf4c809982f78cca0809f6f5adb9ce3655130e249c7953ee2ab3
MD5 5ac53fed05b7b6c459d11f635c44333d
BLAKE2b-256 80e6eb170f3e9f5a38f7b8e33a456042b0245801f3002065bcae26e8a8f4351c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page