Skip to main content

A method to generate counterfactuals

Project description

Latent Shift - A Simple Autoencoder Approach to Counterfactual Generation

Open In Colab

The idea

Read the paper: https://arxiv.org/abs/2102.09475

Watch a video: https://www.youtube.com/watch?v=1fxSDP8DheI

The main diagram: latentshift.gif

Animations/GIFs

Smiling Arched Eyebrows
Mouth Slightly Open Young

Generating a transition sequence

For a predicting of smiling

gen_sequence.png

Multiple different targets

Comparison to traditional methods

For a predicting of pointy_nose

comparison.png

Getting Started

# Load classifier and autoencoder
model = classifiers.FaceAttribute()
ae = autoencoders.Transformer(weights="celeba")

# Load image
input = torch.randn(1, 3, 1024, 1024)

# Defining Latent Shift module
attr = captum.attr.LatentShift(model, ae)

# Computes counterfactual for class 3.
output = attr.attribute(input, target=3)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

latentshift-0.0.2.tar.gz (5.3 kB view details)

Uploaded Source

Built Distribution

latentshift-0.0.2-py3-none-any.whl (6.1 kB view details)

Uploaded Python 3

File details

Details for the file latentshift-0.0.2.tar.gz.

File metadata

  • Download URL: latentshift-0.0.2.tar.gz
  • Upload date:
  • Size: 5.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.0

File hashes

Hashes for latentshift-0.0.2.tar.gz
Algorithm Hash digest
SHA256 0bab149d88228fae0b50505abb248210f0d0c34bc601cc2ce1124b6cedac8f7d
MD5 45907d34795f09ae154a7a4872f6c988
BLAKE2b-256 547103d7566e8f5ef5c9573e8dbdb78024dde047786694a976ee7c339686afae

See more details on using hashes here.

File details

Details for the file latentshift-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: latentshift-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 6.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.0

File hashes

Hashes for latentshift-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 813ab872a655a0e43b38febd62bdd15cd07960088bed60db70bfa2389346fddf
MD5 d0d2ce9cfc1f724c59b6c6a7d75b87e9
BLAKE2b-256 c493fa27eb46377a1d704f320aa9bd793f021df4c078a3e9e6a41e79e9e2d405

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page