Frechet Inception Distance in JAX.
Project description
FID JAX
Clean implementation of the Frechet Inception Distance in JAX.
- Reproduces OpenAI's TensorFlow implementation.
- Pure JAX implementation runs on CPU/GPU/TPU and inside JIT.
- Can load weights from GCS using
pathlib
API. - Clean and simple code.
Instructions
1️⃣ FID JAX is a single file, so you can just copy it to your project directory. Or you can install the package:
pip install fidjax
2️⃣ Download the Inception weights (credits to Matthias Wright):
wget https://www.dropbox.com/s/xt6zvlvt22dcwck/inception_v3_weights_fid.pickle?dl=1
3️⃣ Download the ImageNet reference stats of the desired resolution (generate your own for other datasets):
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz
4️⃣ Compute activations, statistics, and scores in JAX:
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
reference = './VIRTUAL_imagenet128_labeled.npz'
fid = fidjax.FID(weights, reference)
fid_total = 50000
fid_batch = 1000
acts = []
for range(fid_total // fid_batch):
samples = ... # (B, H, W, 3) jnp.uint8
acts.append(fid.compute_acts(samples))
stats = fid.compute_stats(acts)
score = fid.compute_score(stats)
print(float(score)) # FID
Accuracy
Dataset | Model | FID JAX | OpenAI TF |
---|---|---|---|
ImageNet 256 | ADM (guided, upsampled) | 3.937 | 3.943 |
Tutorials
Using Cloud Storage
Point to the files via a pathlib.Path
implementation that support your
Cloud storage. For example for GCS:
import elements # pip install elements
import fidjax
weights = elements.Path('gs://bucket/fid/inception_v3_weights_fid.pickle')
reference = elements.Path('gs://bucket/fid/VIRTUAL_imagenet128_labeled.npz')
fid = fidjax.FID(weights, reference)
Custom Datasets
Generate reference statistics for custom datasets:
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
fid = fidjax.FID(weights)
acts = fid.compute_acts(images)
mu, sigma = fid.compute_stats(acts)
np.savez('reference.npz', {'mu': mu, 'sigma': sigma})
Resources
Questions
Please file an issue on Github.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file fidjax-1.0.1.tar.gz
.
File metadata
- Download URL: fidjax-1.0.1.tar.gz
- Upload date:
- Size: 5.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7772735343251fec87fdf631fc1b31a5809f6394066dd4ee513539f0cd14fced |
|
MD5 | 2326d80678159905c5eeba2ce890e3e6 |
|
BLAKE2b-256 | 4789a0a21d023367d4ff321955ce9755f37936761dedf3fdf6e52a897a41aafc |