Library to explain a dataset in natural language.
Project description
Library to explain a dataset in natural language.
Model | Reference | Output | Description |
---|---|---|---|
iPrompt | ๐, ๐๏ธ, ๐, ๐ | Explanation | Generates a prompt that explains patterns in data (Official) |
D3 | ๐, ๐๏ธ, ๐, ๐ | Explanation | Explain the difference between two distributions |
AutoPrompt | โ โ โ ๐๏ธ, ๐, ๐ | Explanation | Find a natural-language prompt using input-gradients (โ In progress) |
Aug-GAM | ๐, ๐๏ธ, ๐, ๐ | Linear model | Fit better linear model using an LLM to extract embeddings (Official) |
Aug-Tree | ๐, ๐๏ธ, ๐, ๐ | Decision tree | Fit better decision tree using an LLM to expand features (โ In progress) |
Linear Finetune | ๐, ๐๏ธ, โ โ | Black-box model | Finetune a single linear layer on top of LLM embeddings |
๐Demo notebooks โ ๐๏ธ Doc โ ๐ Reference code โ ๐ Research paper
โ We plan to support other interpretable algorithms like RLPrompt, CBMs, and NBDT. If you want to contribute an algorithm, feel free to open a PR ๐
Quickstart
Installation: pip install imodelsx
(or, for more control, clone and install from source)
Demos: see the demo notebooks
iPrompt
from imodelsx import explain_dataset_iprompt, get_add_two_numbers_dataset
# get a simple dataset of adding two numbers
input_strings, output_strings = get_add_two_numbers_dataset(num_examples=100)
for i in range(5):
print(repr(input_strings[i]), repr(output_strings[i]))
# explain the relationship between the inputs and outputs
# with a natural-language prompt string
prompts, metadata = explain_dataset_iprompt(
input_strings=input_strings,
output_strings=output_strings,
checkpoint='EleutherAI/gpt-j-6B', # which language model to use
num_learned_tokens=3, # how long of a prompt to learn
n_shots=3, # shots per example
n_epochs=15, # how many epochs to search
verbose=0, # how much to print
llm_float16=True, # whether to load the model in float_16
)
--------
prompts is a list of found natural-language prompt strings
D3 (DescribeDistributionalDifferences)
import imodelsx
hypotheses, hypothesis_scores = imodelsx.explain_dataset_d3(
pos=positive_samples, # List[str] of positive examples
neg=negative_samples, # another List[str]
num_steps=100,
num_folds=2,
batch_size=64,
)
Aug-models
Use these just a like a scikit-learn model. During training, they fit better features via LLMs, but at test-time they are extremely fast and completely transparent.
from imodelsx import AugGAMClassifier, AugTreeClassifier, AugGAMRegressor, AugTreeRegressor
import datasets
import numpy as np
# set up data
dset = datasets.load_dataset('rotten_tomatoes')['train']
dset = dset.select(np.random.choice(len(dset), size=300, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))
# fit model
m = AugGAMClassifier(
checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
ngrams=2, # use bigrams
)
m.fit(dset['text'], dset['label'])
# predict
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))
# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:
print('\t', k, round(v, 2))
Linear finetune
An easy-to-fit baseline that follows the same API.
# fit a simple one-layer finetune
m = LinearFinetuneClassifier(
checkpoint='distilbert-base-uncased',
)
m.fit(dset['text'], dset['label'])
preds = m.predict(dset_val['text'])
acc = (preds == dset_val['label']).mean()
print('validation acc', acc)
Related work
- imodels package (JOSS 2021 github) - interpretable ML package for concise, transparent, and accurate predictive modeling (sklearn-compatible).
- Adaptive wavelet distillation (NeurIPS 2021 pdf, github) - distilling a neural network into a concise wavelet model
- Transformation importance (ICLR 2020 workshop pdf, github) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies)
- Hierarchical interpretations (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
- Interpretation regularization (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
- PDR interpretability framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.