Skip to main content

LPIPS Similarity metric for Jax

Project description

LPIPS-Jax

Jax port of the original PyTorch implementation of LPIPS. The current version supports pretrained AlexNet and VGG16, and pretrained linear layers.

Installation

pip install lpips_jax

Usage

For replicate=False:

import lpips_jax
import numpy as np

images_0 = np.random.randn(4, 224, 224, 3)
images_1 = np.random.randn(4, 224, 224, 3)

lpips = lpips_jax.LPIPSEvaluator(replicate=False, net='alexnet') # ['alexnet', 'vgg16']
out = lpips(images_0, images_1)

For replicate=True

import lpips_jax
import numpy as np
import jax

n_devices = jax.local_device_count()
images_0 = np.random.randn(n_devices, 4, 224, 224, 3)
images_1 = np.random.randn(n_devices, 4, 224, 224, 3)

# replicate=True is the default setting
lpips = lpips_jax.LPIPSEvaluator(net='alexnet') # ['alexnet', 'vgg16]
out = lpips(images_0, images_1) # internally calls jax.pmap

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lpips_jax-0.1.0.tar.gz (63.8 MB view details)

Uploaded Source

File details

Details for the file lpips_jax-0.1.0.tar.gz.

File metadata

  • Download URL: lpips_jax-0.1.0.tar.gz
  • Upload date:
  • Size: 63.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.7.13

File hashes

Hashes for lpips_jax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 a286e44ce15db862b3b5244d175b0f9abbc13c0e737c8355a7dd69fb62fc693b
MD5 28940a02fd0113855802925aaccbb600
BLAKE2b-256 e44f5d98bdde23129144b73bc992f759f04b5fbb0db3df3051114d7d30426e0b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page