Skip to main content

nGPT

Project description

nGPT (normalized GPT) - Pytorch

Quick implementation of nGPT, learning entirely on the hypersphere, from NvidiaAI. The question is whether there is any loss of expressivity they swept under the rug, but I'll take it with good faith.

This type of network should also be studied in the context of continual learning and loss of plasticity

Adaptation to vision transformers is here

Install

$ pip install nGPT-pytorch

Usage

import torch
from nGPT_pytorch import nGPT

model = nGPT(
    num_tokens = 256,
    dim = 512,
    depth = 4,
    attn_norm_qk = True
)

x = torch.randint(0, 256, (2, 2048))

loss = model(x, return_loss = True)
loss.backward()

logits = model(x) # (2, 2048, 256)

Test

Enwik8

$ python train.py

Citations

@inproceedings{Loshchilov2024nGPTNT,
    title   = {nGPT: Normalized Transformer with Representation Learning on the Hypersphere},
    author  = {Ilya Loshchilov and Cheng-Ping Hsieh and Simeng Sun and Boris Ginsburg},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:273026160}
}
@article{Luo2017CosineNU,
    title     = {Cosine Normalization: Using Cosine Similarity Instead of Dot Product in Neural Networks},
    author    = {Chunjie Luo and Jianfeng Zhan and Lei Wang and Qiang Yang},
    journal   = {ArXiv},
    year      = {2017},
    volume    = {abs/1702.05870},
    url       = {https://api.semanticscholar.org/CorpusID:1505432}
}
@inproceedings{Zhou2024ValueRL,
    title   = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
    author  = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:273532030}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ngpt_pytorch-0.2.7.tar.gz (36.9 MB view details)

Uploaded Source

Built Distribution

ngpt_pytorch-0.2.7-py3-none-any.whl (15.0 kB view details)

Uploaded Python 3

File details

Details for the file ngpt_pytorch-0.2.7.tar.gz.

File metadata

  • Download URL: ngpt_pytorch-0.2.7.tar.gz
  • Upload date:
  • Size: 36.9 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.20

File hashes

Hashes for ngpt_pytorch-0.2.7.tar.gz
Algorithm Hash digest
SHA256 6b8b5d374afa15e0cfc8870be156cba95cfee0484775b8f97d3a5b5f900c83dd
MD5 a682850a5db83e4a647fa210d9c19205
BLAKE2b-256 4a4ab78d7e69497f6f2f1728066675929b798043e0a11750367102471bfcedaf

See more details on using hashes here.

File details

Details for the file ngpt_pytorch-0.2.7-py3-none-any.whl.

File metadata

File hashes

Hashes for ngpt_pytorch-0.2.7-py3-none-any.whl
Algorithm Hash digest
SHA256 01855fa1bb4dcffbedcab1f50eb42b67e53ca3f6546c2aa19d127560aecc6426
MD5 f7d4dc7d5edce3ff7639e93b98f6d01f
BLAKE2b-256 ac3655697368aa71f7a61f952a28152734fc5ac374be2b8c55552324c5967043

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page