A partial implementation of Generative Infinite Vocabulary Transformer (GIVT) from Google Deepmind, in PyTorch.
Project description
GIVT-PyTorch
A partial implementation of Generative Infinite Vocabulary Transformer (GIVT) from Google Deepmind, in PyTorch.
This repo only implements the causal version of GIVT, and does away with the k mixtures predictions or the use of the full covariance matrix, as for most purposes they did not yield significantly better results.
The decoder transformer implementation is also modernized, adopting a Llama style architecture with gated MLPs, SiLU, RMSNorm, and RoPE.
Install
# for inference
pip install .
# for training/development
pip install -e '.[train]'
Usage
from givt_pytorch import GIVT
# load pretrained checkpoint
model = GIVT.from_pretrained('elyxlz/givt-test')
latents = torch.randn((4, 500, 32)) # vae latents (bs, seq_len, size)
loss = model.forward(latents).loss # NLL Loss
prompt = torch.randn((50, 32)) # no batched inference implemented
generated = model.generate(
prompt=prompt,
max_len=500,
cfg_scale=0.5,
temperature=0.95,
) # (500, 32)
Training
Define a config file in configs/
, such as this one:
from givt_pytorch import (
GIVT,
GIVTConfig,
DummyDataset,
Trainer,
TrainConfig
)
model = GIVT(GIVTConfig())
dataset = DummyDataset()
trainer = Trainer(
model=model,
dataset=dataset,
train_config=TrainConfig()
)
Create an accelerate config.
accelerate config
And then run the training.
accelerate launch train.py {config_name}
TODO
- Test out with latents from an audio vae
- Add CFG with rejection sampling
References
@misc{litgpt2024,
title={lit-gpt on GitHub},
url={https://github.com/Lightning-AI/lit-gpt},
year={2024}
@misc{tschannen2023givt,
title = {GIVT: Generative Infinite-Vocabulary Transformers},
author = {Michael Tschannen, Cian Eastwood, Fabian Mentzer},
year = {2023},
eprint = {2312.02116},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file givt_pytorch-0.0.3.tar.gz
.
File metadata
- Download URL: givt_pytorch-0.0.3.tar.gz
- Upload date:
- Size: 11.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 89aba1029ae2d3e410c92d2def8b2e3cccc22a3c385383b649b73fb40b9b4e4c |
|
MD5 | 71b6ba1b6c88ef1ac59c17f155293319 |
|
BLAKE2b-256 | f4d544be0a22d72a9bc47757ec5a36e2fd900ffebfd6d2941c6bf6733b2c443d |
File details
Details for the file givt_pytorch-0.0.3-py3-none-any.whl
.
File metadata
- Download URL: givt_pytorch-0.0.3-py3-none-any.whl
- Upload date:
- Size: 11.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3e88e10598998a55fb28eee8b8552f604f623fdbc92c83b8440dd2071033e69c |
|
MD5 | 5a5b9ab4748159ed7d6a530fc14441a6 |
|
BLAKE2b-256 | 607031a60806155cb2bc18dc8cceb5f1954982620eddf741b40492201552364d |