Skip to main content

Tensorflow 2.0 implementation of Sinusodial Representation networks (SIREN).

Project description

Tensorflow Sinusodial Representation Networks (SIREN)

Tensorflow 2.0 implementation of Sinusodial Representation networks (SIREN) from the paper Implicit Neural Representations with Periodic Activation Functions.

Installation

  • Pip install
$ pip install --upgrade tf_siren
  • Pip install (test support)
$ pip install --upgrade tf_siren[tests]

Usage

For general usage equivalent to the paper, import and use either SinusodialRepresentationDense or SIRENModel.

from tf_siren import SinusodialRepresentationDense
from tf_siren import SIRENModel

# You can use SinusodialRepresentationDense exactly like you ordinarily use Dense layers.
ip = tf.keras.layers.Input(shape=[2])
x = SinusodialRepresentationDense(32,
                                  activation='sine', # default activation function
                                  w0=1.0)(ip)        # w0 represents sin(w0 * x) where x is the input.

model = tf.keras.Model(inputs=ip, outputs=x)

# Or directly use the model class to build a multi layer SIREN
model = SIRENModel(units=256, final_units=3, final_activation='sigmoid',
                   num_layers=5, w0=1.0, w0_initial=30.0)

For the (experimental) kernel scaled variants, import and use either ScaledSinusodialRepresentationDense or ScaledSIRENModel.

from tf_siren import ScaledSinusodialRepresentationDense
from tf_siren import ScaledSIRENModel

# You can use SinusodialRepresentationDense exactly like you ordinarily use Dense layers.
ip = tf.keras.layers.Input(shape=[2])
x = ScaledSinusodialRepresentationDense(32,
                                        scale=1.0          # scale value should be carefully chosen in range [1, 2]
                                        activation='sine', # default activation function
                                        w0=1.0)(ip)        # w0 represents sin(w0 * x) where x is the input.

model = tf.keras.Model(inputs=ip, outputs=x)

# Or directly use the model class to build a multi layer Scaled SIREN
model = ScaledSIRENModel(units=256, final_units=3, final_activation='sigmoid', scale=1.0,
                         num_layers=5, w0=1.0, w0_initial=30.0)

Results on Image Inpainting task

A partial implementation of the image inpainting task is available as the train_inpainting_siren.py and eval_inpainting_siren.py scripts inside the scripts directory.

Weight files are made available in the repository under the Release tab of the project. Extract the weights and place the checkpoints folder at the scripts directory

These weights generates the following output after 5000 epochs of training with batch size 8192 while using only 10% of the available pixels in the image during training phase.


If we train for using only 20% of the available pixels in the image during training phase -


If we train for using only 30% of the available pixels in the image during training phase -

SIREN Hyper Network

We can use a Hyper Network in order to encode an entire dataset into the weights of a SIREN model. The weights for the SIREN model are generated by this hyper network, which computes these weights based on an encoded representation.

Support for the Hyper Network is available by using NeuralProcessHyperNet, which uses the SetEncoder from the paper as the encoder.

Training on the CIFAR 10 dataset is available inside the scripts directory - train_cifar_inpainting_siren.py and eval_cifar_inpainting_siren.py.

Pre-trained weights are available in the Release tab under assets.

On evaluating on the test set with 1000 context pixels, this model gets an average MSE of 0.009. Using 100 context pixels, the MSE increases to 0.019.

The following image is using 1000 context pixels on the test set :

(Experimental) Comparison of convergence between original and kernel scaled SIRENs

The kernel scaled variants of the model converge faster than the original SIREN under certain circumstances. All the models below are trained with Adam optimizer with constant learning rate of 5e-5 for 5000 epochs and batch size of 8192 on the same image pixels (10% of the celtic spiral image).

The tensorboard logs can be found here -

Citation

@inproceedings{sitzmann2019siren,
    author = {Sitzmann, Vincent
              and Martel, Julien N.P.
              and Bergman, Alexander W.
              and Lindell, David B.
              and Wetzstein, Gordon},
    title = {Implicit Neural Representations
              with Periodic Activation Functions},
    booktitle = {arXiv},
    year={2020}
}

Requirements

  • Tensorflow 2.0+
  • Matplotlib to visualize eval result

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tf_siren-0.0.5.tar.gz (17.6 kB view details)

Uploaded Source

Built Distribution

tf_siren-0.0.5-py2.py3-none-any.whl (15.7 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file tf_siren-0.0.5.tar.gz.

File metadata

  • Download URL: tf_siren-0.0.5.tar.gz
  • Upload date:
  • Size: 17.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.44.1 CPython/3.7.4

File hashes

Hashes for tf_siren-0.0.5.tar.gz
Algorithm Hash digest
SHA256 69630f83f1cdf244d3937b01f009123e2e677be9229d7ee515bd576d10445d9e
MD5 29853a33b907637af3b9671bd853e829
BLAKE2b-256 67df1c74025d5753bf2dcff66ff66057307e625d9693e63c7a994ccd46c055dc

See more details on using hashes here.

File details

Details for the file tf_siren-0.0.5-py2.py3-none-any.whl.

File metadata

  • Download URL: tf_siren-0.0.5-py2.py3-none-any.whl
  • Upload date:
  • Size: 15.7 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.44.1 CPython/3.7.4

File hashes

Hashes for tf_siren-0.0.5-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 fd174ca5a932b3ae04597bb2d1d8da41b6b899cd5e2caa90ef970e4b6306679a
MD5 d758abea48e0a7e90ed3729a1e532217
BLAKE2b-256 f949e27fd34d6c4ff58a2dc0cf758faf3d734d7b00d91030caa9548d8589971b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page