Skip to main content

A partial implementation of Generative Infinite Vocabulary Transformer (GIVT) from Google Deepmind, in PyTorch.

Project description

GIVT-PyTorch

A partial implementation of Generative Infinite Vocabulary Transformer (GIVT) from Google Deepmind, in PyTorch.

This repo only implements the causal version of GIVT, and does away with the k mixtures predictions or the use of the full covariance matrix, as for most purposes they did not yield significantly better results.

The decoder transformer implementation is also modernized, adopting a Llama style architecture with gated MLPs, SiLU, RMSNorm, and RoPE.

Install

# for inference
pip install .

# for training/development
pip install -e '.[train]'

Usage

from givt_pytorch import GIVT

# load pretrained checkpoint
model = GIVT.from_pretrained('elyxlz/givt-test')

latents = torch.randn((4, 500, 32)) # vae latents (bs, seq_len, size)
loss = model.forward(latents).loss # NLL Loss

prompt = torch.randn((50, 32)) # no batched inference implemented
generated = model.generate(
    prompt=prompt, 
    max_len=500,
    cfg_scale=0.5,
    temperature=0.95,
) # (500, 32)

Training

Define a config file in configs/, such as this one:

from givt_pytorch import (
    GIVT,
    GIVTConfig,
    DummyDataset,
    Trainer,
    TrainConfig
)

model = GIVT(GIVTConfig())
dataset = DummyDataset()

trainer = Trainer(
    model=model,
    dataset=dataset,    
    train_config=TrainConfig()
)

Create an accelerate config.

accelerate config

And then run the training.

accelerate launch train.py {config_name}

TODO

  • Test out with latents from an audio vae
  • Add CFG with rejection sampling

References

@misc{litgpt2024,
  title={lit-gpt on GitHub},
  url={https://github.com/Lightning-AI/lit-gpt},
  year={2024}

@misc{tschannen2023givt,
    title   = {GIVT: Generative Infinite-Vocabulary Transformers}, 
    author  = {Michael Tschannen, Cian Eastwood, Fabian Mentzer},
    year    = {2023},
    eprint  = {2312.02116},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

givt_pytorch-0.0.3.tar.gz (11.9 kB view hashes)

Uploaded Source

Built Distribution

givt_pytorch-0.0.3-py3-none-any.whl (11.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page