Gato: A Generalist Agent
Project description
Gato: A Generalist Agent
[Deepmind Publication] [arXiv Paper]
aper.
Installation
$ pip install gato-torch
import torch
from gato import Gato
#create model instance
gato = Gato(input_dim=768,
img_patch_size=16,
token_sequence_length=1024,
vocabulary_size=32000,
actions_size=1024,
continuous_values_size=1024,
num_transformer_blocks=8,
num_attention_heads=24,
layer_width=768,
feedforward_hidden_size=3072,
key_value_size=32,
dropout_rate=0.1,
num_group_norm_groups=32,
discretize_depth=128,
local_position_encoding_size=512,
max_seq_len=8192)
#fake inputs for Gato
input_dim = config.input_dim
input_ids = torch.cat([
torch.rand((1, 1, input_dim)) for _ in range(20)] + # 20 image patches
[torch.full((1, 1, input_dim), 0.25), #continous value]
torch.full((1, 1, input_dim), 624.0)] + #discrete (actions, texts)
[torch.rand((1, 1, input_dim)) for _ in range(20)] + #20 image patches
[torch.full((1, 1, input_dim), 0.12), #continous value
torch.full((1, 1, input_dim), 295.0)], #discrete( actions, text)
dim=1)
encoding = torch.tensor([
[0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 1, 2]
])
row_pos = (
torch.tensor([[0.00, 0.25, 0.50, 0.75, 0, 0, 0.00, 0.25, 0.50, 0.75, 0, 0]]), # pos_from
torch.tensor([[0.25, 0.50, 0.75, 1.00, 0, 0, 0.25, 0.50, 0.75, 1.00, 0, 0]]) # pos_to
)
col_pos = (
torch.tensor([[0.00, 0.00, 0.00, 0.80, 0, 0, 0.00, 0.00, 0.00, 0.80, 0, 0]]), # pos_from
torch.tensor([[0.20, 0.20, 0.20, 1.00, 0, 0, 0.20, 0.20, 0.20, 1.00, 0, 0]]) # pos_to
)
obs = (
torch.tensor([[ 0, 1, 2, 19, 20, 21, 0, 1, 2, 19, 20, 21]]), # obs token
torch.tensor([[ 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0]]) # obs token masking (for action tokens)
)
hidden_states = gato((input_ids, (encoding, row_pos, col_pos), obs))
Dataset and Model Architecture
Paper Reviews
Full Episode Sequence
Architecture Variants
Appendix C.1. Transformer Hyperparameters
In the paper, Deepmind tested Gato with 3 architecture variants, 1.18B
, 364M
, and 79M
.
I have named them as large()
, baseline()
and small()
respectively in GatoConfig
.
Hyperparameters | Large(1.18B) | Baseline(364M) | Small(79M) |
---|---|---|---|
Transformer blocks | 24 | 12 | 8 |
Attention heads | 16 | 12 | 24 |
Layer width | 2048 | 1536 | 768 |
Feedforward hidden size | 8192 | 6144 | 3072 |
Key/value size | 128 | 128 | 32 |
Residual Embedding
Appendix C.2. Embedding Function
There are no mentions that how many residual networks must be stacked for token embeddings.
Therefore, I remain configurable in GatoConfig
.
Whatever how many residual layers are existing, full-preactivation is a key.
The blocks are consisted of:
- Version 2 ResNet architecture (based on ResNet50V2)
- GroupNorm (instead of LayerNorm)
- GeLU (instead of ReLU)
Position Encodings
Appendix C.3. Position Encodings
Patch Position Encodings
Like Vision Transformer (ViT) by Google, Gato takes the input images as raster-ordered 16x16 patches.
Unlike the Vision Transformer model, however, Gato divides its patch encoding strategy into 2 phases, training and evaluation.
For high-performance computation in TensorFlow, I have used the following expressions.
$C$ and $R$ mean column and row-wise, and $F$ and $T$ mean from
and to
respectively.
$$ \begin{align} v^R_F &= \begin{bmatrix} 0 & 32 & 64 & 96 \end{bmatrix} \ v^R_T &= \begin{bmatrix} 32 & 64 & 96 & 128 \end{bmatrix} \ v^C_F &= \begin{bmatrix} 0 & 26 & 51 & 77 & 102 \end{bmatrix} \ v^C_T &= \begin{bmatrix} 26 & 51 & 77 & 102 & 128 \end{bmatrix} \ \ P_R &= \begin{cases} \mathsf{if} \ \mathsf{training} & v^R_F + \mathsf{uniform}(v^R_T - v^R_F) \ \mathsf{otherwise} & \mathsf{round}(\frac{v^R_F + v^R_T}{2}) \end{cases} \ P_C &= \begin{cases} \mathsf{if} \ \mathsf{training} & v^C_F + \mathsf{uniform}(v^C_T - v^C_F) \ \mathsf{otherwise} & \mathsf{round}(\frac{v^C_F + v^C_T}{2}) \end{cases} \ \ E^R_P &= P_R \cdot 1^{\mathsf{T}}_C \ E^C_P &= 1^{\mathsf{T}}_R \cdot P_C \ \ \therefore E &= E_I + E^R_P + E^C_P \end{align} $$
Local Observation Position Encodings
In the definition of Appendix B., text tokens, image patch tokens, and discrete & continuous values are observation tokens
When Gato receives those values, they must be encoded with their own (local) time steps.
Contributing
License
Licensed under the MIT license.
Roadmap:
-
Get functional prototype
-
Integrate ALIBI, multi query, qk norm and other SOTA stuff
-
integrate action tokens
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for gato_torch-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f040acf3c689966c0ce1bb64b3e388db43a760e481765407d3fbfa88cd517f68 |
|
MD5 | 3d11f63076814356db8f013dbf5a8a1a |
|
BLAKE2b-256 | 9507e60fa544bd1a80f3dc384c5ef558ecb82d3bf6ff321c7f85fa52cec41309 |