Skip to main content

Deep learning image classificaiton informed by expert attention

Project description

Expert-attention guided deep learning for medical images

Get Started

Pip install the PYPI distro:

pip install expert-informed-dl

Here's an example of how to use the trained model for inference (with subimages)

Check out eidl/examples/subimage_example.py for a simple example of how to use the trained model for inference on subimages.

from eidl.utils.model_utils import get_subimage_model

subimage_handler = get_subimage_model()
subimage_handler.compute_perceptual_attention('9025_OD_2021_widefield_report', is_plot_results=True, discard_ratio=0.1)

If you want to use the rollouts/gradcams in a user interface, you may consider precomputing them, as it can be slow to compute them on the fly.

from eidl.utils.model_utils import get_subimage_model

subimage_handler = get_subimage_model(precompute='vit')

# or

subimage_handler = get_subimage_model(precompute='resnet')

# or

subimage_handler = get_subimage_model(precompute=['vit', 'resnet'])

If you don't want to use subimages:

Check out eidl/examples/example.py for a simple example of how to use the trained model for inference.

When forwarding image through the network, use the argument collapse_attention_matrix=True to get the attention matrix to get the attention matrix averaged across all heads and keys for each query token.

y_pred, attention_matrix = model(image_data, collapse_attention_matrix=False)

Train model locally

Install requirements.txt

Download Pytorch matching with a CUDA version matching your GPU from here.

Run train.py

For example, if you have 32 * 32 patches, the attention matrix will be of size (32 * 32 + 1) 1025. Plus one for the classificaiton token. If you set collapse_attention_matrix=False, the attention matrix will be uncollapsed. The resulting attention matrix will be of shape (n_batch, n_heads, n_queries, n_keys). For example, if you have 32 * 32 patches, one image and one head, the attention matrix will be of shape (1, 1, 1025, 1025).

Troubleshoot

If get model functions raises the following error:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

You will need to install the correct version of Pytorch matching with a CUDA version matching your GPU from here. This is because all the models are trained on GPU.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

expert_informed_dl-0.0.23.tar.gz (1.3 MB view details)

Uploaded Source

Built Distribution

expert_informed_dl-0.0.23-py3-none-any.whl (1.3 MB view details)

Uploaded Python 3

File details

Details for the file expert_informed_dl-0.0.23.tar.gz.

File metadata

  • Download URL: expert_informed_dl-0.0.23.tar.gz
  • Upload date:
  • Size: 1.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.13

File hashes

Hashes for expert_informed_dl-0.0.23.tar.gz
Algorithm Hash digest
SHA256 6da012d855b0c215ef1004b2d60d3bf7fe1f4f0bdf42349a2f23a9daeeb6c600
MD5 a243d91c7bb4e4c4e4fab4c4f8c051a7
BLAKE2b-256 3ad5ea666fe39462ecf0b6a523daaf90c45e15e1e314ddcb3cf4b5c718ede698

See more details on using hashes here.

File details

Details for the file expert_informed_dl-0.0.23-py3-none-any.whl.

File metadata

File hashes

Hashes for expert_informed_dl-0.0.23-py3-none-any.whl
Algorithm Hash digest
SHA256 8949533874a02128a7746a5231f5a6da96b1e776f4f8651392cdefdff7257c7b
MD5 07c16d8fa1bb6b6eb169c29dbddd4a53
BLAKE2b-256 df605863cb353f666aa6460224889344df76bf2684842ba85fbb7687a4804b94

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page