Memory Wrap: an extension for image classification models
Memory Wrap is an extension to image classification models that improves both data-efficiency and model interpretability, adopting a sparse content-attention mechanism between the input and some memories of past training samples.
This is a PyTorch implementation of Memory Wrap. To install Memory Wrap run the following command:
pip install memorywrap
The library contains two main classes:
- MemoryWrapLayer: it is the Memory Wrap variant described in the paper that uses both the input encoding and the memory encoding to compute the output;
- BaselineMemory: it is the baseline that uses only the memory encoding to compute the output.
Instantiate the layer
memorywrap = MemoryWrapLayer(encoder_output_dim, output_dim, classifier=None, distance='cosine')
or, for the baseline that uses only the memory to output the prediction:
memorywrap = BaselineMemory(encoder_output_dim, output_dim, classifier=None, distance='cosine')
encoder_output_dim (int) is the output dimension of the last layer of the encoder
output_dim (int) is the desired output dimensione. In the case of the paper output_dim is equal to the number of classes;
classifier (torch.nn.Module): Classifier on top of MemoryWrap. Inputs dimensions must be equal at encoder_output_dim*2 for MemoryWrapLayer and encoder_output_dim for BaselineMemory. By default is an MLP as described in the paper. An alternative is to use a linear layer. (e.g. torch.nn.Linear(encoder_output_dim*2, output_dim). Default: torch.nn.Sequential( torch.nn.Linear(encoder_output_dim*2, encoder_output_dim*4), torch.nn.ReLU(), torch.nn.Linear(encoder_output_dim*4, output_dim)
distance (str): Distance to use to compute the similarity between input and memory set. Allowed values are: cosine, l2 and dot for respectively cosine similarity, l2 distance and dot product distance. Default=cosine
Add the forward call to your forward function.
output_memorywrap = memorywrap(input_encoding, memory_encoding, return_weights=False)
where input_encoding and memory_encoding are the outputs of the the encoder of rispectively the current input and the memory set.
The last argument of the Memory Wrap's call function is a boolean flag controlling the number of outputs returned. If the flag is True, then the layer returns both the output and the sparse attention weight associated to each memory sample; if the flag is False, then the layer return only the output.
Here you can find link to additional source of information about Memory Wrap:
Release history Release notifications | RSS feed
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Hashes for memorywrap-1.1.7-py3-none-any.whl