Perform cross-validation on MIL data using sklearn. Assume single-instance inference for the bag label
Project description
Bag cross validation
Introduction
Multiple instance labeling (MIL) refers to labeling data arranged in sets, or bags. In MIL supervised learning, labels are known for bags of instances, and the goal is to assign bag-level labels to unobserved bags.
A simple solution to the MIL problem is to treat each instance in a bag as a single instance (SI) that inherits the label of its bag. Each SI in a bag are labeled with a single-instance estimator, and the bag label is reduced from some metric (mode, threshold, presence) of the SI observations. Official terms include presence based, threshold based, or count based concepts (see A two-level learning method for generalized multi-instance problems by Weidmann Nils et. al.).
This package fits into Scikit-learn cross-validation framework, and allows us to use traditional single-instance classifiers to predict on bag-level data for multiple instance labeling.
This package supports:
- The usage of scikit-learn estimators in a cross-validation framework for multiple-instance-labeling with bag-label inference from single-instance labels
- The use of scikit-learn evaluation metrics with cross-evaluation measured against the MIL problem
Motivation
scikit-learn is a popular tool for data analysis, and includes APIs for SI estimators. It includes a convenient API for evaluating SI estimators, namely cross_validate
This package uses sklearn's cross_validate method and extends it to MIL for SI estimators.
Usage example
# Python imports
# Third party imports
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, jaccard_score
from sklearn.model_selection import ShuffleSplit
from sklearn.dummy import DummyClassifier
from sklearn.neighbors import RadiusNeighborsClassifier
from sklearn.metrics import make_scorer
# Local imports
from bag_cross_validate import cross_validate_bag, BagScorer, bags_2_si
# Global definitions
#%%
# Create some dummy data
"""Generate some dummy data
Create bags and single-instance data
A set of bags have a shape [n x (m x p)], and can be through of as an
array of bag instances.
n is the number of bags
m is the number of instances within each bag (this can vary between bags)
p is the feature space of each instance"""
n_bags = 1000
m_instances = 8 # Static number of bags
p = 5
bags = []
# 25% negative class, 75% positive class
# Bags are created with random data, of shape (n, (m,p))
labels = np.concatenate((np.ones(int(n_bags*0.5)),
np.zeros(int(n_bags*(1-0.5))),
))
bags = np.random.randint(low=0, high=2, size=(n_bags, m_instances, p))
print("This is what a bag looks like: \n{}".format(bags[0]))
# Split dummy dataset dataset
rs = ShuffleSplit(n_splits=1, test_size=0.2, train_size=0.8)
train_index, test_index = next(rs.split(bags, labels))
train_bags, train_labels = bags[train_index], labels[train_index]
test_bags, test_labels = bags[test_index], labels[test_index]
# Create an estimator
dumb = DummyClassifier(strategy='constant', constant=1)
radiusNeighbor = RadiusNeighborsClassifier(weights='distance',
algorithm='auto',
p=1, # Manhattan distance
)
# Create an evaluation metric
# Multiple evaluation metrics are allowed
accuracy_scorer = make_scorer(accuracy_score)
bagAccScorer = BagScorer(accuracy_scorer) # Accuracy score, no factory function
precision_scorer = make_scorer(precision_score, average='binary')
bagPreScorer = BagScorer(precision_scorer)
jaccard_scorer = make_scorer(jaccard_score, average='binary')
bagJacScorer = BagScorer(jaccard_scorer)
scoring = {'bag_accuracy':bagAccScorer,
'bag_precision':bagPreScorer,
'bag_jaccard':bagJacScorer,
}
#%%
# Cross validate the dummy data and estimator
result_dumb = cross_validate_bag(estimator=dumb,
X=train_bags,
y=train_labels,
groups=None,
scoring=scoring, # Custom scorer...
cv=2,
n_jobs=3,
verbose=0,
fit_params=None,
pre_dispatch='2*n_jobs',
return_train_score=False,
return_estimator=False,
error_score=np.nan)
result_neighbor = cross_validate_bag(estimator=radiusNeighbor,
X=train_bags,
y=train_labels,
groups=None,
scoring=scoring, # Custom scorer...
cv=3,
n_jobs=2,
verbose=0,
fit_params=None,
pre_dispatch='2*n_jobs',
return_train_score=False,
return_estimator=False,
error_score=np.nan)
# Display the results
msg=("\nOur dummy estimator tried his best, and predicted {} percent of bags "
"correctly")
msg = msg.format(result_dumb['test_bag_accuracy'])
print(msg)
msg=("\nOur neighbor estimator didnt fair well either, and predicted {} percent "
"of bags correctly")
msg = msg.format(result_neighbor['test_bag_accuracy'])
print(msg)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file bag_cross_validate-0.0.3.tar.gz
.
File metadata
- Download URL: bag_cross_validate-0.0.3.tar.gz
- Upload date:
- Size: 19.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ab9802a6f15066a2d4f1b2a742099b5e332a0ca1530d4d7cc719d147efcfded1 |
|
MD5 | e12a2f82341ada67c2e8ddd4b096065f |
|
BLAKE2b-256 | e40e4d9da2109c23a5d8c214e490367d9a164bb027ceb46dbb7b69e28fdb819f |
File details
Details for the file bag_cross_validate-0.0.3-py3-none-any.whl
.
File metadata
- Download URL: bag_cross_validate-0.0.3-py3-none-any.whl
- Upload date:
- Size: 19.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 87ec81bf9584207da171e961efae569f11675f63b56e4a8bf2dfd04d7e081b57 |
|
MD5 | ac9f35018fd7754f77533674ed04cc68 |
|
BLAKE2b-256 | 2a5cf2d1c1f01d7e84dd5969fe0e4bc144995c5d52c65b1d60d163336a6393f5 |