Skip to main content

No project description provided

Project description

pyrfd

Pytorch implementation of RFD

Covariance model

Provides an implementation of the SquaredExponential covariance model with an auto_fit function, which requires only

  1. A model_factory which returns the same but randomly initialized model every time it is called
  2. A loss function e.g. torch.nn.functional.nll_loss which accepts a prediction and a true value
  3. data, which can be passed to torch.utils.DataLoader with different batch size parameters such that it returns (x,y) tuples when iterated on
  4. a csv filename which acts as the cache for the covariance model ofthis unique (model, data, loss) combination.

Implementation of RFD

Such a covariance model can then be passed to RFD which implements the pytorch optimizer interface. The end result can be used like torch.optim.Adam

Example usage

from mnistSimpleCNN.models.modelM3 import ModelM3
# cf. mnistSimpleCNN directory (example model)

import torch
import torchvision as tv

from pyrfd import RFD, SquaredExponential

cov_model = SquaredExponential()
cov_model.auto_fit(
    model_factory=ModelM3,
    loss=torch.nn.functional.nll_loss,
    data= tv.datasets.MNIST(
        root="mnistSimpleCNN/data",
        train=True,
        transform=tv.transforms.ToTensor()
    ),
    cache="cache/CNN3_mnist.csv",
    # should be unique for (models, data, loss)
)
rfd = RFD(
    ModelM3().parameters(),
    covariance_model=cov_model
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyrfd-0.1.0.tar.gz (9.4 kB view hashes)

Uploaded Source

Built Distribution

pyrfd-0.1.0-py3-none-any.whl (11.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page