Skip to main content

Comprehensive and simple framework for conducting AI robustness experiments

Project description

Logo

Python Package Huggingface Code Style License

As machine learning approaches to artificial intelligence continue to grow in popularity, the need for secure implementation and evaluation becomes increasingly paramount. This is of especially great concern in safety-critical applications such as object detection for self driving cars, monitoring nuclear power plants, and giving medical diagnoses. To this end we present a simple yet comprehensive interface for robust training and evaluation of PyTorch classifiers.

Although other solutions such as the adversarial robustness toolbox and MAIR have provided solutions for this in the past, they are not as comprehensive in breadth of provided attacks and defenses.

Installation

Our work is available via this repository and as a PyPI package.

From PyPI (Recommended)

python3 -m pip install airtk

From Repo Source (Not Recommended)

In order to install from here, you will need:

  • The Conda environment manager.
  • The Git version control system.
git clone https://github.com/LAiSR-SK/AiRobustnessTestingKit-AiR-TK-
conda env create -p .conda

conda activate ./.conda

Contents

Attacks

Defenses

You can import and use our defenses as shown:

from torch import nn

from airtk.defense import TradesTraining

if __name__ == "__main__":
    # Initialize the training function
    training = TradesTraining(batch_size=512,
                              "cifar10",
                              "res101",
                              epochs=100,
                              lr=0.01,
                              seed=0,
                              model_dir="data/model/TRADES/",
                              save_freq=10)
                              
    # Run the specified training regime
    training()

We support the following defenses:

  • Adversarial Distributional Training (ADT)
  • Adversarial Distributional Training++ (ADTPP)
  • Adversarial Weight Distribution (ATAWP)
  • Curriculum Adversarial Training (Currat)
  • Federated Adversarial Training (FAT)
  • Feature Scatter (FS)
  • Geometry Aware Instance Reweighted Adversarial Training (GAIRAT)
  • TRadeoff-inspired Adversarial DEfenses via Surrogate loss minimization (TRADES)
  • TRADES with Adversarial Weight Distribution (TRADESAWP)
  • Various Attacks (VA)
  • You Only Propogate Once (YOPO)

Most of which can use the following keyword arguments:

kwarg name use
dataset_name name of the dataset to use
model_name name of the model to use
epochs number of epochs to train / test for
batch_size size of training and testing batches
eps size of image perturbations
model_dir directory to save models to

Pretrained Models

In order to expedite progress in the field of secure AI, we have provided the weights of our trained models on huggingface. These can be loaded via load_pretrained and then or further augmented:

import torch
from airtk.data import CIFAR100
from airtk.model import ResNet50
from torch.utils.data import DataLoader

if __name__ == "__main__":
    torch.set_default_device("cuda")

    # 1. Load the model
    model: ResNet50 = ResNet50.from_pretrained("LAiSR-SK/curriculum-at-cifar100-res50")
    
    # 2. Evaluate the model against CIFAR100
    testset: CIFAR100 = CIFAR100(root="data/", train=False, download=True)
    test_loader: DataLoader = DataLoader(testset, batch_szie = 256, shuffle=True)
    
    total: int = 0
    correct: int = 0
    for x, y in test_loader:
        logits = model(x)
        _, predicted = torch.max(logits, 1)

        total_correct += (predicted == y).sum().item()
        total += predicted.size[0]
        
    acc: float = 100 * correct / total

    print(f"Accuracy: {acc}")

Cite Us

See CITATION.cff or the sidebar for details on how to cite our work.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

airtk-0.1.2.tar.gz (289.2 kB view hashes)

Uploaded Source

Built Distribution

airtk-0.1.2-py3-none-any.whl (247.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page