No project description provided
Project description
pyrfd
Pytorch implementation of RFD (see arXiv)
Covariance model
Provides an implementation of the SquaredExponential covariance model
with an auto_fit function, which requires only
- A
model_factorywhich returns the same but randomly initialized model every time it is called - A
lossfunction e.g.torch.nn.functional.nll_losswhich accepts a prediction and a true value - data, which can be passed to
torch.utils.DataLoaderwith different batch size parameters such that it returns(x,y)tuples when iterated on - a
csvfilename which acts as the cache for the covariance model ofthis unique (model, data, loss) combination.
Implementation of RFD
Such a covariance model can then be passed to RFD which implements the
pytorch optimizer interface. The end result can be used like torch.optim.Adam
Example usage
from benchmaking.classification.mnist.models.cnn3 import CNN3
import torch
import torchvision as tv
from pyrfd import RFD, SquaredExponential
cov_model = SquaredExponential()
cov_model.auto_fit(
model_factory=CNN3,
loss=torch.nn.functional.nll_loss,
data= tv.datasets.MNIST(
root="mnistSimpleCNN/data",
train=True,
transform=tv.transforms.ToTensor()
),
cache="cache/CNN3_mnist.csv",
# should be unique for (models, data, loss)
)
rfd = RFD(
CNN3().parameters(),
covariance_model=cov_model
)
How to cite
@inproceedings{benningRandomFunctionDescent2024,
title = {Random {{Function Descent}}},
booktitle = {Advances in {{Neural Information Processing Systems}}},
author = {Benning, Felix and D{\"o}ring, Leif},
year = {2024},
month = dec,
volume = {37},
primaryclass = {cs, math, stat},
publisher = {Curran Associates, Inc.},
address = {Vancouver, Canada},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pyrfd-1.0.1.tar.gz.
File metadata
- Download URL: pyrfd-1.0.1.tar.gz
- Upload date:
- Size: 15.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
69b9c05918ab524dec76fb9050e4146ee12751115b82d03d2148f0a35b6ed1f3
|
|
| MD5 |
f1a78786f4ea396a52fa6cdbf1f9256a
|
|
| BLAKE2b-256 |
98511e2bb8cd10c5316f99980c6d60a8f4a3ce91401fc1ae76a6d1711a2858c7
|
File details
Details for the file pyrfd-1.0.1-py3-none-any.whl.
File metadata
- Download URL: pyrfd-1.0.1-py3-none-any.whl
- Upload date:
- Size: 17.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bc704ce7fc2ed5a57bbb2ba308460db216edba743da789adf21524e5ac54bc24
|
|
| MD5 |
9d20055a7d5891bad60ec68dcb1037b8
|
|
| BLAKE2b-256 |
747af27d012915ee4782562edd1a604156794709bf0f1f57e6f2e66edc446702
|