Skip to main content

Product Key Memory

Project description

Product Key Memory

PyPI version

Standalone Product Key Memory module for augmenting Transformer models

Install

$ pip install product-key-memory

Usage

Replace the feedforwards in a Transformer with the following

import torch
from product_key_memory import PKM

pkm = PKM(
    dim = 512,
    heads = 4,
    dim_head = 128,       # keep at 128 for best results
    num_keys = 256,       # number of subkeys, # values will be num_keys ^ 2
    topk = 32             # the top number of subkeys to select
)

x = torch.randn(1, 1024, 512)
mask = torch.ones((1, 1024)).bool()
values = pkm(x, input_mask = mask) # (1, 1024, 512)

Learning Rates

To give different learning rates to the value parameters of the product-key-memory network, use the following helper function.

from torch.optim import Adam
from product_key_memory import fetch_pkm_value_parameters

# this helper function, for your root model, finds all the PKM models and the embedding bag weight parameters
pkm_parameters, other_parameters = fetch_pkm_value_parameters(model)

optim = Adam([
    {'params': other_parameters},
    {'params': pkm_parameters, 'lr': 1e-2}
], lr=1e-3)

Or, if product-key-memory parameters are the only other parameters you have a different learning rate for

from torch.optim import Adam
from product_key_memory import fetch_optimizer_parameters

parameters = fetch_optimizer_parameters(model) # automatically creates array of parameter settings with learning rate set at 1e-2 for pkm values
optim = Adam(parameters, lr=1e-3)

Appreciation

Special thanks go to Aran for encouraging me to look into this, and to Madison May for his educational blog post, which helped me understand this better.

Todo

  • offer stochasticity with annealed gumbel noise. seen dramatic effects in vector-quantization setting

  • offer a way for smaller value dimensions + concat and linear combination of heads (like multi-head attention)

  • get caught up on latest literature on product key memories, if any

  • instead of additive scores, try multiplicative using coordinate descent routing

Citations

@misc{lample2019large,
    title   = {Large Memory Layers with Product Keys},
    author  = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
    year    = {2019},
    eprint  = {1907.05242},
    archivePrefix = {arXiv}
}
@misc{liu2020evolving,
    title   = {Evolving Normalization-Activation Layers},
    author  = {Hanxiao Liu and Andrew Brock and Karen Simonyan and Quoc V. Le},
    year    = {2020},
    eprint  = {2004.02967},
    archivePrefix = {arXiv}
}
@article{Shen2023ASO,
    title   = {A Study on ReLU and Softmax in Transformer},
    author  = {Kai Shen and Junliang Guo and Xuejiao Tan and Siliang Tang and Rui Wang and Jiang Bian},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2302.06461},
    url     = {https://api.semanticscholar.org/CorpusID:256827573}
}
@article{Csordas2023ApproximatingTF,
    title   = {Approximating Two-Layer Feedforward Networks for Efficient Transformers},
    author  = {R'obert Csord'as and Kazuki Irie and J{\"u}rgen Schmidhuber},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2310.10837},
    url     = {https://api.semanticscholar.org/CorpusID:264172384}
}
@inproceedings{anonymous2025continual,
    title   = {Continual Learning via Sparse Memory Finetuning},
    author  = {Anonymous},
    booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
    year    = {2025},
    url     = {https://openreview.net/forum?id=LGo7U1m24L},
    note    = {under review}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

product_key_memory-0.3.0.tar.gz (36.6 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

product_key_memory-0.3.0-py3-none-any.whl (8.3 kB view details)

Uploaded Python 3

File details

Details for the file product_key_memory-0.3.0.tar.gz.

File metadata

  • Download URL: product_key_memory-0.3.0.tar.gz
  • Upload date:
  • Size: 36.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for product_key_memory-0.3.0.tar.gz
Algorithm Hash digest
SHA256 e4961ee71da62a25e6740bfca693feef20f4503e1049233c8684c27def065a2a
MD5 402e72ee264323ed198f42f6f837f316
BLAKE2b-256 d21588ff08fc280ee8fe702379699e57220aebaf4a676dc740ef6b0a97b3a26d

See more details on using hashes here.

File details

Details for the file product_key_memory-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for product_key_memory-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 bea0b37808906f173aa05f54bcda508d9e22c63055513124022c503dfa2e5c70
MD5 3fb6d849eefc209e3801aaaccb14fbbc
BLAKE2b-256 951683bd0a7213e5f38713a565ce3b62d53d7ddec7ac0dfa1aa18b0144b48e9f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page