Comprehensive and simple framework for conducting AI robustness experiments
Project description
As machine learning approaches to artificial intelligence continue to grow in popularity, the need for secure implementation and evaluation becomes increasingly paramount. This is of especially great concern in safety-critical applications such as object detection for self driving cars, monitoring nuclear power plants, and giving medical diagnoses. To this end we present a simple yet comprehensive interface for robust training and evaluation of PyTorch classifiers.
Although other solutions such as the adversarial robustness toolbox and MAIR have provided solutions for this in the past, they are not as comprehensive in breadth of provided attacks and defenses.
Installation
Our work is available via this repository and as a PyPI package.
From PyPI (Recommended)
python3 -m pip install airtk
From Repo Source (Not Recommended)
In order to install from here, you will need:
git clone https://github.com/LAiSR-SK/AiRobustnessTestingKit-AiR-TK-
conda env create -p .conda
conda activate ./.conda
Contents
Attacks
Defenses
You can import and use our defenses as shown:
from torch import nn
from airtk.defense import TradesTraining
if __name__ == "__main__":
# Initialize the training function
training = TradesTraining(batch_size=512,
"cifar10",
"res101",
epochs=100,
lr=0.01,
seed=0,
model_dir="data/model/TRADES/",
save_freq=10)
# Run the specified training regime
training()
We support the following defenses:
- Adversarial Distributional Training (ADT)
- Adversarial Distributional Training++ (ADTPP)
- Adversarial Weight Distribution (ATAWP)
- Curriculum Adversarial Training (Currat)
- Federated Adversarial Training (FAT)
- Feature Scatter (FS)
- Geometry Aware Instance Reweighted Adversarial Training (GAIRAT)
- TRadeoff-inspired Adversarial DEfenses via Surrogate loss minimization (TRADES)
- TRADES with Adversarial Weight Distribution (TRADESAWP)
- Various Attacks (VA)
- You Only Propogate Once (YOPO)
Most of which can use the following keyword arguments:
kwarg name | use |
---|---|
dataset_name | name of the dataset to use |
model_name | name of the model to use |
epochs | number of epochs to train / test for |
batch_size | size of training and testing batches |
eps | size of image perturbations |
model_dir | directory to save models to |
Pretrained Models
In order to expedite progress in the field of secure AI, we have provided the weights of our trained models on huggingface. These can be loaded via load_pretrained
and then or further augmented:
import torch
from airtk.data import CIFAR100
from airtk.model import ResNet50
from torch.utils.data import DataLoader
if __name__ == "__main__":
torch.set_default_device("cuda")
# 1. Load the model
model: ResNet50 = ResNet50.from_pretrained("LAiSR-SK/curriculum-at-cifar100-res50")
# 2. Evaluate the model against CIFAR100
testset: CIFAR100 = CIFAR100(root="data/", train=False, download=True)
test_loader: DataLoader = DataLoader(testset, batch_szie = 256, shuffle=True)
total: int = 0
correct: int = 0
for x, y in test_loader:
logits = model(x)
_, predicted = torch.max(logits, 1)
total_correct += (predicted == y).sum().item()
total += predicted.size[0]
acc: float = 100 * correct / total
print(f"Accuracy: {acc}")
Cite Us
See CITATION.cff or the sidebar for details on how to cite our work.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file airtk-0.1.2.tar.gz
.
File metadata
- Download URL: airtk-0.1.2.tar.gz
- Upload date:
- Size: 289.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 945910cac042951a8532465caa86311a40511b62827189ffbd55f44401f43683 |
|
MD5 | 2c708aa6209e0e924f84d87fde570574 |
|
BLAKE2b-256 | 241193bfc81cd58b92617c04c25bcb7ef4a6d625bd0f23d7618a76c328c34f7a |
File details
Details for the file airtk-0.1.2-py3-none-any.whl
.
File metadata
- Download URL: airtk-0.1.2-py3-none-any.whl
- Upload date:
- Size: 247.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | dcdaf16b258036cd25517afc66272d1289c21eea49bb8fdc9eb5f65473b5d77f |
|
MD5 | f508f8af9c147b0671469adf1df78dce |
|
BLAKE2b-256 | 4e0121ba0837d6beff18be3b2b24a38e1d37d934352bb2459368725dc37b2b90 |