A partial implementation of Generative Infinite Vocabulary Transformer (GIVT) from Google Deepmind, in PyTorch.
Project description
GIVT-PyTorch
A partial implementation of Generative Infinite Vocabulary Transformer (GIVT) from Google Deepmind, in PyTorch.
This repo only implements the causal version of GIVT, and does away with the k mixtures predictions or the use of the full covariance matrix, as for most purposes they did not yield significantly better results.
The decoder transformer implementation is also modernized, adopting a Llama style architecture with gated MLPs, SiLU, RMSNorm, and RoPE.
Install
# for inference
pip install .
# for training/development
pip install -e '.[train]'
Usage
from givt_pytorch import GIVT
# load pretrained checkpoint
model = GIVT.from_pretrained('elyxlz/givt-test')
latents = torch.randn((4, 500, 32)) # vae latents (bs, seq_len, size)
loss = model.forward(latents).loss # NLL Loss
prompt = torch.randn((50, 32)) # no batched inference implemented
generated = model.generate(
prompt=prompt,
max_len=500,
cfg_scale=0.5,
temperature=0.95,
) # (500, 32)
Training
Define a config file in configs/
, such as this one:
from givt_pytorch import (
GIVT,
GIVTConfig,
DummyDataset,
Trainer,
TrainConfig
)
model = GIVT(GIVTConfig())
dataset = DummyDataset()
trainer = Trainer(
model=model,
dataset=dataset,
train_config=TrainConfig()
)
Create an accelerate config.
accelerate config
And then run the training.
accelerate launch train.py {config_name}
TODO
- Test out with latents from an audio vae
- Add CFG with rejection sampling
References
@misc{litgpt2024,
title={lit-gpt on GitHub},
url={https://github.com/Lightning-AI/lit-gpt},
year={2024}
@misc{tschannen2023givt,
title = {GIVT: Generative Infinite-Vocabulary Transformers},
author = {Michael Tschannen, Cian Eastwood, Fabian Mentzer},
year = {2023},
eprint = {2312.02116},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for givt_pytorch-0.0.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3e88e10598998a55fb28eee8b8552f604f623fdbc92c83b8440dd2071033e69c |
|
MD5 | 5a5b9ab4748159ed7d6a530fc14441a6 |
|
BLAKE2b-256 | 607031a60806155cb2bc18dc8cceb5f1954982620eddf741b40492201552364d |