Skip to main content

The code for paper 'Is Learn to Defer Enough? Optimal Predictors that Incorporate Human Decisions'

Project description

Quick Start

In this package, we provide the code to reproduce the experiments in the paper "Is Learn to Defer Enough? Optimal Predictors that Incorporate Human Decisions".

Quick Installation

Using pip (Recommended)

This package could be installed easily using pip with the following commands. It is recommended that you create a virtual environment and install everything there:

# Creating a virtual environment (optional)
python3 -m venv beyonddefer-venv
source beyonddefer-venv/bin/activate

# Adding the Package
pip install beyonddefer

Cloning Repository

Another way to use the package is to clone this repository and then add the package's path to the python path (using PYTHONPATH environmental variable):

# cloning the repositiry and installing requirements
git clone <repo-url>
cd BeyondDefer
pip install -r requirements.txt

# adding the package to python path
export PYTHONPATH=$PWD

# run your python script which includes beyonddefer

Usage Example

In this section, we go through an example of using beyonddefer package and writing a simple python code. In this example, we simply train the Additional Beyond Defer method with WideResNet model with the synthetic CIFAR10K dataset for k = 5. Then we test for the results and print them out.

Note: Before running any experiments, you should first create the data, models, and Results directories in the directory of your python script:

mkdir data models Results

Here is code for the introduced example:

# import beyonddefer itself for some initializations
import beyonddefer

# import the required modules
from beyonddefer.human_ai_defer.datasetsdefer.cifar_synth import \
    CifarSynthDataset
from beyonddefer.MyMethod.additional_defer import AdditionalBeyond
from beyonddefer.MyNet.call_net import networks, optimizer_scheduler
from beyonddefer.metrics.metrics import compute_additional_defer_metrics
import torch
import logging

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Setting log level to INFO to see training logs
logging.getLogger().setLevel(logging.INFO)

# Adding the dataset
k = 5 # expert k
dataset = CifarSynthDataset(k, False, batch_size=512)

dataset_name = "cifar_synth"
epochs = 150
num_classes = 10

# Adding the networks and model
classifier, human, meta = networks(dataset_name, "Additional", device)
AB = AdditionalBeyond(10, classifier, human, meta, device)

# Optimizer and scheduler
optimizer, scheduler = optimizer_scheduler()

# Training the model
AB.fit(dataset.data_train_loader, dataset.data_val_loader,
           dataset.data_test_loader, num_classes, epochs, optimizer, lr=1e-3,
           scheduler=scheduler, verbose=False)

# Generating test results
test_data = AB.test(dataset.data_test_loader, num_classes)

# Extracting useful information from the raw test data
res_AB = compute_additional_defer_metrics(test_data)

print(res_AB)

Experiments

The main set of experiments shown in the paper are in Experiments/ (Section 7). In fact,

  • in Experiments/acc_vs_c.py the code corresponding to the accuracy of methods based on additional defer cost is provided,
  • in Experiments/CIFAR10K.py the code corresponding to the CIFAR10K experiment for different $K$ is provided,
  • in Experiments/cost_sensitive_cov_acc.py the code of accuracy vs. coverage for cost-sensitive methods is provided,
  • in Experiments/SampleComp.py the role of sample complexity is studied, and
  • in Experiments/no_loss_cov_acc.py the code of accuracy vs. coverage for methods for 0-1 losses is provided.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

beyonddefer-1.0.7.tar.gz (3.3 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

beyonddefer-1.0.7-py3-none-any.whl (3.4 MB view details)

Uploaded Python 3

File details

Details for the file beyonddefer-1.0.7.tar.gz.

File metadata

  • Download URL: beyonddefer-1.0.7.tar.gz
  • Upload date:
  • Size: 3.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for beyonddefer-1.0.7.tar.gz
Algorithm Hash digest
SHA256 b256b9d0db2d4494779fa5a26af6527d182e305b3ea4b70f6541f0dec48bf166
MD5 c737bbebceed8855a1857310fb4b8949
BLAKE2b-256 4e8a8b55d5984b8270bbdc1b8467da693dfb822f34f341308136023fc3349798

See more details on using hashes here.

File details

Details for the file beyonddefer-1.0.7-py3-none-any.whl.

File metadata

  • Download URL: beyonddefer-1.0.7-py3-none-any.whl
  • Upload date:
  • Size: 3.4 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for beyonddefer-1.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 d91312fa12355d730cdd4bdf0215d700feae3a93c52e78ca2593f77e050bd595
MD5 354cdd10c8ec615782d186f98a7f3e82
BLAKE2b-256 fc5765b078e24df15b1f8b27d26ade4470324b3c027a8e75aee386a4639a1b7b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page