Skip to main content

No project description provided

Project description

pyrfd

PyPI version codecov

Pytorch implementation of RFD (see arXiv)

Covariance model

Provides an implementation of the SquaredExponential covariance model with an auto_fit function, which requires only

  1. A model_factory which returns the same but randomly initialized model every time it is called
  2. A loss function e.g. torch.nn.functional.nll_loss which accepts a prediction and a true value
  3. data, which can be passed to torch.utils.DataLoader with different batch size parameters such that it returns (x,y) tuples when iterated on
  4. a csv filename which acts as the cache for the covariance model ofthis unique (model, data, loss) combination.

Implementation of RFD

Such a covariance model can then be passed to RFD which implements the pytorch optimizer interface. The end result can be used like torch.optim.Adam

Example usage

from benchmaking.classification.mnist.models.cnn3 import CNN3

import torch
import torchvision as tv

from pyrfd import RFD, SquaredExponential

cov_model = SquaredExponential()
cov_model.auto_fit(
    model_factory=CNN3,
    loss=torch.nn.functional.nll_loss,
    data= tv.datasets.MNIST(
        root="mnistSimpleCNN/data",
        train=True,
        transform=tv.transforms.ToTensor()
    ),
    cache="cache/CNN3_mnist.csv",
    # should be unique for (models, data, loss)
)
rfd = RFD(
    CNN3().parameters(),
    covariance_model=cov_model
)

How to cite

@inproceedings{benningRandomFunctionDescent2024,
  title = {Random {{Function Descent}}},
  booktitle = {Advances in {{Neural Information Processing Systems}}},
  author = {Benning, Felix and D{\"o}ring, Leif},
  year = {2024},
  month = dec,
  volume = {37},
  primaryclass = {cs, math, stat},
  publisher = {Curran Associates, Inc.},
  address = {Vancouver, Canada},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyrfd-1.0.1.tar.gz (15.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pyrfd-1.0.1-py3-none-any.whl (17.0 kB view details)

Uploaded Python 3

File details

Details for the file pyrfd-1.0.1.tar.gz.

File metadata

  • Download URL: pyrfd-1.0.1.tar.gz
  • Upload date:
  • Size: 15.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for pyrfd-1.0.1.tar.gz
Algorithm Hash digest
SHA256 69b9c05918ab524dec76fb9050e4146ee12751115b82d03d2148f0a35b6ed1f3
MD5 f1a78786f4ea396a52fa6cdbf1f9256a
BLAKE2b-256 98511e2bb8cd10c5316f99980c6d60a8f4a3ce91401fc1ae76a6d1711a2858c7

See more details on using hashes here.

File details

Details for the file pyrfd-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: pyrfd-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 17.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for pyrfd-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 bc704ce7fc2ed5a57bbb2ba308460db216edba743da789adf21524e5ac54bc24
MD5 9d20055a7d5891bad60ec68dcb1037b8
BLAKE2b-256 747af27d012915ee4782562edd1a604156794709bf0f1f57e6f2e66edc446702

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page