Skip to main content

B-cos models.

Project description

B-cos Networks v2

M. Böhle, N. Singh, M. Fritz, B. Schiele.

Improved B-cos Networks.

Table of Contents

Introduction

This repository contains the code for the B-cos v2 models. (Paper coming soon!)

These models are more efficient and easier to train than the original B-cos models. Furthermore, we make a large number of pretrained B-cos models available for use.

If you want to take a quick look at the explanations the models generate, you can try out the Gradio web demo on Hugging Face Spaces.

If you prefer a more hands-on approach, you can take a look at the demo notebook on Colab or load the models directly via torch.hub as explained below.

If you simply want to copy the model definitions, we provide a minimal, single-file reference implementation including explanation mode in extra/minimal_bcos_resnet.py!

UPDATE: We have also released our ViT models! See Model Zoo.

Quick Start

You only need to make sure you have torch and torchvision installed.

Then, loading the models via torch.hub is as easy as:

import torch

# list all available models
torch.hub.list('B-cos/B-cos-v2')

# load a pretrained model
model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)

Inference and explanation visualization is as simple as:

from PIL import Image
import matplotlib.pyplot as plt

# load image
img = model.transform(Image.open('cat.jpg'))
img = img[None].requires_grad_()

# predict and explain
model.eval()
expl_out = model.explain(img)
print("Prediction:", expl_out["prediction"])  # predicted class idx
plt.imshow(expl_out["explanation"])
plt.show()

Each of the models has its inference transform attached to it, accessible via model.transform. Furthermore, each model has a .explain() method that takes an image tensor and returns a dictionary containing the prediction and the explanation, and some extras.

See the demo notebook for more details on the .explain() method.

Furthermore, each model has a get_classifier and get_feature_extractor method that return the classifier and feature extractor modules respectively. These can useful for fine-tuning the models!

Installation

Depending on your use case, you can either install the bcos package or set up the development environment for training the models (for your custom models or for reproducing the results).

bcos Package

If you are simply interested in using the models (pretrained or otherwise), then we provide a bcos package that can be installed via pip:

pip install bcos

This contains the models, their modules, transforms, and other utilities making it easy to use and build B-cos models. Take a look at the public API here. (I'll add a proper docs site if I have time or there's enough interest. Nonetheless, I have tried to keep the code well-documented, so it should be easy to follow.)

Training Environment Setup

If you want to train your own B-cos models using this repository or are interested in reproducing the results, you can set up the development environment as follows:

Using conda (recommended, especially if you want to reproduce the results):

conda env create -f environment.yml
conda activate bcos

Using pip

pip install -r requirements-train.txt

Setting Data Paths

You can either set the paths in bcos/settings.py or set the environment variables

  1. DATA_ROOT
  2. IMAGENET_PATH

to the paths of the data directories.

The DATA_ROOT environment variable should point to the data root directory for CIFAR-10 (will be automatically downloaded). For ImageNet, the IMAGENET_PATH environment variable should point to the directory containing the train and val directories.

Usage

For the bcos package, as mentioned earlier, take a look at the public API here.

For evaluating or training the models, you can use the evaluate.py and train.py scripts, as follows:

Evaluation

You can use evaluate the accuracy of the models on the ImageNet validation set using:

python evaluate.py --dataset ImageNet --hubconf resnet18

This will download the model from torch.hub and evaluate it on the ImageNet validation set. The default batch size is 1, but you can change it using the --batch-size argument. Replace resnet18 with any of the other models listed in Model Zoo that you wish to evaluate.

Training

Short version:

python train.py \
  --dataset ImageNet \
  --base_network bcos_final \
  --experiment_name resnet18

Long version: See TRAINING.md for more details on how the setup works and how to train your own models.

Model Zoo

Here are the ImageNet pre-trained models available in the model zoo. You can find the links to the model weights below (uploaded to the Weights GitHub release).

Model/Entrypoint Top-1 Accuracy Top-5 Accuracy #Params Download
resnet18 68.736% 87.430% 11.69M link
resnet34 72.284% 90.052% 21.80M link
resnet50 75.882% 92.528% 25.52M link
resnet101 76.532% 92.538% 44.50M link
resnet152 76.484% 92.398% 60.13M link
resnext50_32x4d 75.820% 91.810% 25.00M link
densenet121 73.612% 91.106% 7.95M link
densenet161 76.622% 92.554% 28.58M link
densenet169 75.186% 91.786% 14.08M link
densenet201 75.480% 91.992% 19.91M link
vgg11_bnu 69.310% 88.388% 132.86M link
convnext_tiny 77.488% 93.192% 28.54M link
convnext_base 79.650% 94.614% 88.47M link
convnext_tiny_bnu 76.826% 93.090% 28.54M link
convnext_base_bnu 80.142% 94.834% 88.47M link
densenet121_long 77.302% 93.234% 7.95M link
resnet50_long 79.468% 94.452% 25.52M link
resnet152_long 80.144% 94.116% 60.13M link
simple_vit_ti_patch16_224 59.960% 81.838% 5.80M link
simple_vit_s_patch16_224 69.246% 88.096% 22.28M link
simple_vit_b_patch16_224 74.408% 91.156% 86.90M link
simple_vit_l_patch16_224 75.060% 91.378% 178.79M link
vitc_ti_patch1_14 67.260% 86.774% 5.32M link
vitc_s_patch1_14 74.504% 91.288% 20.88M link
vitc_b_patch1_14 77.152% 92.926% 81.37M link
vitc_l_patch1_14 77.782% 92.966% 167.44M link
standard_simple_vit_ti_patch16_224 70.230% 89.380% 5.67M link
standard_simple_vit_s_patch16_224 74.470% 91.226% 21.96M link
standard_simple_vit_b_patch16_224 75.300% 91.026% 86.38M link
standard_simple_vit_l_patch16_224 75.710% 90.050% 178.10M link
standard_vitc_ti_patch1_14 72.590% 90.788% 5.33M link
standard_vitc_s_patch1_14 75.756% 91.994% 20.91M link
standard_vitc_b_patch1_14 76.790% 92.024% 81.39M link
standard_vitc_l_patch1_14 77.866% 92.298% 167.54M link

You can find these entrypoints in bcos/models/pretrained.py.

License

This repository's code is licensed under the Apache License 2.0 which you can find in the LICENSE file.

The pre-trained models are trained on ImageNet (and are hence derived from it), which is licensed under the ImageNet Terms of access, which among others things, only allows non-commercial use of the dataset. It is therefore your responsibility to check whether you have permission to use the pre-trained models for your use case.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bcos-0.1.0.tar.gz (91.1 kB view details)

Uploaded Source

Built Distribution

bcos-0.1.0-py3-none-any.whl (114.7 kB view details)

Uploaded Python 3

File details

Details for the file bcos-0.1.0.tar.gz.

File metadata

  • Download URL: bcos-0.1.0.tar.gz
  • Upload date:
  • Size: 91.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.8

File hashes

Hashes for bcos-0.1.0.tar.gz
Algorithm Hash digest
SHA256 a1bb12bf220fb7056441308a8ecab932fe1c82fc87d70b5237057a2745a75c8f
MD5 41bf0e2e024978890310699231b631bc
BLAKE2b-256 e1b778cb9a9c8250f56870fc4d5593d84151e4aeba012fb818c6304de47bef28

See more details on using hashes here.

File details

Details for the file bcos-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: bcos-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 114.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.8

File hashes

Hashes for bcos-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8f176fda508d55a144b9f99b85613a692df95c6c19c5d973ec69a1f2241c03a3
MD5 b9d4d10f9faaa483c9fba0059a4764f3
BLAKE2b-256 fe0fdefa14947504654e3288542f9c1155b7f799f6e53f77f19410aaa93e22b7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page