Skip to main content

Perform cross-validation on MIL data using sklearn. Assume single-instance inference for the bag label

Project description

Bag cross validation

Introduction

Multiple instance labeling (MIL) refers to labeling data arranged in sets, or bags. In MIL supervised learning, labels are known for bags of instances, and the goal is to assign bag-level labels to unobserved bags.

A simple solution to the MIL problem is to treat each instance in a bag as a single instance (SI) that inherits the label of its bag. Each SI in a bag are labeled with a single-instance estimator, and the bag label is reduced from some metric (mode, threshold, presence) of the SI observations. Official terms include presence based, threshold based, or count based concepts (see A two-level learning method for generalized multi-instance problems by Weidmann Nils et. al.).

This package fits into Scikit-learn cross-validation framework, and allows us to use traditional single-instance classifiers to predict on bag-level data for multiple instance labeling.

This package supports:

  1. The usage of scikit-learn estimators in a cross-validation framework for multiple-instance-labeling with bag-label inference from single-instance labels
  2. The use of scikit-learn evaluation metrics with cross-evaluation measured against the MIL problem

Motivation

scikit-learn is a popular tool for data analysis, and includes APIs for SI estimators. It includes a convenient API for evaluating SI estimators, namely cross_validate

This package uses sklearn's cross_validate method and extends it to MIL for SI estimators.

Usage example

# Python imports

# Third party imports
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, jaccard_score
from sklearn.model_selection import ShuffleSplit
from sklearn.dummy import DummyClassifier
from sklearn.neighbors import RadiusNeighborsClassifier
from sklearn.metrics import make_scorer

# Local imports
from bag_cross_validate import cross_validate_bag, BagScorer, bags_2_si

# Global definitions

#%%

# Create some dummy data
"""Generate some dummy data
Create bags and single-instance data
A set of bags have a shape [n x (m x p)], and can be through of as an
array of bag instances.
n is the number of bags
m is the number of instances within each bag (this can vary between bags)
p is the feature space of each instance"""
n_bags = 1000
m_instances = 8 # Static number of bags
p = 5
bags = []
# 25% negative class, 75% positive class
# Bags are created with random data, of shape (n, (m,p))
labels = np.concatenate((np.ones(int(n_bags*0.5)),
                         np.zeros(int(n_bags*(1-0.5))),
                         ))
bags = np.random.randint(low=0, high=2, size=(n_bags, m_instances, p))
print("This is what a bag looks like: \n{}".format(bags[0]))

# Split dummy dataset dataset
rs = ShuffleSplit(n_splits=1, test_size=0.2, train_size=0.8)
train_index, test_index = next(rs.split(bags, labels))
train_bags, train_labels = bags[train_index], labels[train_index]
test_bags, test_labels = bags[test_index], labels[test_index]
        
# Create an estimator
dumb = DummyClassifier(strategy='constant', constant=1)
radiusNeighbor = RadiusNeighborsClassifier(weights='distance', 
                                           algorithm='auto',
                                           p=1, # Manhattan distance
                                           )

# Create an evaluation metric
# Multiple evaluation metrics are allowed
accuracy_scorer = make_scorer(accuracy_score)
bagAccScorer = BagScorer(accuracy_scorer) # Accuracy score, no factory function
precision_scorer = make_scorer(precision_score, average='binary')
bagPreScorer = BagScorer(precision_scorer)
jaccard_scorer = make_scorer(jaccard_score, average='binary')
bagJacScorer = BagScorer(jaccard_scorer)
scoring = {'bag_accuracy':bagAccScorer,
           'bag_precision':bagPreScorer,
           'bag_jaccard':bagJacScorer,
           }


#%%

# Cross validate the dummy data and estimator
result_dumb = cross_validate_bag(estimator=dumb, 
                            X=train_bags, 
                            y=train_labels, 
                            groups=None, 
                            scoring=scoring, # Custom scorer... 
                            cv=2,
                            n_jobs=3, 
                            verbose=0, 
                            fit_params=None,
                            pre_dispatch='2*n_jobs', 
                            return_train_score=False,
                            return_estimator=False, 
                            error_score=np.nan)

result_neighbor = cross_validate_bag(estimator=radiusNeighbor, 
                            X=train_bags, 
                            y=train_labels, 
                            groups=None, 
                            scoring=scoring, # Custom scorer... 
                            cv=3,
                            n_jobs=2, 
                            verbose=0, 
                            fit_params=None,
                            pre_dispatch='2*n_jobs', 
                            return_train_score=False,
                            return_estimator=False, 
                            error_score=np.nan)

# Display the results
msg=("\nOur dummy estimator tried his best, and predicted {} percent of bags " 
    "correctly")
msg = msg.format(result_dumb['test_bag_accuracy'])
print(msg)

msg=("\nOur neighbor estimator didnt fair well either, and predicted {} percent "
     "of bags correctly")
msg = msg.format(result_neighbor['test_bag_accuracy'])
print(msg)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bag_cross_validate-0.0.2.tar.gz (17.5 kB view details)

Uploaded Source

Built Distribution

bag_cross_validate-0.0.2-py3-none-any.whl (18.7 kB view details)

Uploaded Python 3

File details

Details for the file bag_cross_validate-0.0.2.tar.gz.

File metadata

  • Download URL: bag_cross_validate-0.0.2.tar.gz
  • Upload date:
  • Size: 17.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.10

File hashes

Hashes for bag_cross_validate-0.0.2.tar.gz
Algorithm Hash digest
SHA256 13654d91118720eb5bc9ccfc5780aab21db881f58c93950a55e2f590a20e5402
MD5 4ee9bded15197f3a0cef392a6ab69f3b
BLAKE2b-256 797993cfb1b8313a53dda8f9a5783ed55ac9319ef063db80fe2e7683fd6cfd02

See more details on using hashes here.

File details

Details for the file bag_cross_validate-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: bag_cross_validate-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 18.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.10

File hashes

Hashes for bag_cross_validate-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 e3457d5371ec55c584228d8ff391da03d13ecab237be15ceae3efbfc7e342cfb
MD5 99097ddce4be4c975e43367710940d01
BLAKE2b-256 b98f7c139f5e40c598508bd4196e993c2a5d932e9ed169b19412a6db4d4069bf

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page