Skip to main content

A partial implementation of Generative Infinite Vocabulary Transformer (GIVT) from Google Deepmind, in PyTorch.

Project description

GIVT-PyTorch

A partial implementation of Generative Infinite Vocabulary Transformer (GIVT) from Google Deepmind, in PyTorch.

This repo only implements the causal version of GIVT, and does away with the k mixtures predictions or the use of the full covariance matrix, as for most purposes they did not yield significantly better results.

The decoder transformer implementation is also modernized, adopting a Llama style architecture with gated MLPs, SiLU, RMSNorm, and RoPE.

Install

# for inference
pip install .

# for training/development
pip install -e '.[train]'

Usage

from givt_pytorch import GIVT

# load pretrained checkpoint
model = GIVT.from_pretrained('elyxlz/givt-test')

latents = torch.randn((4, 500, 32)) # vae latents (bs, seq_len, size)
loss = model.forward(latents).loss # NLL Loss

prompt = torch.randn((50, 32)) # no batched inference implemented
generated = model.generate(
    prompt=prompt, 
    max_len=500,
    cfg_scale=0.5,
    temperature=0.95,
) # (500, 32)

Training

Define a config file in configs/, such as this one:

from givt_pytorch import (
    GIVT,
    GIVTConfig,
    DummyDataset,
    Trainer,
    TrainConfig
)

model = GIVT(GIVTConfig())
dataset = DummyDataset()

trainer = Trainer(
    model=model,
    dataset=dataset,    
    train_config=TrainConfig()
)

Create an accelerate config.

accelerate config

And then run the training.

accelerate launch train.py {config_name}

TODO

  • Test out with latents from an audio vae
  • Add CFG with rejection sampling

References

@misc{litgpt2024,
  title={lit-gpt on GitHub},
  url={https://github.com/Lightning-AI/lit-gpt},
  year={2024}

@misc{tschannen2023givt,
    title   = {GIVT: Generative Infinite-Vocabulary Transformers}, 
    author  = {Michael Tschannen, Cian Eastwood, Fabian Mentzer},
    year    = {2023},
    eprint  = {2312.02116},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

givt_pytorch-0.0.3.tar.gz (11.9 kB view details)

Uploaded Source

Built Distribution

givt_pytorch-0.0.3-py3-none-any.whl (11.8 kB view details)

Uploaded Python 3

File details

Details for the file givt_pytorch-0.0.3.tar.gz.

File metadata

  • Download URL: givt_pytorch-0.0.3.tar.gz
  • Upload date:
  • Size: 11.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for givt_pytorch-0.0.3.tar.gz
Algorithm Hash digest
SHA256 89aba1029ae2d3e410c92d2def8b2e3cccc22a3c385383b649b73fb40b9b4e4c
MD5 71b6ba1b6c88ef1ac59c17f155293319
BLAKE2b-256 f4d544be0a22d72a9bc47757ec5a36e2fd900ffebfd6d2941c6bf6733b2c443d

See more details on using hashes here.

File details

Details for the file givt_pytorch-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for givt_pytorch-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 3e88e10598998a55fb28eee8b8552f604f623fdbc92c83b8440dd2071033e69c
MD5 5a5b9ab4748159ed7d6a530fc14441a6
BLAKE2b-256 607031a60806155cb2bc18dc8cceb5f1954982620eddf741b40492201552364d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page