a weak supervision learning benchmark
Project description
🔧 New
2/13/22
- Add script to generate LFs for any tabular dataset as well as 5 new tabular datasets, namely, mushroom, spambase, PhishingWebsites, Bioresponse, and bank-marketing.
11/04/21
- (beta) Add
parallel_fit
for torch model to support pytorch DistributedDataParallel-example
10/15/21
- A branch of new methods: WeaSEL, ImplyLoss, ASTRA, MeanTeacher, Meta-Weight-Net, Learning-to-Reweight
- Support image classification (dataset class / torchvision backbone) as well as DomainNet/Animals-with-Attributes2 datasets (check out the
datasets
folder)
🔧 What is it?
Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development and evaluation of your own weak supervision models within the benchmark.
For more information, checkout our publications:
- WRENCH: A Comprehensive Benchmark for Weak Supervision (NeurIPS 2021)
If you find this repository helpful, feel free to cite our publication:
@inproceedings{
zhang2021wrench,
title={{WRENCH}: A Comprehensive Benchmark for Weak Supervision},
author={Jieyu Zhang and Yue Yu and Yinghao Li and Yujing Wang and Yaming Yang and Mao Yang and Alexander Ratner},
booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
year={2021},
url={https://openreview.net/forum?id=Q9SKS5k8io}
}
🔧 What is weak supervision?
Weak Supervision is a paradigm for automated training data creation without manual annotations.
For a brief overview, please check out this blog.
For more context, please check out this survey.
To track recent advances in weak supervision, please follow this repo.
🔧 Installation
[1] Install anaconda: Instructions here: https://www.anaconda.com/download/
[2] Clone the repository:
git clone https://github.com/JieyuZ2/wrench.git
cd wrench
[3] Create virtual environment:
conda env create -f environment.yml
source activate wrench
If this not working or you want to use only a subset of modules of Wrench, check out this wiki page
🔧 Available Datasets
The datasets can be downloaded via this.
Note that some datasets may have more training examples than what is reported in README/paper because we include the dev set, whose indices can be found in labeled_id.json if exists.
A documentation of dataset format and usage can be found in this wiki-page
classification:
Name | Task | # class | # LF | # train | # validation | # test | data source | LF source |
---|---|---|---|---|---|---|---|---|
Census | income clasification | 2 | 83 | 10083 | 5561 | 16281 | link | link |
Youtube | spam clasification | 2 | 10 | 1586 | 120 | 250 | link | link |
SMS | spam clasification | 2 | 73 | 4571 | 500 | 500 | link | link |
IMDB | sentiment clasification | 2 | 8 | 20000 | 2500 | 2500 | link | link |
Yelp | sentiment clasification | 2 | 8 | 30400 | 3800 | 3800 | link | link |
AGNews | topic clasification | 4 | 9 | 96000 | 12000 | 12000 | link | link |
TREC | question classification | 6 | 68 | 4965 | 500 | 500 | link | link |
Spouse | relation classification | 2 | 9 | 22254 | 2801 | 2701 | link | link |
SemEval | relation classification | 9 | 164 | 1749 | 200 | 692 | link | link |
CDR | bio relation classification | 2 | 33 | 8430 | 920 | 4673 | link | link |
Chemprot | chemical relation classification | 10 | 26 | 12861 | 1607 | 1607 | link | link |
Commercial | video frame classification | 2 | 4 | 64130 | 9479 | 7496 | link | link |
Tennis Rally | video frame classification | 2 | 6 | 6959 | 746 | 1098 | link | link |
Basketball | video frame classification | 2 | 4 | 17970 | 1064 | 1222 | link | link |
DomainNet | image classification | - | - | - | - | - | link | link |
sequence tagging:
Name | # class | # LF | # train | # validation | # test | data source | LF source |
---|---|---|---|---|---|---|---|
CoNLL-03 | 4 | 16 | 14041 | 3250 | 3453 | link | link |
WikiGold | 4 | 16 | 1355 | 169 | 170 | link | link |
OntoNotes 5.0 | 18 | 17 | 115812 | 5000 | 22897 | link | link |
BC5CDR | 2 | 9 | 500 | 500 | 500 | link | link |
NCBI-Disease | 1 | 5 | 592 | 99 | 99 | link | link |
Laptop-Review | 1 | 3 | 2436 | 609 | 800 | link | link |
MIT-Restaurant | 8 | 16 | 7159 | 500 | 1521 | link | link |
MIT-Movies | 12 | 7 | 9241 | 500 | 2441 | link | link |
The detailed documentation is coming soon.
🔧 Available Models
If you find any of the implementations is wrong/problematic, don't hesitate to raise issue/pull request, we really appreciate it!
TODO-list: check this out!
classification:
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
Majority Voting | Label Model | -- | link |
Weighted Majority Voting | Label Model | -- | link |
Dawid-Skene | Label Model | link | link |
Data Progamming | Label Model | link | link |
MeTaL | Label Model | link | link |
FlyingSquid | Label Model | link | link |
Logistic Regression | End Model | -- | link |
MLP | End Model | -- | link |
BERT | End Model | link | link |
COSINE | End Model | link | link |
Denoise | Joint Model | link | link |
WeaSEL | Joint Model | link | link |
sequence tagging:
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
Hidden Markov Model | Label Model | link | link |
Conditional Hidden Markov Model | Label Model | link | link |
LSTM-CNNs-CRF | End Model | link | link |
BERT-CRF | End Model | link | link |
LSTM-ConNet | Joint Model | link | link |
BERT-ConNet | Joint Model | link | link |
classification-to-sequence-tagging wrapper:
Wrench also provides a SeqLabelModelWrapper
that adaptes label model for classification task to sequence tagging task.
methods from related domains:
Robust Learning methods as end model:
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
Meta-Weight-Net | End Model | link | link |
Learning2ReWeight | End Model | link | link |
Semi-Supervised Learning methods as end model:
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
MeanTeacher | End Model | link | link |
Weak Supervision with cleaned labels (Semi-Weak Supervision):
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
ImplyLoss | Joint Model | link | link |
ASTRA | Joint Model | link | link |
🔧 Quick examples
🔧 Label model with parallel grid search for hyper-parameters
import logging
import numpy as np
import pprint
from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.search import grid_search
from wrench import labelmodel
from wrench.evaluation import AverageMeter
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
#### Load dataset
dataset_home = '../datasets'
data = 'youtube'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)
#### Specify the hyper-parameter search space for grid search
search_space = {
'Snorkel': {
'lr': np.logspace(-5, -1, num=5, base=10),
'l2': np.logspace(-5, -1, num=5, base=10),
'n_epochs': [5, 10, 50, 100, 200],
}
}
#### Initialize label model
label_model_name = 'Snorkel'
label_model = getattr(labelmodel, label_model_name)
#### Search best hyper-parameters using validation set in parallel
n_trials = 100
n_repeats = 5
target = 'acc'
searched_paras = grid_search(label_model(), dataset_train=train_data, dataset_valid=valid_data,
metric=target, direction='auto', search_space=search_space[label_model_name],
n_repeats=n_repeats, n_trials=n_trials, parallel=True)
#### Evaluate the label model with searched hyper-parameters and average meter
meter = AverageMeter(names=[target])
for i in range(n_repeats):
model = label_model(**searched_paras)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data)
metric_value = model.test(test_data, target)
meter.update(target=metric_value)
metrics = meter.get_results()
pprint.pprint(metrics)
For detailed guidance of grid_search
, please check out this wiki page.
🔧 Run a standard supervised learning pipeline
import logging
import torch
from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.endmodel import MLPModel
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
#### Load dataset
dataset_home = '../datasets'
data = 'youtube'
#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
cache_name=extract_fn, model_name=model_name)
#### Train a MLP classifier
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000
patience = 200
evaluation_step = 50
target='acc'
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, device=device, metric=target,
patience=patience, evaluation_step=evaluation_step)
#### Evaluate the trained model
metric_value = model.test(test_data, target)
🔧 Build a two-stage weak supervision pipeline
import logging
import torch
from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.endmodel import MLPModel
from wrench.labelmodel import MajorityVoting
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
#### Load dataset
dataset_home = '../datasets'
data = 'youtube'
#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
cache_name=extract_fn, model_name=model_name)
#### Generate soft training label via a label model
#### The weak labels provided by supervision sources are alreadly encoded in dataset object
label_model = MajorityVoting()
label_model.fit(train_data, valid_data)
soft_label = label_model.predict_proba(train_data)
#### Train a MLP classifier with soft label
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000
patience = 200
evaluation_step = 50
target='acc'
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=soft_label,
device=device, metric=target, patience=patience, evaluation_step=evaluation_step)
#### Evaluate the trained model
metric_value = model.test(test_data, target)
#### We can also train a MLP classifier with hard label
from snorkel.utils import probs_to_preds
hard_label = probs_to_preds(soft_label)
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=hard_label,
device=device, metric=target, patience=patience, evaluation_step=evaluation_step)
🔧 Procedural labeling function generator
import logging
import torch
from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.synthetic import ConditionalIndependentGenerator, NGramLFGenerator
from wrench.labelmodel import FlyingSquid
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
#### Generate synthetic dataset
generator = ConditionalIndependentGenerator(
n_class=2,
n_lfs=10,
alpha=0.75, # mean accuracy
beta=0.1, # mean propensity
alpha_radius=0.2, # radius of accuracy
beta_radius=0.1 # radius of propensity
)
train_data = generator.generate_split('train', 10000)
valid_data = generator.generate_split('valid', 1000)
test_data = generator.generate_split('test', 1000)
#### Evaluate label model on synthetic dataset
label_model = FlyingSquid()
label_model.fit(dataset_train=train_data, dataset_valid=valid_data)
target_value = label_model.test(test_data, metric_fn='auc')
#### Load dataset
dataset_home = '../datasets'
data = 'youtube'
#### Load real-world dataset
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)
#### Generate procedural labeling functions
generator = NGramLFGenerator(dataset=train_data, min_acc_gain=0.1, min_support=0.01, ngram_range=(1, 2))
applier = generator.generate(mode='correlated', n_lfs=10)
L_test = applier.apply(test_data)
L_train = applier.apply(train_data)
#### Evaluate label model on real-world dataset with semi-synthetic labeling functions
label_model = FlyingSquid()
label_model.fit(dataset_train=L_train, dataset_valid=valid_data)
target_value = label_model.test(L_test, metric_fn='auc')
🔧 Contact
Contact person: Jieyu Zhang, jieyuzhang97@gmail.com
Don't hesitate to send us an e-mail if you have any question.
We're also open to any collaboration!
🔧 Contributing Dataset and Model
We sincerely welcome any contribution to the datasets or models!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file ws-benchmark-1.1.1.tar.gz
.
File metadata
- Download URL: ws-benchmark-1.1.1.tar.gz
- Upload date:
- Size: 109.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.26.0 requests-toolbelt/0.9.1 urllib3/1.26.6 tqdm/4.62.1 importlib-metadata/3.10.0 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.4 CPython/3.6.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d17a3f14871909b70be15e0b368e7429887dc1aa2702e3b1370c4e6bd7541151 |
|
MD5 | 04221bf8562de19a49caa87ddcc17268 |
|
BLAKE2b-256 | 7fcda2e92527ff63fb5d9f18f5e39f2dedf832aa3c00cfc3984d0c290c1f3b9c |
File details
Details for the file ws_benchmark-1.1.1-py3-none-any.whl
.
File metadata
- Download URL: ws_benchmark-1.1.1-py3-none-any.whl
- Upload date:
- Size: 149.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.26.0 requests-toolbelt/0.9.1 urllib3/1.26.6 tqdm/4.62.1 importlib-metadata/3.10.0 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.4 CPython/3.6.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | efe7512939c702b8332229f4cb01e3cf1eeb6ce7aff21c06bcb59f182a979a41 |
|
MD5 | 85e3fd242ca6d1304dca2a63689f5576 |
|
BLAKE2b-256 | 784ecd587d8b8ff6283302f0769e06fc790426c1a9a896286eb58cea241ce9e4 |