Memory Wrap: an extension for image classification models
Project description
Description
Memory Wrap is an extension to image classification models that improves both data-efficiency and model interpretability, adopting a sparse content-attention mechanism between the input and some memories of past training samples.
Installation
This is a PyTorch implementation of Memory Wrap. To install Memory Wrap run the following command:
pip install memorywrap
The library contains two main classes:
- MemoryWrapLayer: it is the Memory Wrap variant described in the paper that uses both the input encoding and the memory encoding to compute the output;
- BaselineMemory: it is the baseline that uses only the memory encoding to compute the output.
Usage
Instantiate the layer
memorywrap = MemoryWrapLayer(encoder_dim, output_dim, mlp_activation=torch.nn.ReLU())
or, for the baseline that uses only the memory to output the prediction:
memorywrap = BaselineMemory(encoder_dim, output_dim, mlp_activation=torch.nn.ReLU())
where
- encoder_dim is the output dimension of the last layer of the encoder
- output_dim is the desired output dimensione. In the case of the paper output_dim is equal to the number of classes;
- mlp_activation s the activation function that must be used in the hidden layer of multi-layer perceptron.
Forward call
Add the forward call to your forward function.
output_memorywrap = memorywrap(input_encoding, memory_encoding, return_weights=False)
where input_encoding and memory_encoding are the outputs of the the encoder of rispectively the current input and the memory set.
The last argument of the Memory Wrap's call function is a boolean flag controlling the number of outputs returned. If the flag is True, then the layer returns both the output and the sparse attention weight associated to each memory sample; if the flag is False, then the layer return only the output.
Additional information
Here you can find link to additional source of information about Memory Wrap:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for memorywrap-1.0.6-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 19c4c75e658698e9fb4ad4acb8f8841a894ea37a75b7457bb5b4f08205683a7c |
|
MD5 | 8ad3d1bc11ea6c5b007bad930f0a1703 |
|
BLAKE2b-256 | 5c8e7383dce973e483ad580d31ed155c35d86a8a25fd904edf2896860bd02432 |