Skip to main content

The Robust Deep Learning Library

Project description

THIS IS A EARLY DRAFT VERSION. PLEASE, DO NOT USE IT.


The Robust Deep Learning Library

Train your model from scratch or fine-tune a pretraneind model to produce deep neural networks with improved robustness and uncertainty estimations.

  • Model Independent: Use models from timm library or whatever you want.
  • Hyperparameter-Free: Do not train many times! Use hyperparameter-free "You Only Train Once" (YOTO) losses.
  • Large-Scale Models and Datasets: Train on ImageNet or any other large-scale dataset.
  • Media Independent: Most cases work for anytype of media (e.g., image, text, audio, and others).
  • Standard Interface: Use the same API to train models with improved robustness using diferent losses.
  • State-of-the-art: SOTA results for Out-of-Distribution Detection and Uncertainty Estimation.
  • Efficient Inferences: The trained models are as efficient as the ones trained using the cross-entropy losss.
  • No Need for Additional Data: The losses used in this library do not requere collecting additional data.
  • Temperature Calibration: Calculate the Uncertainty Estimation and update the temperature of last layer.
  • Threshold Computation: Compute the threshold for deciding regarding out-of-distribution examples.
  • Scores Computation: Compute the scores opting from a set of many diferent types available.
  • Detect Out-of-Distrbution: Detect out-of-distribution examples using the computed scores.
  • Scalability: Our entropic losses perform better and better as the size of the dataset and models increase.

Some code reused from deep_Mahalanobis_detector and odin-pytorch.

Installation

pip install robust-deep-learning

Usage

# Import the robust deep learning library
import robust-deep-learning as rdl

########################################################################################
########################################################################################
# Training or Fine-tuning the Robust Deep Neural Network
########################################################################################
########################################################################################


##############
# Custom model
##############

# Create from a model definition file. 
# For example, you imported a class "Model" froma model definition file
model = Model()
# You may load a checkpoint now if you intent to fine-tuning rather training from scratch
# To fine-tuning, load a checkpoint from a pretrained model 

# Chance the classifier last layer of your model.
# if your classification last layer is called "classifier".
# For example, you have something like
# "model.classifier = nn.Linear(num_features, num_classes)",
# then add the following line 
model.classifier = rdl.DisMaxLossFirstPart(num_features, num_classes)

# Replace the Cross Entropy Loss
criterion = rdl.DisMaxLossSecondPart(model.classifier, add_on="FPR")

############
# timm model
############

# For use a model from timm lib, use "create_model" functionality
# It is possible to start for a pretrained model and fine-tune using the new loss
model = timm.create_model('resnet18', pretrained=False)

# If you are using timm models and the last layer is called "fc", do the following instead:
model.fc = rdl.DisMaxLossFirstPart(model.get_classifier().in_features, num_classes)

# If you are using timm models and the last layer is called "fc", do the following instead:
criterion = rdl.DisMaxLossSecondPart(model.fc, add_on="FPR")

###############
# training loop
###############

for epoch in epochs:

    # training loop
    for inputs, targets in in_data_train_loader:

        # In the training loop, add the line below for preprocessing before forwarding.
        inputs, targets = criterion.preprocess(inputs, targets) 

        # The three bellow lines are usually find in the training loop!
        # These lines should not be chanced!
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

########################################################################################
########################################################################################
# Uncertainty Estimation
########################################################################################
########################################################################################

#############
# calibrating
#############

# In the training loop, add the line of code below for preprocessing before forwarding.
results = rdl.get_outputs_labels_and_metrics(model, in_data_val_loader)

# In the training loop, add the line of code below for preprocessing before forwarding.
probabilities = torch.nn.Softmax(dim=1)(results["outputs"])
print(probabilities)
print(results["acc"], results["ece"], results["nll"])

# In the training loop, add the line of code below for preprocessing before forwarding.
rdl.calibrate_temperature(model.classifier, model, in_data_val_loader, optimize="ECE")

# print the new temperature after calibration
print(model.classifier.temperature)

######################
# verifing calibration
######################

# In the training loop, add the line of code below for preprocessing before forwarding.
results = rdl.get_outputs_labels_and_metrics(model, in_data_val_loader)

# In the training loop, add the line of code below for preprocessing before forwarding.
probabilities = torch.nn.Softmax(dim=1)(results["outputs"])
print(probabilities)
print(results["acc"], results["ece"],= results["nll"])

#######################################################################################
#######################################################################################
# Out-of-Distribution Detection
#######################################################################################
#######################################################################################

########################
# estimation performance
########################

# Define a score type. Typically the best for the loss you are using.
score_type = "MMLES"

# Evaluate the out-of-distrbution detection performance. Do all scores!
ood_metrics = rdl.get_ood_metrics(model, in_data_val_loader, out_data_loader, score_type, fpr=0.05)

# In the training loop, add the line of code below for preprocessing before forwarding.
results = rdl.get_outputs_labels_and_metrics(model, in_data_val_loader)
in_data_scores = rdl.get_scores(results["outputs"], score_type)

# In the training loop, add the line of code below for preprocessing before forwarding.
results = rdl.get_outputs_labels_and_metrics(model, out_data_loader)
out_data_scores = rdl.get_scores(results["outputs"], score_type)

#results = rdl.get_ood_detection_metrics(model, score, test_loader, out_test_loader)
ood_metrics = rdl.get_ood_metrics_from_scores(in_data_scores, out_data_scores, fpr=0.05)

###########
# detecting
###########

# In the training loop, add the line of code below for preprocessing before forwarding.
#thresholds = rdl.get_thresholds(results["outputs"], score_types="MMLES")
thresholds = rdl.get_thresholds(model, in_data_val_loader, score_type)

# In the training loop, add the line of code below for preprocessing before forwarding.
#thresholds = rdl.get_thresholds(results["outputs"], score_types="MMLES")
thresholds = rdl.get_thresholds_from_scores(in_data_scores) # guarder 

# scores and threshold calculated above should have used the same score_types
# the 5 means 10 percentile! See other values...
# next(iter(data_loader))
ood_detections = rdl.get_ood_detections(model, inputs, thresholds, fpr="0.05")

Examples

Please, move to the data directory and run all the prepare data bash scripts:

# Download and prepare out-of-distrbution data for CIFAR10 and CIFAR100 datasets.
./prepare-cifar.sh
# Download and prepare out-of-distrbution data for ImageNet.
./prepare-imagenet.sh

For an example training on CIFAR using a model definition file, please see:

# Run the CIFAR dataset example.
python -m robust_deep_learning.examples.cifar

For an example training on ImageNet using timm library model, please see:

# Run the ImageNet dataset example.
python -m robust_deep_learning.examples.imagenet

Losses and Scores

The following losses are implemented:

The following scores are implemented:

  • Features
  • Features
  • Features
  • Features
  • Features
  • Features

Questions and Answers

  • Should I use augumentation? Use models from timm library or whatever you want.
  • Should I use inference-based approaches? Use hyperparameter-free "You Only Train Once" (YOTO) losses.

Results

Dataset: ImageNet

Model Loss Score Classification (ACC) OOD Detection [ImageNet-O] (AUROC)
ResNet18 SoftMax (Baseline) MPS 69.9 52.4
ResNet18 DisMax MMLES 69.6 75.8

Experiments

Train and evaluate the classification, uncertainty estimation, and out-of-distribution detection performances:

./run_cifar100_densenetbc100.sh*
./run_cifar100_resnet34.sh*
./run_cifar100_wideresnet2810.sh*
./run_cifar10_densenetbc100.sh*
./run_cifar10_resnet34.sh*
./run_cifar10_wideresnet2810.sh*
./run_imagenet1k_resnet18.sh*

Print the experiment results:

./analize.sh

Citing

BibTeX

Please, cite our papers if you use our losses in your work:

@article{macedo2022distinction,
    title={Distinction Maximization Loss:
    Efficiently Improving Out-of-Distribution Detection
    and Uncertainty Estimation by Replacing the Loss and Calibrating}, 
    author={David Macêdo and Cleber Zanchettin and Teresa Ludermir},
    year={2022},
    eprint={2205.05874},
    archivePrefix={arXiv},
    primaryClass={cs.LG}}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

robust_deep_learning-0.0.1-py3-none-any.whl (25.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page