Skip to main content

Comprehensive and simple framework for conducting AI robustness experiments

Project description

Logo

Python Package Huggingface Code Style License

As machine learning approaches to artificial intelligence continue to grow in popularity, the need for secure implementation and evaluation becomes increasingly paramount. This is of especially great concern in safety-critical applications such as object detection for self driving cars, monitoring nuclear power plants, and giving medical diagnoses. To this end we present a simple yet comprehensive interface for robust training and evaluation of PyTorch classifiers.

Although other solutions such as the adversarial robustness toolbox and MAIR have provided solutions for this in the past, they are not as comprehensive in breadth of provided attacks and defenses.

Installation

Our work is available via this repository and as a PyPI package.

From PyPI (Recommended)

python3 -m pip install airtk

From Repo Source (Not Recommended)

In order to install from here, you will need:

  • The Conda environment manager.
  • The Git version control system.
git clone https://github.com/LAiSR-SK/AiRobustnessTestingKit-AiR-TK-
conda env create -p .conda

conda activate ./.conda

Contents

Attacks

Defenses

You can import and use our defenses as shown:

from torch import nn

from airtk.defense import TradesTraining

if __name__ == "__main__":
    # Initialize the training function
    training = TradesTraining(batch_size=512,
                              "cifar10",
                              "res101",
                              epochs=100,
                              lr=0.01,
                              seed=0,
                              model_dir="data/model/TRADES/",
                              save_freq=10)
                              
    # Run the specified training regime
    training()

We support the following defenses:

  • Adversarial Distributional Training (ADT)
  • Adversarial Distributional Training++ (ADTPP)
  • Adversarial Weight Distribution (ATAWP)
  • Curriculum Adversarial Training (Currat)
  • Federated Adversarial Training (FAT)
  • Feature Scatter (FS)
  • Geometry Aware Instance Reweighted Adversarial Training (GAIRAT)
  • TRadeoff-inspired Adversarial DEfenses via Surrogate loss minimization (TRADES)
  • TRADES with Adversarial Weight Distribution (TRADESAWP)
  • Various Attacks (VA)
  • You Only Propogate Once (YOPO)

Most of which can use the following keyword arguments:

kwarg name use
dataset_name name of the dataset to use
model_name name of the model to use
epochs number of epochs to train / test for
batch_size size of training and testing batches
eps size of image perturbations
model_dir directory to save models to

Pretrained Models

In order to expedite progress in the field of secure AI, we have provided the weights of our trained models on huggingface. These can be loaded via load_pretrained and then or further augmented:

import torch
from airtk.data import CIFAR100
from airtk.model import ResNet50
from torch.utils.data import DataLoader

if __name__ == "__main__":
    torch.set_default_device("cuda")

    # 1. Load the model
    model: ResNet50 = ResNet50.from_pretrained("LAiSR-SK/curriculum-at-cifar100-res50")
    
    # 2. Evaluate the model against CIFAR100
    testset: CIFAR100 = CIFAR100(root="data/", train=False, download=True)
    test_loader: DataLoader = DataLoader(testset, batch_szie = 256, shuffle=True)
    
    total: int = 0
    correct: int = 0
    for x, y in test_loader:
        logits = model(x)
        _, predicted = torch.max(logits, 1)

        total_correct += (predicted == y).sum().item()
        total += predicted.size[0]
        
    acc: float = 100 * correct / total

    print(f"Accuracy: {acc}")

Cite Us

See CITATION.cff or the sidebar for details on how to cite our work.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

airtk-0.1.2.tar.gz (289.2 kB view details)

Uploaded Source

Built Distribution

airtk-0.1.2-py3-none-any.whl (247.0 kB view details)

Uploaded Python 3

File details

Details for the file airtk-0.1.2.tar.gz.

File metadata

  • Download URL: airtk-0.1.2.tar.gz
  • Upload date:
  • Size: 289.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.18

File hashes

Hashes for airtk-0.1.2.tar.gz
Algorithm Hash digest
SHA256 945910cac042951a8532465caa86311a40511b62827189ffbd55f44401f43683
MD5 2c708aa6209e0e924f84d87fde570574
BLAKE2b-256 241193bfc81cd58b92617c04c25bcb7ef4a6d625bd0f23d7618a76c328c34f7a

See more details on using hashes here.

File details

Details for the file airtk-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: airtk-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 247.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.18

File hashes

Hashes for airtk-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 dcdaf16b258036cd25517afc66272d1289c21eea49bb8fdc9eb5f65473b5d77f
MD5 f508f8af9c147b0671469adf1df78dce
BLAKE2b-256 4e0121ba0837d6beff18be3b2b24a38e1d37d934352bb2459368725dc37b2b90

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page