LPIPS Similarity metric for Jax
Project description
LPIPS-Jax
Jax port of the original PyTorch implementation of LPIPS. The current version supports pretrained AlexNet and VGG16, and pretrained linear layers.
Installation
pip install lpips_jax
Usage
For replicate=False:
import lpips_jax
import numpy as np
images_0 = np.random.randn(4, 224, 224, 3)
images_1 = np.random.randn(4, 224, 224, 3)
lpips = lpips_jax.LPIPSEvaluator(replicate=False, net='alexnet') # ['alexnet', 'vgg16']
out = lpips(images_0, images_1)
For replicate=True
import lpips_jax
import numpy as np
import jax
n_devices = jax.local_device_count()
images_0 = np.random.randn(n_devices, 4, 224, 224, 3)
images_1 = np.random.randn(n_devices, 4, 224, 224, 3)
# replicate=True is the default setting
lpips = lpips_jax.LPIPSEvaluator(net='alexnet') # ['alexnet', 'vgg16]
out = lpips(images_0, images_1) # internally calls jax.pmap
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
lpips_jax-0.1.0.tar.gz
(63.8 MB
view details)
File details
Details for the file lpips_jax-0.1.0.tar.gz.
File metadata
- Download URL: lpips_jax-0.1.0.tar.gz
- Upload date:
- Size: 63.8 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.7.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a286e44ce15db862b3b5244d175b0f9abbc13c0e737c8355a7dd69fb62fc693b
|
|
| MD5 |
28940a02fd0113855802925aaccbb600
|
|
| BLAKE2b-256 |
e44f5d98bdde23129144b73bc992f759f04b5fbb0db3df3051114d7d30426e0b
|