Skip to main content

Frechet Inception Distance in JAX.

Project description

PyPI

FID JAX

Clean implementation of the Frechet Inception Distance in JAX.

  • Reproduces OpenAI's TensorFlow implementation.
  • Pure JAX implementation runs on CPU/GPU/TPU and inside JIT.
  • Can load weights from GCS using pathlib API.
  • Clean and simple code.

Instructions

1️⃣ FID JAX is a single file, so you can just copy it to your project directory. Or you can install the package:

pip install fidjax

2️⃣ Download the Inception weights (credits to Matthias Wright):

wget https://www.dropbox.com/s/xt6zvlvt22dcwck/inception_v3_weights_fid.pickle?dl=1

3️⃣ Download the ImageNet reference stats of the desired resolution (generate your own for other datasets):

wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz

4️⃣ Compute activations, statistics, and scores in JAX:

import fidjax
import numpy as np

weights = './inception_v3_weights_fid.pickle?dl=1'
reference = './VIRTUAL_imagenet128_labeled.npz'
fid = fidjax.FID(weights, reference)

fid_total = 50000
fid_batch = 1000
acts = []
for range(fid_total // fid_batch):
  samples = ...  # (B, H, W, 3) jnp.uint8
  acts.append(fid.compute_acts(samples))
stats = fid.compute_stats(acts)
score = fid.compute_score(stats)

print(float(score))  # FID

Accuracy

Dataset Model FID JAX OpenAI TF
ImageNet 256 ADM (guided, upsampled) 3.937 3.943

Tutorials

Using Cloud Storage

Point to the files via a pathlib.Path implementation that support your Cloud storage. For example for GCS:

import elements  # pip install elements
import fidjax

weights = elements.Path('gs://bucket/fid/inception_v3_weights_fid.pickle')
reference = elements.Path('gs://bucket/fid/VIRTUAL_imagenet128_labeled.npz')

fid = fidjax.FID(weights, reference)

Custom Datasets

Generate reference statistics for custom datasets:

import fidjax
import numpy as np

weights = './inception_v3_weights_fid.pickle?dl=1'
fid = fidjax.FID(weights)

acts = fid.compute_acts(images)
mu, sigma = fid.compute_stats(acts)

np.savez('reference.npz', {'mu': mu, 'sigma': sigma})

Resources

Questions

Please file an issue on Github.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fidjax-1.0.1.tar.gz (5.3 kB view details)

Uploaded Source

File details

Details for the file fidjax-1.0.1.tar.gz.

File metadata

  • Download URL: fidjax-1.0.1.tar.gz
  • Upload date:
  • Size: 5.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.8

File hashes

Hashes for fidjax-1.0.1.tar.gz
Algorithm Hash digest
SHA256 7772735343251fec87fdf631fc1b31a5809f6394066dd4ee513539f0cd14fced
MD5 2326d80678159905c5eeba2ce890e3e6
BLAKE2b-256 4789a0a21d023367d4ff321955ce9755f37936761dedf3fdf6e52a897a41aafc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page