Gato: A Generalist Agent
Project description
Gato: A Generalist Agent
[Deepmind Publication] [arXiv Paper]
aper.
Installation
$ pip install gato-torch
import torch
from gato import Gato
#create model instance
gato = Gato(input_dim=768,
img_patch_size=16,
token_sequence_length=1024,
vocabulary_size=32000,
actions_size=1024,
continuous_values_size=1024,
num_transformer_blocks=8,
num_attention_heads=24,
layer_width=768,
feedforward_hidden_size=3072,
key_value_size=32,
dropout_rate=0.1,
num_group_norm_groups=32,
discretize_depth=128,
local_position_encoding_size=512,
max_seq_len=8192)
#fake inputs for Gato
input_dim = config.input_dim
input_ids = torch.cat([
torch.rand((1, 1, input_dim)) for _ in range(20)] + # 20 image patches
[torch.full((1, 1, input_dim), 0.25), #continous value]
torch.full((1, 1, input_dim), 624.0)] + #discrete (actions, texts)
[torch.rand((1, 1, input_dim)) for _ in range(20)] + #20 image patches
[torch.full((1, 1, input_dim), 0.12), #continous value
torch.full((1, 1, input_dim), 295.0)], #discrete( actions, text)
dim=1)
encoding = torch.tensor([
[0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 1, 2]
])
row_pos = (
torch.tensor([[0.00, 0.25, 0.50, 0.75, 0, 0, 0.00, 0.25, 0.50, 0.75, 0, 0]]), # pos_from
torch.tensor([[0.25, 0.50, 0.75, 1.00, 0, 0, 0.25, 0.50, 0.75, 1.00, 0, 0]]) # pos_to
)
col_pos = (
torch.tensor([[0.00, 0.00, 0.00, 0.80, 0, 0, 0.00, 0.00, 0.00, 0.80, 0, 0]]), # pos_from
torch.tensor([[0.20, 0.20, 0.20, 1.00, 0, 0, 0.20, 0.20, 0.20, 1.00, 0, 0]]) # pos_to
)
obs = (
torch.tensor([[ 0, 1, 2, 19, 20, 21, 0, 1, 2, 19, 20, 21]]), # obs token
torch.tensor([[ 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0]]) # obs token masking (for action tokens)
)
hidden_states = gato((input_ids, (encoding, row_pos, col_pos), obs))
Dataset and Model Architecture
Paper Reviews
Full Episode Sequence
Architecture Variants
Appendix C.1. Transformer Hyperparameters
In the paper, Deepmind tested Gato with 3 architecture variants, 1.18B
, 364M
, and 79M
.
I have named them as large()
, baseline()
and small()
respectively in GatoConfig
.
Hyperparameters | Large(1.18B) | Baseline(364M) | Small(79M) |
---|---|---|---|
Transformer blocks | 24 | 12 | 8 |
Attention heads | 16 | 12 | 24 |
Layer width | 2048 | 1536 | 768 |
Feedforward hidden size | 8192 | 6144 | 3072 |
Key/value size | 128 | 128 | 32 |
Residual Embedding
Appendix C.2. Embedding Function
There are no mentions that how many residual networks must be stacked for token embeddings.
Therefore, I remain configurable in GatoConfig
.
Whatever how many residual layers are existing, full-preactivation is a key.
The blocks are consisted of:
- Version 2 ResNet architecture (based on ResNet50V2)
- GroupNorm (instead of LayerNorm)
- GeLU (instead of ReLU)
Position Encodings
Appendix C.3. Position Encodings
Patch Position Encodings
Like Vision Transformer (ViT) by Google, Gato takes the input images as raster-ordered 16x16 patches.
Unlike the Vision Transformer model, however, Gato divides its patch encoding strategy into 2 phases, training and evaluation.
For high-performance computation in TensorFlow, I have used the following expressions.
$C$ and $R$ mean column and row-wise, and $F$ and $T$ mean from
and to
respectively.
$$ \begin{align} v^R_F &= \begin{bmatrix} 0 & 32 & 64 & 96 \end{bmatrix} \ v^R_T &= \begin{bmatrix} 32 & 64 & 96 & 128 \end{bmatrix} \ v^C_F &= \begin{bmatrix} 0 & 26 & 51 & 77 & 102 \end{bmatrix} \ v^C_T &= \begin{bmatrix} 26 & 51 & 77 & 102 & 128 \end{bmatrix} \ \ P_R &= \begin{cases} \mathsf{if} \ \mathsf{training} & v^R_F + \mathsf{uniform}(v^R_T - v^R_F) \ \mathsf{otherwise} & \mathsf{round}(\frac{v^R_F + v^R_T}{2}) \end{cases} \ P_C &= \begin{cases} \mathsf{if} \ \mathsf{training} & v^C_F + \mathsf{uniform}(v^C_T - v^C_F) \ \mathsf{otherwise} & \mathsf{round}(\frac{v^C_F + v^C_T}{2}) \end{cases} \ \ E^R_P &= P_R \cdot 1^{\mathsf{T}}_C \ E^C_P &= 1^{\mathsf{T}}_R \cdot P_C \ \ \therefore E &= E_I + E^R_P + E^C_P \end{align} $$
Local Observation Position Encodings
In the definition of Appendix B., text tokens, image patch tokens, and discrete & continuous values are observation tokens
When Gato receives those values, they must be encoded with their own (local) time steps.
Contributing
License
Licensed under the MIT license.
Roadmap:
-
Get functional prototype
-
Integrate ALIBI, multi query, qk norm and other SOTA stuff
-
integrate action tokens
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file gato_torch-0.0.2.tar.gz
.
File metadata
- Download URL: gato_torch-0.0.2.tar.gz
- Upload date:
- Size: 10.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e29746f33ef7406934bd7e65e0dd916baa9b1ee0fd35169759631a910f64b317 |
|
MD5 | c909c2d9eea5e79d316513b713687731 |
|
BLAKE2b-256 | 3a954a1dd1e9725359c61ec1fffcb609e7a4c832bebee3484865a1ffd64dab60 |
File details
Details for the file gato_torch-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: gato_torch-0.0.2-py3-none-any.whl
- Upload date:
- Size: 8.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f040acf3c689966c0ce1bb64b3e388db43a760e481765407d3fbfa88cd517f68 |
|
MD5 | 3d11f63076814356db8f013dbf5a8a1a |
|
BLAKE2b-256 | 9507e60fa544bd1a80f3dc384c5ef558ecb82d3bf6ff321c7f85fa52cec41309 |