Skip to main content

Perform cross-validation on MIL data using sklearn. Assume single-instance inference for the bag label

Project description

Bag cross validation

Introduction

Multiple instance labeling (MIL) refers to labeling data arranged in sets, or bags. In MIL supervised learning, labels are known for bags of instances, and the goal is to assign bag-level labels to unobserved bags.

A simple solution to the MIL problem is to treat each instance in a bag as a single instance (SI) that inherits the label of its bag. Each SI in a bag are labeled with a single-instance estimator, and the bag label is reduced from some metric (mode, threshold, presence) of the SI observations. Official terms include presence based, threshold based, or count based concepts (see A two-level learning method for generalized multi-instance problems by Weidmann Nils et. al.).

This package fits into Scikit-learn cross-validation framework, and allows us to use traditional single-instance classifiers to predict on bag-level data for multiple instance labeling.

This package supports:

  1. The usage of scikit-learn estimators in a cross-validation framework for multiple-instance-labeling with bag-label inference from single-instance labels
  2. The use of scikit-learn evaluation metrics with cross-evaluation measured against the MIL problem

Motivation

scikit-learn is a popular tool for data analysis, and includes APIs for SI estimators. It includes a convenient API for evaluating SI estimators, namely cross_validate

This package uses sklearn's cross_validate method and extends it to MIL for SI estimators.

Usage example

# Python imports

# Third party imports
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, jaccard_score
from sklearn.model_selection import ShuffleSplit
from sklearn.dummy import DummyClassifier
from sklearn.neighbors import RadiusNeighborsClassifier
from sklearn.metrics import make_scorer

# Local imports
from bag_cross_validate import cross_validate_bag, BagScorer, bags_2_si

# Global definitions

#%%

# Create some dummy data
"""Generate some dummy data
Create bags and single-instance data
A set of bags have a shape [n x (m x p)], and can be through of as an
array of bag instances.
n is the number of bags
m is the number of instances within each bag (this can vary between bags)
p is the feature space of each instance"""
n_bags = 1000
m_instances = 8 # Static number of bags
p = 5
bags = []
# 25% negative class, 75% positive class
# Bags are created with random data, of shape (n, (m,p))
labels = np.concatenate((np.ones(int(n_bags*0.5)),
                         np.zeros(int(n_bags*(1-0.5))),
                         ))
bags = np.random.randint(low=0, high=2, size=(n_bags, m_instances, p))
print("This is what a bag looks like: \n{}".format(bags[0]))

# Split dummy dataset dataset
rs = ShuffleSplit(n_splits=1, test_size=0.2, train_size=0.8)
train_index, test_index = next(rs.split(bags, labels))
train_bags, train_labels = bags[train_index], labels[train_index]
test_bags, test_labels = bags[test_index], labels[test_index]
        
# Create an estimator
dumb = DummyClassifier(strategy='constant', constant=1)
radiusNeighbor = RadiusNeighborsClassifier(weights='distance', 
                                           algorithm='auto',
                                           p=1, # Manhattan distance
                                           )

# Create an evaluation metric
# Multiple evaluation metrics are allowed
accuracy_scorer = make_scorer(accuracy_score)
bagAccScorer = BagScorer(accuracy_scorer) # Accuracy score, no factory function
precision_scorer = make_scorer(precision_score, average='binary')
bagPreScorer = BagScorer(precision_scorer)
jaccard_scorer = make_scorer(jaccard_score, average='binary')
bagJacScorer = BagScorer(jaccard_scorer)
scoring = {'bag_accuracy':bagAccScorer,
           'bag_precision':bagPreScorer,
           'bag_jaccard':bagJacScorer,
           }


#%%

# Cross validate the dummy data and estimator
result_dumb = cross_validate_bag(estimator=dumb, 
                            X=train_bags, 
                            y=train_labels, 
                            groups=None, 
                            scoring=scoring, # Custom scorer... 
                            cv=2,
                            n_jobs=3, 
                            verbose=0, 
                            fit_params=None,
                            pre_dispatch='2*n_jobs', 
                            return_train_score=False,
                            return_estimator=False, 
                            error_score=np.nan)

result_neighbor = cross_validate_bag(estimator=radiusNeighbor, 
                            X=train_bags, 
                            y=train_labels, 
                            groups=None, 
                            scoring=scoring, # Custom scorer... 
                            cv=3,
                            n_jobs=2, 
                            verbose=0, 
                            fit_params=None,
                            pre_dispatch='2*n_jobs', 
                            return_train_score=False,
                            return_estimator=False, 
                            error_score=np.nan)

# Display the results
msg=("\nOur dummy estimator tried his best, and predicted {} percent of bags " 
    "correctly")
msg = msg.format(result_dumb['test_bag_accuracy'])
print(msg)

msg=("\nOur neighbor estimator didnt fair well either, and predicted {} percent "
     "of bags correctly")
msg = msg.format(result_neighbor['test_bag_accuracy'])
print(msg)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bag_cross_validate-0.0.1.tar.gz (17.1 kB view details)

Uploaded Source

Built Distribution

bag_cross_validate-0.0.1-py3-none-any.whl (18.2 kB view details)

Uploaded Python 3

File details

Details for the file bag_cross_validate-0.0.1.tar.gz.

File metadata

  • Download URL: bag_cross_validate-0.0.1.tar.gz
  • Upload date:
  • Size: 17.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.10

File hashes

Hashes for bag_cross_validate-0.0.1.tar.gz
Algorithm Hash digest
SHA256 1945e9ae11cbbab4015262dafe610c031e6f19f4d2eebad18f77b8b3617a2bac
MD5 f9851abec73fb5e2b017a91cb6d0e9f8
BLAKE2b-256 1a88d40cf583ecff827fa204b89275f20ae378ebb9901a282e5c7688b7d8de98

See more details on using hashes here.

File details

Details for the file bag_cross_validate-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: bag_cross_validate-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 18.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.10

File hashes

Hashes for bag_cross_validate-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ca87b2b997d73ac397a71b9a743165e3bfe23f4da8c3f8b9f70f88506a81fb53
MD5 0b7e666135218efe26b2b51ea71be453
BLAKE2b-256 fae661e1319eefb976909c42648d79561f8a9b60b5b54952257e574b2bad2e97

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page