Skip to main content

Testing with PCA projected Concept Activation Vectors

Project description

TPCAV (Testing with PCA projected Concept Activation Vectors)

This repository contains code to compute TPCAV (Testing with PCA projected Concept Activation Vectors) on deep learning models. TPCAV is an extension of the original TCAV method, which uses PCA to reduce the dimensionality of the activations at a selected intermediate layer before computing Concept Activation Vectors (CAVs) to improve the consistency of the results.

Installation

pip install tpcav

Quick start

tpcav only works with Pytorch model, if your model is built using other libraries, you should port the model into Pytorch first. For Tensorflow models, you can use tf2onnx and onnx2pytorch for the conversion.

import torch
from tpcav import run_tpcav

class DummyModelSeq(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = torch.nn.Linear(1024, 1)
        self.layer2 = torch.nn.Linear(4, 1)

    def forward(self, seq):
        y_hat = self.layer1(seq)
        y_hat = y_hat.squeeze(-1)
        y_hat = self.layer2(y_hat)
        return y_hat

# transformation function to obtain one-hot encoded sequences
def transform_fasta_to_one_hot_seq(seq, chrom):
    # `seq` is a list of fasta sequences
    # `chrom` is a numpy array of bigwig signals of shape [-1, # bigwigs, len]
    return (helper.fasta_to_one_hot_sequences(seq),) # it has to return a tuple of inputs, even if there is only one input

motif_path = "data/motif-clustering-v2.1beta_consensus_pwms.test.meme"
bed_seq_concept = "data/hg38_rmsk.head500k.bed"
genome_fasta = "data/hg38.analysisSet.fa"
model = DummyModelSeq() # load the model
layer_name = "layer1"   # name of the layer to be interpreted

# concept_fscores_dataframe: fscores of each concept
# motif_cav_trainers: each trainer contains the cav weights of motifs inserted different number of times
# bed_cav_trainer: trainer contains the cav weights of the sequence concepts provided in bed file
concept_fscores_dataframe, motif_cav_trainers, bed_cav_trainer = run_tpcav(
    model=model,
    layer_name=layer_name,
    meme_motif_file=motif_path,
    genome_fasta=genome_fasta,
    num_motif_insertions=[4, 8],
    bed_seq_file=bed_seq_concept, 
    output_dir="test_run_tpcav_output/",
    input_transform_func=transform_fasta_to_one_hot_seq
)

# check each trainer for detailed weights
print(bed_cav_trainer.cav_weights)

Detailed Usage

For detailed usage, please refer to this jupyter notebook

If you find any issue, feel free to open an issue (strongly suggested) or contact Jianyu Yang.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tpcav-0.2.1.tar.gz (25.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tpcav-0.2.1-py3-none-any.whl (25.2 kB view details)

Uploaded Python 3

File details

Details for the file tpcav-0.2.1.tar.gz.

File metadata

  • Download URL: tpcav-0.2.1.tar.gz
  • Upload date:
  • Size: 25.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for tpcav-0.2.1.tar.gz
Algorithm Hash digest
SHA256 209c894f0c6d5cce652614626dd398f4dab2c979471b6893704415d723e4101b
MD5 6d8bfdcef57a69469e3e10738119d0ba
BLAKE2b-256 69b147beb1fff2a1b6e86d2022a4b2d13a2a39cd0da213365c5e6a3c08ec5a49

See more details on using hashes here.

File details

Details for the file tpcav-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: tpcav-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 25.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for tpcav-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 63dd9b21d3b8e08a064085c6c21d93a07cc1c682b7133be7ab22e2706c74c4fb
MD5 d0cb7e040046b5ee5332d670713163c0
BLAKE2b-256 e5344b57d3357c0b36bd1700c872195f7b72d9231c6e9a24bcc3882ef5767d5b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page