SGLD as PyTorch Optimizer
Project description
SGLD in PyTorch
This package implements SGLD and cSGLD as a PyTorch Optimizer.
Installation
Install from pip as:
pip install torch-sgld
To install the latest directly from source, run
pip install git+https://github.com/activatedgeek/torch-sgld.git
Usage
The general idea is to modify the usual gradient-based update loops
in PyTorch with the SGLD optimizer.
from torch_sgld import SGLD
f = module() ## construct PyTorch nn.Module.
sgld = SGLD(f.parameters(), lr=lr, momentum=.9) ## Add momentum to make it SG-HMC.
sgld_scheduler = ## Optionally add a step-size scheduler.
for _ in range(num_steps):
energy = f()
energy.backward()
sgld.step()
sgld_scheduler.step() ## Optional scheduler step.
cSGLD can be implemented by using a cyclical learning rate schedule.
See the toy_csgld.ipynb notebook for a
complete example.
License
Apache 2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch-sgld-0.1.0.tar.gz.
File metadata
- Download URL: torch-sgld-0.1.0.tar.gz
- Upload date:
- Size: 8.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3bea70771e8215f64313c3ec7fbd88cd16e6f7a41cddd246eaffffaf856a02ce
|
|
| MD5 |
b28266e9fdb97fe750889f863e0e0acd
|
|
| BLAKE2b-256 |
a9f4d2029ce2535111854d305749f0392ee650beef2401a0aaf817591e5bc21f
|
File details
Details for the file torch_sgld-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torch_sgld-0.1.0-py3-none-any.whl
- Upload date:
- Size: 8.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
566882f7995e08907911dc90965fd7b90062312941099ff1eed1a1d47a7a6396
|
|
| MD5 |
16a9d0bd2a02dc7d4924b35195c0e4f7
|
|
| BLAKE2b-256 |
27180b0d1191defc8a54ec4dddc56c36f87e8baa35b0f716a09e48235eef6e66
|