Memory Wrap: an extension for image classification models
Project description
Description
Memory Wrap is an extension to image classification models that improves both data-efficiency and model interpretability, adopting a sparse content-attention mechanism between the input and some memories of past training samples.
Installation
This is a PyTorch implementation of Memory Wrap. To install Memory Wrap run the following command:
pip install memorywrap
The library contains two main classes:
- MemoryWrapLayer: it is the Memory Wrap variant described in the paper that uses both the input encoding and the memory encoding to compute the output;
- BaselineMemory: it is the baseline that uses only the memory encoding to compute the output.
Usage
Instantiate the layer
memorywrap = MemoryWrapLayer(encoder_output_dim, output_dim, classifier=None, distance='cosine')
or, for the baseline that uses only the memory to output the prediction:
memorywrap = BaselineMemory(encoder_output_dim, output_dim, classifier=None, distance='cosine')
where
-
encoder_output_dim (int) is the output dimension of the last layer of the encoder
-
output_dim (int) is the desired output dimensione. In the case of the paper output_dim is equal to the number of classes;
-
classifier (torch.nn.Module): Classifier on top of MemoryWrap. Inputs dimensions must be equal at encoder_output_dim*2 for MemoryWrapLayer and encoder_output_dim for BaselineMemory. By default is an MLP as described in the paper. An alternative is to use a linear layer. (e.g. torch.nn.Linear(encoder_output_dim*2, output_dim). Default: torch.nn.Sequential( torch.nn.Linear(encoder_output_dim*2, encoder_output_dim*4), torch.nn.ReLU(), torch.nn.Linear(encoder_output_dim*4, output_dim)
-
distance (str): Distance to use to compute the similarity between input and memory set. Allowed values are: cosine, l2 and dot for respectively cosine similarity, l2 distance and dot product distance. Default=cosine
Forward call
Add the forward call to your forward function.
output_memorywrap = memorywrap(input_encoding, memory_encoding, return_weights=False)
where input_encoding and memory_encoding are the outputs of the the encoder of rispectively the current input and the memory set.
The last argument of the Memory Wrap's call function is a boolean flag controlling the number of outputs returned. If the flag is True, then the layer returns both the output and the sparse attention weight associated to each memory sample; if the flag is False, then the layer return only the output.
Additional information
Here you can find link to additional source of information about Memory Wrap:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for memorywrap-1.1.7-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bf00dd1e9152f51ef018d37217d4e63481a0048b289128b9fb7eccd010a15576 |
|
MD5 | dea629639e7a96dd078968b805b80267 |
|
BLAKE2b-256 | 27816092b3d7be1297dec69a5d57af145cc77b7ebf92f02015c88a9f4769a6a9 |