The Robust Deep Learning Library
Project description
THIS IS A EARLY DRAFT VERSION. PLEASE, DO NOT USE IT.
The Robust Deep Learning Library
Train your model from scratch or fine-tune a pretraneind model to produce deep neural networks with improved robustness and uncertainty estimations.
- Model Independent: Use models from timm library or whatever you want.
- Hyperparameter-Free: Do not train many times! Use hyperparameter-free "You Only Train Once" (YOTO) losses.
- Large-Scale Models and Datasets: Train on ImageNet or any other large-scale dataset.
- Media Independent: Most cases work for anytype of media (e.g., image, text, audio, and others).
- Standard Interface: Use the same API to train models with improved robustness using diferent losses.
- State-of-the-art: SOTA results for Out-of-Distribution Detection and Uncertainty Estimation.
- Efficient Inferences: The trained models are as efficient as the ones trained using the cross-entropy losss.
- No Need for Additional Data: The losses used in this library do not requere collecting additional data.
- Temperature Calibration: Calculate the Uncertainty Estimation and update the temperature of last layer.
- Threshold Computation: Compute the threshold for deciding regarding out-of-distribution examples.
- Scores Computation: Compute the scores opting from a set of many diferent types available.
- Detect Out-of-Distrbution: Detect out-of-distribution examples using the computed scores.
- Scalability: Our entropic losses perform better and better as the size of the dataset and models increase.
Some code reused from deep_Mahalanobis_detector and odin-pytorch.
Installation
pip install robust-deep-learning
Usage
# Import the robust deep learning library
import robust-deep-learning as rdl
########################################################################################
########################################################################################
# Training or Fine-tuning the Robust Deep Neural Network
########################################################################################
########################################################################################
##############
# Custom model
##############
# Create from a model definition file.
# For example, you imported a class "Model" froma model definition file
model = Model()
# You may load a checkpoint now if you intent to fine-tuning rather training from scratch
# To fine-tuning, load a checkpoint from a pretrained model
# Chance the classifier last layer of your model.
# if your classification last layer is called "classifier".
# For example, you have something like
# "model.classifier = nn.Linear(num_features, num_classes)",
# then add the following line
model.classifier = rdl.DisMaxLossFirstPart(num_features, num_classes)
# Replace the Cross Entropy Loss
criterion = rdl.DisMaxLossSecondPart(model.classifier, add_on="FPR")
############
# timm model
############
# For use a model from timm lib, use "create_model" functionality
# It is possible to start for a pretrained model and fine-tune using the new loss
model = timm.create_model('resnet18', pretrained=False)
# If you are using timm models and the last layer is called "fc", do the following instead:
model.fc = rdl.DisMaxLossFirstPart(model.get_classifier().in_features, num_classes)
# If you are using timm models and the last layer is called "fc", do the following instead:
criterion = rdl.DisMaxLossSecondPart(model.fc, add_on="FPR")
###############
# training loop
###############
for epoch in epochs:
# training loop
for inputs, targets in in_data_train_loader:
# In the training loop, add the line below for preprocessing before forwarding.
inputs, targets = criterion.preprocess(inputs, targets)
# The three bellow lines are usually find in the training loop!
# These lines should not be chanced!
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
########################################################################################
########################################################################################
# Uncertainty Estimation
########################################################################################
########################################################################################
#############
# calibrating
#############
# In the training loop, add the line of code below for preprocessing before forwarding.
results = rdl.get_outputs_labels_and_metrics(model, in_data_val_loader)
# In the training loop, add the line of code below for preprocessing before forwarding.
probabilities = torch.nn.Softmax(dim=1)(results["outputs"])
print(probabilities)
print(results["acc"], results["ece"], results["nll"])
# In the training loop, add the line of code below for preprocessing before forwarding.
rdl.calibrate_temperature(model.classifier, model, in_data_val_loader, optimize="ECE")
# print the new temperature after calibration
print(model.classifier.temperature)
######################
# verifing calibration
######################
# In the training loop, add the line of code below for preprocessing before forwarding.
results = rdl.get_outputs_labels_and_metrics(model, in_data_val_loader)
# In the training loop, add the line of code below for preprocessing before forwarding.
probabilities = torch.nn.Softmax(dim=1)(results["outputs"])
print(probabilities)
print(results["acc"], results["ece"],= results["nll"])
#######################################################################################
#######################################################################################
# Out-of-Distribution Detection
#######################################################################################
#######################################################################################
########################
# estimation performance
########################
# Define a score type. Typically the best for the loss you are using.
score_type = "MMLES"
# Evaluate the out-of-distrbution detection performance. Do all scores!
ood_metrics = rdl.get_ood_metrics(model, in_data_val_loader, out_data_loader, score_type, fpr=0.05)
# In the training loop, add the line of code below for preprocessing before forwarding.
results = rdl.get_outputs_labels_and_metrics(model, in_data_val_loader)
in_data_scores = rdl.get_scores(results["outputs"], score_type)
# In the training loop, add the line of code below for preprocessing before forwarding.
results = rdl.get_outputs_labels_and_metrics(model, out_data_loader)
out_data_scores = rdl.get_scores(results["outputs"], score_type)
#results = rdl.get_ood_detection_metrics(model, score, test_loader, out_test_loader)
ood_metrics = rdl.get_ood_metrics_from_scores(in_data_scores, out_data_scores, fpr=0.05)
###########
# detecting
###########
# In the training loop, add the line of code below for preprocessing before forwarding.
#thresholds = rdl.get_thresholds(results["outputs"], score_types="MMLES")
thresholds = rdl.get_thresholds(model, in_data_val_loader, score_type)
# In the training loop, add the line of code below for preprocessing before forwarding.
#thresholds = rdl.get_thresholds(results["outputs"], score_types="MMLES")
thresholds = rdl.get_thresholds_from_scores(in_data_scores) # guarder
# scores and threshold calculated above should have used the same score_types
# the 5 means 10 percentile! See other values...
# next(iter(data_loader))
ood_detections = rdl.get_ood_detections(model, inputs, thresholds, fpr="0.05")
Examples
Please, move to the data
directory and run all the prepare data bash scripts:
# Download and prepare out-of-distrbution data for CIFAR10 and CIFAR100 datasets.
./prepare-cifar.sh
# Download and prepare out-of-distrbution data for ImageNet.
./prepare-imagenet.sh
For an example training on CIFAR using a model definition file, please see:
# Run the CIFAR dataset example.
python -m robust_deep_learning.examples.cifar
For an example training on ImageNet using timm library model, please see:
# Run the ImageNet dataset example.
python -m robust_deep_learning.examples.imagenet
Losses and Scores
The following losses are implemented:
- Isotropy Maximization Loss (arXiv)
- Isotropy Maximization Loss (conference version)
- Isotropy Maximization Loss (journal version)
- Enhanced Isotropy Maximization Loss (arXiv)
- Distinction Maximization Loss (arXiv)
The following scores are implemented:
- Features
- Features
- Features
- Features
- Features
- Features
Questions and Answers
- Should I use augumentation? Use models from timm library or whatever you want.
- Should I use inference-based approaches? Use hyperparameter-free "You Only Train Once" (YOTO) losses.
Results
Dataset: ImageNet
Model | Loss | Score | Classification (ACC) | OOD Detection [ImageNet-O] (AUROC) |
---|---|---|---|---|
ResNet18 | SoftMax (Baseline) | MPS | 69.9 | 52.4 |
ResNet18 | DisMax | MMLES | 69.6 | 75.8 |
Experiments
Train and evaluate the classification, uncertainty estimation, and out-of-distribution detection performances:
./run_cifar100_densenetbc100.sh*
./run_cifar100_resnet34.sh*
./run_cifar100_wideresnet2810.sh*
./run_cifar10_densenetbc100.sh*
./run_cifar10_resnet34.sh*
./run_cifar10_wideresnet2810.sh*
./run_imagenet1k_resnet18.sh*
Print the experiment results:
./analize.sh
Citing
BibTeX
Please, cite our papers if you use our losses in your work:
@article{macedo2022distinction,
title={Distinction Maximization Loss:
Efficiently Improving Out-of-Distribution Detection
and Uncertainty Estimation by Replacing the Loss and Calibrating},
author={David Macêdo and Cleber Zanchettin and Teresa Ludermir},
year={2022},
eprint={2205.05874},
archivePrefix={arXiv},
primaryClass={cs.LG}}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file robust_deep_learning-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: robust_deep_learning-0.0.2-py3-none-any.whl
- Upload date:
- Size: 25.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d7ee30c0c20ad73be225318178ff47438d173e7de7c21a125e76bf71147d2a66 |
|
MD5 | f366938bbd2329a7c81661a146e63419 |
|
BLAKE2b-256 | fa2bfaac2cf6dac6febd908a8999327cdfe384763859972c48d918bf369850ad |