Skip to main content

B-cos models.

Project description

B-cos Networks v2

M. Böhle, N. Singh, M. Fritz, B. Schiele.

Improved B-cos Networks.

Table of Contents

Introduction

This repository contains the code for the B-cos v2 models. (Paper coming soon!)

These models are more efficient and easier to train than the original B-cos models. Furthermore, we make a large number of pretrained B-cos models available for use.

If you want to take a quick look at the explanations the models generate, you can try out the Gradio web demo on Hugging Face Spaces.

If you prefer a more hands-on approach, you can take a look at the demo notebook on Colab or load the models directly via torch.hub as explained below.

If you simply want to copy the model definitions, we provide a minimal, single-file reference implementation including explanation mode in extra/minimal_bcos_resnet.py!

Quick Start

You only need to make sure you have torch and torchvision installed.

Then, loading the models via torch.hub is as easy as:

import torch

# list all available models
torch.hub.list('B-cos/B-cos-v2')

# load a pretrained model
model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)

Inference and explanation visualization is as simple as:

from PIL import Image
import matplotlib.pyplot as plt

# load image
img = model.transform(Image.open('cat.jpg'))
img = img[None].requires_grad_()

# predict and explain
model.eval()
expl_out = model.explain(img)
print("Prediction:", expl_out["prediction"])  # predicted class idx
plt.imshow(expl_out["explanation"])
plt.show()

Each of the models has its inference transform attached to it, accessible via model.transform. Furthermore, each model has a .explain() method that takes an image tensor and returns a dictionary containing the prediction and the explanation, and some extras.

See the demo notebook for more details on the .explain() method.

Furthermore, each model has a get_classifier and get_feature_extractor method that return the classifier and feature extractor modules respectively. These can useful for fine-tuning the models!

Installation

Depending on your use case, you can either install the bcos package or set up the development environment for training the models (for your custom models or for reproducing the results).

bcos Package

If you are simply interested in using the models (pretrained or otherwise), then we provide a bcos package that can be installed via pip:

pip install bcos

This contains the models, their modules, transforms, and other utilities making it easy to use and build B-cos models. Take a look at the public API here. (I'll add a proper docs site if I have time or there's enough interest. Nonetheless, I have tried to keep the code well-documented, so it should be easy to follow.)

Training Environment Setup

If you want to train your own B-cos models using this repository or are interested in reproducing the results, you can set up the development environment as follows:

Using conda (recommended, especially if you want to reproduce the results):

conda env create -f environment.yml
conda activate bcos

Using pip

pip install -r requirements-train.txt

Setting Data Paths

You can either set the paths in bcos/settings.py or set the environment variables

  1. DATA_ROOT
  2. IMAGENET_PATH

to the paths of the data directories.

The DATA_ROOT environment variable should point to the data root directory for CIFAR-10 (will be automatically downloaded). For ImageNet, the IMAGENET_PATH environment variable should point to the directory containing the train and val directories.

Usage

For the bcos package, as mentioned earlier, take a look at the public API here.

For evaluating or training the models, you can use the evaluate.py and train.py scripts, as follows:

Evaluation

You can use evaluate the accuracy of the models on the ImageNet validation set using:

python evaluate.py --dataset ImageNet --hubconf resnet18

This will download the model from torch.hub and evaluate it on the ImageNet validation set. The default batch size is 1, but you can change it using the --batch-size argument. Replace resnet18 with any of the other models listed in Model Zoo that you wish to evaluate.

Training

Short version:

python train.py \
  --dataset ImageNet \
  --base_network bcos_final \
  --experiment_name resnet18

Long version: See TRAINING.md for more details on how the setup works and how to train your own models.

Model Zoo

Here are the ImageNet pre-trained models available in the model zoo. You can find the links to the model weights below (uploaded to the Weights GitHub release).

Model/Entrypoint Top-1 Accuracy Top-5 Accuracy #Params Download
resnet18 68.736% 87.430% 11.69M link
resnet34 72.284% 90.052% 21.80M link
resnet50 75.882% 92.528% 25.52M link
resnet101 76.532% 92.538% 44.50M link
resnet152 76.484% 92.398% 60.13M link
resnext50_32x4d 75.820% 91.810% 25.00M link
densenet121 73.612% 91.106% 7.95M link
densenet161 76.622% 92.554% 28.58M link
densenet169 75.186% 91.786% 14.08M link
densenet201 75.480% 91.992% 19.91M link
vgg11_bnu 69.310% 88.388% 132.86M link
convnext_tiny 77.488% 93.192% 28.54M link
convnext_base 79.650% 94.614% 88.47M link
convnext_tiny_bnu 76.826% 93.090% 28.54M link
convnext_base_bnu 80.142% 94.834% 88.47M link
densenet121_long 77.302% 93.234% 7.95M link
resnet50_long 79.468% 94.452% 25.52M link
resnet152_long 80.144% 94.116% 60.13M link

We'll add ViT models to the model zoo soon, stay tuned! You can find these entrypoints in bcos/models/pretrained.py.

License

This repository's code is licensed under the Apache License 2.0 which you can find in the LICENSE file.

The pre-trained models are trained on ImageNet (and are hence derived from it), which is licensed under the ImageNet Terms of access, which among others things, only allows non-commercial use of the dataset. It is therefore your responsibility to check whether you have permission to use the pre-trained models for your use case.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bcos-0.0.3.tar.gz (84.1 kB view details)

Uploaded Source

Built Distribution

bcos-0.0.3-py3-none-any.whl (106.7 kB view details)

Uploaded Python 3

File details

Details for the file bcos-0.0.3.tar.gz.

File metadata

  • Download URL: bcos-0.0.3.tar.gz
  • Upload date:
  • Size: 84.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.8

File hashes

Hashes for bcos-0.0.3.tar.gz
Algorithm Hash digest
SHA256 deaa52a2c75c75ce53a80d3fb8490c5cf3ecca815ee7bdc059856be224431f22
MD5 048baf029babe26be9cd6dfa73723134
BLAKE2b-256 1e98074ea7b27f5d4265603cc97e0cef2d67e3b3bfdcc440c4fd8484ad093d58

See more details on using hashes here.

File details

Details for the file bcos-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: bcos-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 106.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.8

File hashes

Hashes for bcos-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 d628f773b377f4f48a04047bac928b6ea7db25484c253cf6e0afc8dff6b50faf
MD5 a7ba89af8c1228352f71c4701d6722f0
BLAKE2b-256 ac5b7100b4fe36150a1b7e56118aa9fb0e2033df47f6dac0c7379bd3ba840ac5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page