Memory Wrap: an extension for image classification models
Project description
Description
Memory Wrap is an extension to image classification models that improves both data-efficiency and model interpretability, adopting a sparse content-attention mechanism between the input and some memories of past training samples.
Installation
This is a PyTorch implementation of Memory Wrap. To install Memory Wrap run the following command:
pip install memorywrap
The library contains two main classes:
- MemoryWrapLayer: it is the Memory Wrap variant described in the paper that uses both the input encoding and the memory encoding to compute the output;
- BaselineMemory: it is the baseline that uses only the memory encoding to compute the output.
Usage
Instantiate the layer
memorywrap = MemoryWrapLayer(encoder_dim,output_dim,mlp_activation=torch.nn.ReLU())
or, for the baseline that uses only the memory to output the prediction:
memorywrap = BaselineMemory(encoder_dim,output_dim,mlp_activation=torch.nn.ReLU())
where
- encoder_dim is the output dimension of the last layer of the encoder
- output_dim is the desired output dimensione. In the case of the paper output_dim is equal to the number of classes;
- mlp_activation s the activation function that must be used in the hidden layer of multi-layer perceptron.
Forward call
Add the forward call to your forward function.
output_memorywrap = memorywrap(input_encoding,memory_encoding,return_weights=False)
where input_encoding and memory_encoding are the outputs of the the encoder of rispectively the current input and the memory set.
The last argument of the Memory Wrap's call function is a boolean flag controlling the number of outputs returned. If the flag is True, then the layer returns both the output and the sparse attention weight associated to each memory sample; if the flag is False, then the layer return only the output.
Additional information
Here you can find link to additional source of information about Memory Wrap:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for memorywrap-1.0.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d6e28c01ffdc1f47a36dbf3a3a29644bd543c4be721bf700f333897c20a457d1 |
|
MD5 | 365046427b7cbc5365dc907cc676cb13 |
|
BLAKE2b-256 | a2acc412a1df87cd39ca31d4d5e60ca8d2890cebcbbb35bfae3d84c75dc2c684 |