Memory Wrap: an extension for image classification models
Project description
Description
Memory Wrap is an extension to image classification models that improves both data-efficiency and model interpretability, adopting a sparse content-attention mechanism between the input and some memories of past training samples.
Installation
This is a PyTorch implementation of Memory Wrap. To install Memory Wrap run the following command:
pip install memorywrap
The library contains two main classes:
- MemoryWrapLayer: it is the Memory Wrap variant described in the paper that uses both the input encoding and the memory encoding to compute the output;
- BaselineMemory: it is the baseline that uses only the memory encoding to compute the output.
Usage
Instantiate the layer
memorywrap = MemoryWrapLayer(encoder_dim,output_dim,return_weights=False)
or
memorywrap = BaselineMemory(encoder_dim,output_dim)
where:
- encoder_dim is the output dimension of the last layer of the encoder
- output_dim is the desired output dimensione. In the case of the paper output_dim is equal to the number of classes;
- return_weights is a flag telling to the layer if it has to return the sparse content weights.
Forward call
Add the forward call to your forward function.
output_memorywrap = memorywrap(input_encoding,memory_encoding)
where input_encoding and memory_encoding are the outputs of the the encoder of rispectively the current input and the memory set.
If you have set the flag return_weights to True, then output_memorywrap is a Tuple where the first element is the output and the second one are the content weights associated to each element in the memory_encoding.
Additional information
Here you can find link to additional source of information about Memory Wrap:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for memorywrap-1.0.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0d4ee881f4e481057309b7745da9cb297dd4a22ea17bcac58088c1965a8c6dee |
|
MD5 | 3ec66a23c62051d1b6756232fa20542e |
|
BLAKE2b-256 | ccada9eb665b01c89395eb421aeb189d75b31f5bdf3ea4c4c8403ce9bb49d544 |