LPIPS Similarity metric for Jax
Project description
LPIPS-Jax
Jax port of the original PyTorch implementation of LPIPS. The current version supports pretrained AlexNet and VGG16, and pretrained linear layers.
Installation
pip install lpips_jax
Usage
For replicate=False
:
import lpips_jax
import numpy as np
images_0 = np.random.randn(4, 224, 224, 3)
images_1 = np.random.randn(4, 224, 224, 3)
lpips = lpips_jax.LPIPSEvaluator(replicate=False, net='alexnet') # ['alexnet', 'vgg16']
out = lpips(images_0, images_1)
For replicate=True
import lpips_jax
import numpy as np
import jax
n_devices = jax.local_device_count()
images_0 = np.random.randn(n_devices, 4, 224, 224, 3)
images_1 = np.random.randn(n_devices, 4, 224, 224, 3)
# replicate=True is the default setting
lpips = lpips_jax.LPIPSEvaluator(net='alexnet') # ['alexnet', 'vgg16]
out = lpips(images_0, images_1) # internally calls jax.pmap
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
lpips_jax-0.1.0.tar.gz
(63.8 MB
view hashes)