Skip to main content

Erasing concepts from neural representations with provable guarantees

Project description

Least-Squares Concept Erasure (LEACE)

Concept erasure aims to remove specified features from a representation. It can be used to improve fairness (e.g. preventing a classifier from using gender or race) and interpretability (e.g. removing a concept to observe changes in model behavior). This is the repo for LEAst-squares Concept Erasure (LEACE), a closed-form method which provably prevents all linear classifiers from detecting a concept while inflicting the least possible damage to the representation. You can check out the paper here.

Usage

ConceptEraser is the central class in this repo. It keeps track of the covariance and cross-covariance statistics needed to erase a concept, and lazily computes the LEACE parameters when needed.

Batch usage

In most cases, you probably have a batch of feature vectors X and concept labels Z and want to erase the concept from X. The easiest way to do this is using ConceptEraser.fit() followed by ConceptEraser.forward():

import torch
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression

from concept_erasure import ConceptEraser

n, d, k = 2048, 128, 2

X, Y = make_classification(
    n_samples=n,
    n_features=d,
    n_classes=k,
    random_state=42,
)
X_t = torch.from_numpy(X)
Y_t = torch.from_numpy(Y)

# Logistic regression does learn something before concept erasure
real_lr = LogisticRegression(max_iter=1000).fit(X, Y)
beta = torch.from_numpy(real_lr.coef_)
assert beta.norm(p=torch.inf) > 0.1

eraser = ConceptEraser.fit(X_t, Y_t)
X_ = eraser(X_t)

# But learns nothing after
null_lr = LogisticRegression(max_iter=1000, tol=0.0).fit(X_.numpy(), Y)
beta = torch.from_numpy(null_lr.coef_)
assert beta.norm(p=torch.inf) < 1e-4

Streaming usage

If you have a stream of data, you can use ConceptEraser.update() to update the statistics and ConceptEraser.forward() to erase the concept. This is useful if you have a large dataset and want to avoid storing it all in memory.

from concept_erasure import ConceptEraser
from sklearn.datasets import make_classification
import torch

n, d, k = 2048, 128, 2

X, Y = make_classification(
    n_samples=n,
    n_features=d,
    n_classes=k,
    random_state=42,
)
X_t = torch.from_numpy(X)
Y_t = torch.from_numpy(Y)

eraser = ConceptEraser(d, 1, dtype=X_t.dtype)

# Compute cross-covariance matrix using batched updates
for x, y in zip(X_t.chunk(2), Y_t.chunk(2)):
    eraser.update(x, y)

# Erase the concept from the data
x_ = eraser(X_t[0])

Paper replication

Scripts used to generate the part-of-speech tags for the concept scrubbing experiments can be found in the experiments folder. We plan to upload the tagged datasets to the HuggingFace Hub shortly.

Concept scrubbing

The concept scrubbing code is a bit messy right now, and will probably be refactored soon. We found it necessary to write bespoke implementations for different HuggingFace model families. So far we've implemented LLaMA and GPT-NeoX. These can be found in the concept_erasure.scrubbing submodule.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

concept-erasure-0.0.1.tar.gz (17.4 kB view details)

Uploaded Source

Built Distribution

concept_erasure-0.0.1-py3-none-any.whl (17.3 kB view details)

Uploaded Python 3

File details

Details for the file concept-erasure-0.0.1.tar.gz.

File metadata

  • Download URL: concept-erasure-0.0.1.tar.gz
  • Upload date:
  • Size: 17.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.10

File hashes

Hashes for concept-erasure-0.0.1.tar.gz
Algorithm Hash digest
SHA256 9330372075619eb4c556ebb443e9738eb562ece6e8932e49d4077e7e6d8c26b2
MD5 16ff822440f54b3717f00c6c8692b90e
BLAKE2b-256 dbb7e6b38e5cb1191976d02af89e30e4d467fc13ad3cd6a55fdb9ca879faa804

See more details on using hashes here.

File details

Details for the file concept_erasure-0.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for concept_erasure-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 daec7d5d5448bbc55bd4a9c28678b0c31832e81e7321caf632557b43bb06acb9
MD5 a1387e008be7654d0259d648be74a95e
BLAKE2b-256 e136086f548c35710c4accbb62030f34a083ec70ce8010322c0773735b46ef48

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page