Skip to main content

nGPT

Project description

nGPT (normalized GPT) - Pytorch

Quick implementation of nGPT, learning entirely on the hypersphere, from NvidiaAI. The question is whether there is any loss of expressivity they swept under the rug, but I'll take it with good faith.

This type of network should also be studied in the context of continual learning and loss of plasticity

Adaptation to vision transformers is here

Update: The normalized feedforward was successfully applied for improving UTD for off-policy reinforcement learning in a new paper

Install

$ pip install nGPT-pytorch

Usage

import torch
from nGPT_pytorch import nGPT

model = nGPT(
    num_tokens = 256,
    dim = 512,
    depth = 4,
    attn_norm_qk = True
)

x = torch.randint(0, 256, (2, 2048))

loss = model(x, return_loss = True)
loss.backward()

logits = model(x) # (2, 2048, 256)

Test

Enwik8

$ python train.py

Citations

@inproceedings{Loshchilov2024nGPTNT,
    title   = {nGPT: Normalized Transformer with Representation Learning on the Hypersphere},
    author  = {Ilya Loshchilov and Cheng-Ping Hsieh and Simeng Sun and Boris Ginsburg},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:273026160}
}
@article{Luo2017CosineNU,
    title     = {Cosine Normalization: Using Cosine Similarity Instead of Dot Product in Neural Networks},
    author    = {Chunjie Luo and Jianfeng Zhan and Lei Wang and Qiang Yang},
    journal   = {ArXiv},
    year      = {2017},
    volume    = {abs/1702.05870},
    url       = {https://api.semanticscholar.org/CorpusID:1505432}
}
@inproceedings{Zhou2024ValueRL,
    title   = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
    author  = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:273532030}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ngpt_pytorch-0.2.16.tar.gz (36.9 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ngpt_pytorch-0.2.16-py3-none-any.whl (16.6 kB view details)

Uploaded Python 3

File details

Details for the file ngpt_pytorch-0.2.16.tar.gz.

File metadata

  • Download URL: ngpt_pytorch-0.2.16.tar.gz
  • Upload date:
  • Size: 36.9 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for ngpt_pytorch-0.2.16.tar.gz
Algorithm Hash digest
SHA256 37ac18ee5b6de25448de67bc7a83cdaecc06d7ef32e013b1b2337894a20a02f7
MD5 c8333452244093c704ebd7adf8af663b
BLAKE2b-256 e653232fd16e9d4bf2bda5a514714a82a1192e551de449230797e7a4e2ff5bd3

See more details on using hashes here.

File details

Details for the file ngpt_pytorch-0.2.16-py3-none-any.whl.

File metadata

  • Download URL: ngpt_pytorch-0.2.16-py3-none-any.whl
  • Upload date:
  • Size: 16.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for ngpt_pytorch-0.2.16-py3-none-any.whl
Algorithm Hash digest
SHA256 01ca04c300994b427e9425acf4fa2343091a0d35627af7be57cbd1c2c6eeaf84
MD5 246efb96c196e13b86d1c46edf17f103
BLAKE2b-256 d310494fbda8fbfe497e4e653ee8b5156865818fa94be2c4ce79ced198e83d3e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page