Skip to main content

Perform cross-validation on MIL data using sklearn. Assume single-instance inference for the bag label

Project description

Bag cross validation

Introduction

Multiple instance labeling (MIL) refers to labeling data arranged in sets, or bags. In MIL supervised learning, labels are known for bags of instances, and the goal is to assign bag-level labels to unobserved bags.

A simple solution to the MIL problem is to treat each instance in a bag as a single instance (SI) that inherits the label of its bag. Each SI in a bag are labeled with a single-instance estimator, and the bag label is reduced from some metric (mode, threshold, presence) of the SI observations. Official terms include presence based, threshold based, or count based concepts (see A two-level learning method for generalized multi-instance problems by Weidmann Nils et. al.).

This package fits into Scikit-learn cross-validation framework, and allows us to use traditional single-instance classifiers to predict on bag-level data for multiple instance labeling.

This package supports:

  1. The usage of scikit-learn estimators in a cross-validation framework for multiple-instance-labeling with bag-label inference from single-instance labels
  2. The use of scikit-learn evaluation metrics with cross-evaluation measured against the MIL problem

Motivation

scikit-learn is a popular tool for data analysis, and includes APIs for SI estimators. It includes a convenient API for evaluating SI estimators, namely cross_validate

This package uses sklearn's cross_validate method and extends it to MIL for SI estimators.

Usage example

# Python imports

# Third party imports
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, jaccard_score
from sklearn.model_selection import ShuffleSplit
from sklearn.dummy import DummyClassifier
from sklearn.neighbors import RadiusNeighborsClassifier
from sklearn.metrics import make_scorer

# Local imports
from bag_cross_validate import cross_validate_bag, BagScorer, bags_2_si

# Global definitions

#%%

# Create some dummy data
"""Generate some dummy data
Create bags and single-instance data
A set of bags have a shape [n x (m x p)], and can be through of as an
array of bag instances.
n is the number of bags
m is the number of instances within each bag (this can vary between bags)
p is the feature space of each instance"""
n_bags = 1000
m_instances = 8 # Static number of bags
p = 5
bags = []
# 25% negative class, 75% positive class
# Bags are created with random data, of shape (n, (m,p))
labels = np.concatenate((np.ones(int(n_bags*0.5)),
                         np.zeros(int(n_bags*(1-0.5))),
                         ))
bags = np.random.randint(low=0, high=2, size=(n_bags, m_instances, p))
print("This is what a bag looks like: \n{}".format(bags[0]))

# Split dummy dataset dataset
rs = ShuffleSplit(n_splits=1, test_size=0.2, train_size=0.8)
train_index, test_index = next(rs.split(bags, labels))
train_bags, train_labels = bags[train_index], labels[train_index]
test_bags, test_labels = bags[test_index], labels[test_index]
        
# Create an estimator
dumb = DummyClassifier(strategy='constant', constant=1)
radiusNeighbor = RadiusNeighborsClassifier(weights='distance', 
                                           algorithm='auto',
                                           p=1, # Manhattan distance
                                           )

# Create an evaluation metric
# Multiple evaluation metrics are allowed
accuracy_scorer = make_scorer(accuracy_score)
bagAccScorer = BagScorer(accuracy_scorer) # Accuracy score, no factory function
precision_scorer = make_scorer(precision_score, average='binary')
bagPreScorer = BagScorer(precision_scorer)
jaccard_scorer = make_scorer(jaccard_score, average='binary')
bagJacScorer = BagScorer(jaccard_scorer)
scoring = {'bag_accuracy':bagAccScorer,
           'bag_precision':bagPreScorer,
           'bag_jaccard':bagJacScorer,
           }


#%%

# Cross validate the dummy data and estimator
result_dumb = cross_validate_bag(estimator=dumb, 
                            X=train_bags, 
                            y=train_labels, 
                            groups=None, 
                            scoring=scoring, # Custom scorer... 
                            cv=2,
                            n_jobs=3, 
                            verbose=0, 
                            fit_params=None,
                            pre_dispatch='2*n_jobs', 
                            return_train_score=False,
                            return_estimator=False, 
                            error_score=np.nan)

result_neighbor = cross_validate_bag(estimator=radiusNeighbor, 
                            X=train_bags, 
                            y=train_labels, 
                            groups=None, 
                            scoring=scoring, # Custom scorer... 
                            cv=3,
                            n_jobs=2, 
                            verbose=0, 
                            fit_params=None,
                            pre_dispatch='2*n_jobs', 
                            return_train_score=False,
                            return_estimator=False, 
                            error_score=np.nan)

# Display the results
msg=("\nOur dummy estimator tried his best, and predicted {} percent of bags " 
    "correctly")
msg = msg.format(result_dumb['test_bag_accuracy'])
print(msg)

msg=("\nOur neighbor estimator didnt fair well either, and predicted {} percent "
     "of bags correctly")
msg = msg.format(result_neighbor['test_bag_accuracy'])
print(msg)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bag_cross_validate-0.0.3.tar.gz (19.2 kB view hashes)

Uploaded Source

Built Distribution

bag_cross_validate-0.0.3-py3-none-any.whl (19.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page