Skip to main content

Implementation of Gradient Origin Networks in PyTorch

Project description

Gradient Origin Networks in PyTorch

Unofficial PyTorch implementation of Gradient Origin Networks.

Reconstructions Samples

Usage

Training

Requirements:

After cloning the repository, a GON can be trained using the train_gon.py script:

python train_gon.py dataset.name=<MNIST|FashionMNIST|CIFAR10> dataset.root=<data-root>

All configuration options are listed in config/config.yaml. See the hydra documentation for more information on configuration.

From Code

Install the package:

pip install gon-pytorch

Instantiate a GON with NeRF positional encodings:

import torch
from gon_pytorch import NeRFPositionalEncoding, ImplicitDecoder, GON, SirenBlockFactory

pos_encoder = NeRFPositionalEncoding(in_dim=2)
decoder = ImplicitDecoder(
    latent_dim=128,
    out_dim=3,
    hidden_dim=128,
    num_layers=4,
    block_factory=SirenBlockFactory(),
    pos_encoder=pos_encoder
)
gon = GON(decoder)

coords = torch.randn(1, 32, 32, 2)
image = torch.rand(1, 32, 32, 3)

# Obtain latent
latent, latent_loss = gon.infer_latents(coords, image)

# Reconstruct from latent
recon = gon(coords, latent)

# Optimize model
loss = ((recon - image) ** 2).mean()
loss.backward()

Differences to the original implementation

  • Cross-entropy is used as loss instead of MSE as this seems to be improve results
  • The original implementation obtains gradients with respect to the origin by calculating the mean over the latent loss. This seems to cause a bias on the batch-size as the mean loss is evenly distributed on the single latents in the backward pass. This is fixed by summing over the batch dimension for the latent loss instead of using the mean.
  • Latent modulation from Modulated Periodic Activations for Generalizable Local Functional Representations is implemented and can optionally be used.

Citations

@misc{bondtaylor2021gradient,
      title={Gradient Origin Networks}, 
      author={Sam Bond-Taylor and Chris G. Willcocks},
      year={2021},
      eprint={2007.02798},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{sitzmann2020implicit,
      title={Implicit Neural Representations with Periodic Activation Functions}, 
      author={Vincent Sitzmann and Julien N. P. Martel and Alexander W. Bergman and David B. Lindell and Gordon Wetzstein},
      year={2020},
      eprint={2006.09661},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{mildenhall2020nerf,
      title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis}, 
      author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
      year={2020},
      eprint={2003.08934},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{mehta2021modulated,
    title   = {Modulated Periodic Activations for Generalizable Local Functional Representations}, 
    author  = {Ishit Mehta and Michaël Gharbi and Connelly Barnes and Eli Shechtman and Ravi Ramamoorthi and Manmohan Chandraker},
    year    = {2021},
    eprint  = {2104.03960},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gon-pytorch-0.1.1.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

gon_pytorch-0.1.1-py3-none-any.whl (7.2 kB view details)

Uploaded Python 3

File details

Details for the file gon-pytorch-0.1.1.tar.gz.

File metadata

  • Download URL: gon-pytorch-0.1.1.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.5

File hashes

Hashes for gon-pytorch-0.1.1.tar.gz
Algorithm Hash digest
SHA256 3e7d123a285a227ecbe8258d1b590d6a9ed3502f0ce02785d6fef4adc12dfd44
MD5 ebb76eb1cb1f76e76ec8af7a6268a3d9
BLAKE2b-256 e3f2171fe768a6d9421a6b9805e85dbf59f0ac2d4ed1c4973093b2b661f544b1

See more details on using hashes here.

File details

Details for the file gon_pytorch-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: gon_pytorch-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 7.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.5

File hashes

Hashes for gon_pytorch-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 55cea2f6c30af560ab9e59a3808cddde6cf60f06a2b62565bfb3f393eca464df
MD5 da62262b06a97fb2782298a313179dfc
BLAKE2b-256 51af848d3699a8910b5c69effa647cd3a07426cdc5a23382ff448f270deebb5e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page