SGLD as PyTorch Optimizer
Project description
SGLD in PyTorch
This package implements SGLD and cSGLD as a PyTorch Optimizer.
Installation
Install from pip
as:
pip install torch-sgld
To install the latest directly from source, run
pip install git+https://github.com/activatedgeek/torch-sgld.git
Usage
The general idea is to modify the usual gradient-based update loops
in PyTorch with the SGLD
optimizer.
from torch_sgld import SGLD
f = module() ## construct PyTorch nn.Module.
sgld = SGLD(f.parameters(), lr=lr, momentum=.9) ## Add momentum to make it SG-HMC.
sgld_scheduler = ## Optionally add a step-size scheduler.
for _ in range(num_steps):
energy = f()
energy.backward()
sgld.step()
sgld_scheduler.step() ## Optional scheduler step.
cSGLD
can be implemented by using a cyclical learning rate schedule.
See the toy_csgld.ipynb notebook for a
complete example.
License
Apache 2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch-sgld-0.1.0.tar.gz
(8.2 kB
view details)
Built Distribution
File details
Details for the file torch-sgld-0.1.0.tar.gz
.
File metadata
- Download URL: torch-sgld-0.1.0.tar.gz
- Upload date:
- Size: 8.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3bea70771e8215f64313c3ec7fbd88cd16e6f7a41cddd246eaffffaf856a02ce |
|
MD5 | b28266e9fdb97fe750889f863e0e0acd |
|
BLAKE2b-256 | a9f4d2029ce2535111854d305749f0392ee650beef2401a0aaf817591e5bc21f |
File details
Details for the file torch_sgld-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: torch_sgld-0.1.0-py3-none-any.whl
- Upload date:
- Size: 8.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 566882f7995e08907911dc90965fd7b90062312941099ff1eed1a1d47a7a6396 |
|
MD5 | 16a9d0bd2a02dc7d4924b35195c0e4f7 |
|
BLAKE2b-256 | 27180b0d1191defc8a54ec4dddc56c36f87e8baa35b0f716a09e48235eef6e66 |