nGPT
Project description
nGPT (normalized GPT) - Pytorch
Quick implementation of nGPT, learning entirely on the hypersphere, from NvidiaAI. The question is whether there is any loss of expressivity they swept under the rug, but I'll take it with good faith.
This type of network should also be studied in the context of continual learning and loss of plasticity
Adaptation to vision transformers is here
Install
$ pip install nGPT-pytorch
Usage
import torch
from nGPT_pytorch import nGPT
model = nGPT(
num_tokens = 256,
dim = 512,
depth = 4,
attn_norm_qk = True
)
x = torch.randint(0, 256, (2, 2048))
loss = model(x, return_loss = True)
loss.backward()
logits = model(x) # (2, 2048, 256)
Test
Enwik8
$ python train.py
Citations
@inproceedings{Loshchilov2024nGPTNT,
title = {nGPT: Normalized Transformer with Representation Learning on the Hypersphere},
author = {Ilya Loshchilov and Cheng-Ping Hsieh and Simeng Sun and Boris Ginsburg},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273026160}
}
@article{Luo2017CosineNU,
title = {Cosine Normalization: Using Cosine Similarity Instead of Dot Product in Neural Networks},
author = {Chunjie Luo and Jianfeng Zhan and Lei Wang and Qiang Yang},
journal = {ArXiv},
year = {2017},
volume = {abs/1702.05870},
url = {https://api.semanticscholar.org/CorpusID:1505432}
}
@inproceedings{Zhou2024ValueRL,
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
ngpt_pytorch-0.2.6.tar.gz
(36.9 MB
view hashes)
Built Distribution
Close
Hashes for ngpt_pytorch-0.2.6-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 25b448fb6f00bae1ab68e8387906778de63737bc11ba4aab36a792a0f438ceae |
|
MD5 | 9bd3d56b9f75ffb93e03379d95d0914d |
|
BLAKE2b-256 | 739769071cef9b6ed42544fd966e1e0ccc736e0b059b9ece1cad9a07661443b4 |