Skip to main content

Library to explain a dataset in natural language.

Project description

Library to explain a dataset in natural language.

๐Ÿ“– demo notebooks

Model Reference Output Description
iPrompt ๐Ÿ“–, ๐Ÿ—‚๏ธ, ๐Ÿ”—, ๐Ÿ“„ Explanation Generates a prompt that
explains patterns in data (Official)
D3 ๐Ÿ“–, ๐Ÿ—‚๏ธ, ๐Ÿ”—, ๐Ÿ“„ Explanation Explain the difference between two distributions
AutoPrompt โ €โ €โ €๐Ÿ—‚๏ธ, ๐Ÿ”—, ๐Ÿ“„ Explanation Find a natural-language prompt
using input-gradients (โŒ› In progress)
Aug-GAM ๐Ÿ“–, ๐Ÿ—‚๏ธ, ๐Ÿ”—, ๐Ÿ“„ Linear model Fit better linear model using an LLM
to extract embeddings (Official)
Aug-Tree ๐Ÿ“–, ๐Ÿ—‚๏ธ, ๐Ÿ”—, ๐Ÿ“„ Decision tree Fit better decision tree using an LLM
to expand features (โŒ› In progress)
Linear Finetune ๐Ÿ“–, ๐Ÿ—‚๏ธ, โ €โ € Black-box model Finetune a single linear layer
on top of LLM embeddings

๐Ÿ“–Demo notebooks โ€ƒ ๐Ÿ—‚๏ธ Doc โ€ƒ ๐Ÿ”— Reference code โ€ƒ ๐Ÿ“„ Research paper
โŒ› We plan to support other interpretable algorithms like RLPrompt, CBMs, and NBDT. If you want to contribute an algorithm, feel free to open a PR ๐Ÿ˜„

Quickstart

Installation: pip install imodelsx (or, for more control, clone and install from source)

Demos: see the demo notebooks

iPrompt

from imodelsx import explain_dataset_iprompt, get_add_two_numbers_dataset

# get a simple dataset of adding two numbers
input_strings, output_strings = get_add_two_numbers_dataset(num_examples=100)
for i in range(5):
    print(repr(input_strings[i]), repr(output_strings[i]))

# explain the relationship between the inputs and outputs
# with a natural-language prompt string
prompts, metadata = explain_dataset_iprompt(
    input_strings=input_strings,
    output_strings=output_strings,
    checkpoint='EleutherAI/gpt-j-6B', # which language model to use
    num_learned_tokens=3, # how long of a prompt to learn
    n_shots=3, # shots per example
    n_epochs=15, # how many epochs to search
    verbose=0, # how much to print
    llm_float16=True, # whether to load the model in float_16
)
--------
prompts is a list of found natural-language prompt strings

D3 (DescribeDistributionalDifferences)

import imodelsx
hypotheses, hypothesis_scores = imodelsx.explain_dataset_d3(
    pos=positive_samples, # List[str] of positive examples
    neg=negative_samples, # another List[str]
    num_steps=100,
    num_folds=2,
    batch_size=64,
)

Aug-models

Use these just a like a scikit-learn model. During training, they fit better features via LLMs, but at test-time they are extremely fast and completely transparent.

from imodelsx import AugGAMClassifier, AugTreeClassifier, AugGAMRegressor, AugTreeRegressor
import datasets
import numpy as np

# set up data
dset = datasets.load_dataset('rotten_tomatoes')['train']
dset = dset.select(np.random.choice(len(dset), size=300, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))

# fit model
m = AugGAMClassifier(
    checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
    ngrams=2, # use bigrams
)
m.fit(dset['text'], dset['label'])

# predict
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))

# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
    print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:
    print('\t', k, round(v, 2))

Linear finetune

An easy-to-fit baseline that follows the same API.

# fit a simple one-layer finetune
m = LinearFinetuneClassifier(
    checkpoint='distilbert-base-uncased',
)
m.fit(dset['text'], dset['label'])
preds = m.predict(dset_val['text'])
acc = (preds == dset_val['label']).mean()
print('validation acc', acc)

Related work

  • imodels package (JOSS 2021 github) - interpretable ML package for concise, transparent, and accurate predictive modeling (sklearn-compatible).
  • Adaptive wavelet distillation (NeurIPS 2021 pdf, github) - distilling a neural network into a concise wavelet model
  • Transformation importance (ICLR 2020 workshop pdf, github) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies)
  • Hierarchical interpretations (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
  • Interpretation regularization (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
  • PDR interpretability framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

imodelsx-0.21.tar.gz (67.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

imodelsx-0.21-py3-none-any.whl (77.7 kB view details)

Uploaded Python 3

File details

Details for the file imodelsx-0.21.tar.gz.

File metadata

  • Download URL: imodelsx-0.21.tar.gz
  • Upload date:
  • Size: 67.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for imodelsx-0.21.tar.gz
Algorithm Hash digest
SHA256 20d8e9e9f5c88b1c039d36065176a93ca67b3ec785f69b7d8eebb766194fd464
MD5 44d7b6a4236478e4e25b7fcb1ddd79a2
BLAKE2b-256 e02b6a39238e69325dea781bcbdf373cc5a3e7f72b8bfa88d0377ba733b4093b

See more details on using hashes here.

File details

Details for the file imodelsx-0.21-py3-none-any.whl.

File metadata

  • Download URL: imodelsx-0.21-py3-none-any.whl
  • Upload date:
  • Size: 77.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for imodelsx-0.21-py3-none-any.whl
Algorithm Hash digest
SHA256 538e61e89545c9033b463129f1e8bda2a207c84389c9aad3ff3ff52887c6c19f
MD5 200d9107f9f31a47770e06c09b7e20d4
BLAKE2b-256 c883ef7ad39b5d80ef02ffc8edc6aa94cab506f46278f2f9db6560e636babe11

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page