Skip to main content

Memory Wrap: an extension for image classification models

Project description

Description

Memory Wrap is an extension to image classification models that improves both data-efficiency and model interpretability, adopting a sparse content-attention mechanism between the input and some memories of past training samples.

Installation

This is a PyTorch implementation of Memory Wrap. To install Memory Wrap run the following command:

pip install memorywrap

The library contains two main classes:

  • MemoryWrapLayer: it is the Memory Wrap variant described in the paper that uses both the input encoding and the memory encoding to compute the output;
  • BaselineMemory: it is the baseline that uses only the memory encoding to compute the output.

Usage

Instantiate the layer

memorywrap = MemoryWrapLayer(encoder_output_dim, output_dim, head=None, classifier=None, distance='cosine')

or, for the baseline that uses only the memory to output the prediction:

memorywrap = BaselineMemory(encoder_output_dim, output_dim, head=None, classifier=None, distance='cosine')

where

  • encoder_output_dim (int) is the output dimension of the last layer of the encoder

  • output_dim (int) is the desired output dimensione. In the case of the paper output_dim is equal to the number of classes;

  • head (torch.nn.Module): Read head used to project the key and query. It can be a linear or non-linear layer. Input dimensions must be equal to encoder_output_dim (in this case 1280). If None, it is fixed as a linear layer with input and output dimension equal to the input dimension of MemoryWrap(encoder_output_dim). (See https://www.nature.com/articles/nature20101 for further information)

  • classifier (torch.nn.Module): Classifier on top of MemoryWrap. Inputs dimensions must be equal at encoder_output_dim*2 for MemoryWrapLayer and encoder_output_dim for BaselineMemory. By default is an MLP as described in the paper. An alternative is to use a linear layer. (e.g. torch.nn.Linear(encoder_output_dim*2, output_dim). Default: torch.nn.Sequential( torch.nn.Linear(encoder_output_dim*2, encoder_output_dim*4), torch.nn.ReLU(), torch.nn.Linear(encoder_output_dim*4, output_dim)

  • distance (str): Distance to use to compute the similarity between input and memory set. Allowed values are: cosine, l2 and dot for respectively cosine similarity, l2 distance and dot product distance. Default=cosine

Forward call

Add the forward call to your forward function.

output_memorywrap = memorywrap(input_encoding, memory_encoding, return_weights=False)

where input_encoding and memory_encoding are the outputs of the the encoder of rispectively the current input and the memory set.
The last argument of the Memory Wrap's call function is a boolean flag controlling the number of outputs returned. If the flag is True, then the layer returns both the output and the sparse attention weight associated to each memory sample; if the flag is False, then the layer return only the output.

Additional information

Here you can find link to additional source of information about Memory Wrap:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

memorywrap-1.1.3.tar.gz (4.3 kB view details)

Uploaded Source

Built Distribution

memorywrap-1.1.3-py3-none-any.whl (4.6 kB view details)

Uploaded Python 3

File details

Details for the file memorywrap-1.1.3.tar.gz.

File metadata

  • Download URL: memorywrap-1.1.3.tar.gz
  • Upload date:
  • Size: 4.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.0 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.10

File hashes

Hashes for memorywrap-1.1.3.tar.gz
Algorithm Hash digest
SHA256 de6d205d4a7a0207a369727a02534cc1dda5689b47b763ed24c5d2ddc28a2c50
MD5 62fb52ff681b6555bde91686f34843a3
BLAKE2b-256 a3dde830bad59615a0ff702812968e9d1ce44a548c669e844ecf36734220f103

See more details on using hashes here.

File details

Details for the file memorywrap-1.1.3-py3-none-any.whl.

File metadata

  • Download URL: memorywrap-1.1.3-py3-none-any.whl
  • Upload date:
  • Size: 4.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.0 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.10

File hashes

Hashes for memorywrap-1.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 af092ed4ca557e208063bb08656b33122f1f8675d93097b5512ace979ea32b7a
MD5 5597ec9e815a5e7943aa3667a5b1b9a9
BLAKE2b-256 ceec535d315ea7209a99e6ff70e92c08d04388f4212bd30c06afcb5ee4ee93fb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page