Skip to main content

Gato: A Generalist Agent

Project description

Multi-Modality

Gato: A Generalist Agent

[Deepmind Publication] [arXiv Paper]

aper.

Installation

$ pip install gato-torch
import torch
from gato import Gato

#create model instance
gato = Gato(input_dim=768,
            img_patch_size=16,
            token_sequence_length=1024,
            vocabulary_size=32000,
            actions_size=1024,
            continuous_values_size=1024,
            num_transformer_blocks=8,
            num_attention_heads=24,
            layer_width=768,
            feedforward_hidden_size=3072,
            key_value_size=32,
            dropout_rate=0.1,
            num_group_norm_groups=32,
            discretize_depth=128,
            local_position_encoding_size=512,
            max_seq_len=8192)


#fake inputs for Gato
input_dim = config.input_dim
input_ids = torch.cat([
    torch.rand((1, 1, input_dim)) for _ in range(20)] + # 20 image patches
    [torch.full((1, 1, input_dim), 0.25), #continous value]
     torch.full((1, 1, input_dim), 624.0)] + #discrete (actions, texts)
     [torch.rand((1, 1, input_dim)) for _ in range(20)] + #20 image patches
     [torch.full((1, 1, input_dim), 0.12), #continous value
      torch.full((1, 1, input_dim), 295.0)], #discrete( actions, text)
      dim=1)

encoding = torch.tensor([
    [0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 1, 2]
])

row_pos = (
    torch.tensor([[0.00, 0.25, 0.50, 0.75, 0, 0, 0.00, 0.25, 0.50, 0.75, 0, 0]]),  # pos_from
    torch.tensor([[0.25, 0.50, 0.75, 1.00, 0, 0, 0.25, 0.50, 0.75, 1.00, 0, 0]])  # pos_to
)

col_pos = (
    torch.tensor([[0.00, 0.00, 0.00, 0.80, 0, 0, 0.00, 0.00, 0.00, 0.80, 0, 0]]),  # pos_from
    torch.tensor([[0.20, 0.20, 0.20, 1.00, 0, 0, 0.20, 0.20, 0.20, 1.00, 0, 0]])  # pos_to
)


obs = (
    torch.tensor([[ 0,  1,  2, 19, 20, 21,  0,  1,  2, 19, 20, 21]]),  # obs token
    torch.tensor([[ 1,  1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0]])  # obs token masking (for action tokens)
)


hidden_states = gato((input_ids, (encoding, row_pos, col_pos), obs))

Dataset and Model Architecture

gato dataset and model architecture

Paper Reviews

Full Episode Sequence

gato dataset architecture

Architecture Variants

Appendix C.1. Transformer Hyperparameters

In the paper, Deepmind tested Gato with 3 architecture variants, 1.18B, 364M, and 79M.
I have named them as large(), baseline() and small() respectively in GatoConfig.

Hyperparameters Large(1.18B) Baseline(364M) Small(79M)
Transformer blocks 24 12 8
Attention heads 16 12 24
Layer width 2048 1536 768
Feedforward hidden size 8192 6144 3072
Key/value size 128 128 32

Residual Embedding

Appendix C.2. Embedding Function

There are no mentions that how many residual networks must be stacked for token embeddings.
Therefore, I remain configurable in GatoConfig.

Whatever how many residual layers are existing, full-preactivation is a key.

The blocks are consisted of:

  • Version 2 ResNet architecture (based on ResNet50V2)
  • GroupNorm (instead of LayerNorm)
  • GeLU (instead of ReLU)

Position Encodings

Appendix C.3. Position Encodings

Patch Position Encodings

Like Vision Transformer (ViT) by Google, Gato takes the input images as raster-ordered 16x16 patches.
Unlike the Vision Transformer model, however, Gato divides its patch encoding strategy into 2 phases, training and evaluation.

For high-performance computation in TensorFlow, I have used the following expressions.

$C$ and $R$ mean column and row-wise, and $F$ and $T$ mean from and to respectively.

$$ \begin{align} v^R_F &= \begin{bmatrix} 0 & 32 & 64 & 96 \end{bmatrix} \ v^R_T &= \begin{bmatrix} 32 & 64 & 96 & 128 \end{bmatrix} \ v^C_F &= \begin{bmatrix} 0 & 26 & 51 & 77 & 102 \end{bmatrix} \ v^C_T &= \begin{bmatrix} 26 & 51 & 77 & 102 & 128 \end{bmatrix} \ \ P_R &= \begin{cases} \mathsf{if} \ \mathsf{training} & v^R_F + \mathsf{uniform}(v^R_T - v^R_F) \ \mathsf{otherwise} & \mathsf{round}(\frac{v^R_F + v^R_T}{2}) \end{cases} \ P_C &= \begin{cases} \mathsf{if} \ \mathsf{training} & v^C_F + \mathsf{uniform}(v^C_T - v^C_F) \ \mathsf{otherwise} & \mathsf{round}(\frac{v^C_F + v^C_T}{2}) \end{cases} \ \ E^R_P &= P_R \cdot 1^{\mathsf{T}}_C \ E^C_P &= 1^{\mathsf{T}}_R \cdot P_C \ \ \therefore E &= E_I + E^R_P + E^C_P \end{align} $$

Local Observation Position Encodings

In the definition of Appendix B., text tokens, image patch tokens, and discrete & continuous values are observation tokens
When Gato receives those values, they must be encoded with their own (local) time steps.

Contributing

We welcome all contributions, please either submit a pull request or submit issues in the Agora discord

License

Licensed under the MIT license.

Roadmap:

  • Get functional prototype

  • Integrate ALIBI, multi query, qk norm and other SOTA stuff

  • integrate action tokens

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gato_torch-0.0.2.tar.gz (10.4 kB view details)

Uploaded Source

Built Distribution

gato_torch-0.0.2-py3-none-any.whl (8.1 kB view details)

Uploaded Python 3

File details

Details for the file gato_torch-0.0.2.tar.gz.

File metadata

  • Download URL: gato_torch-0.0.2.tar.gz
  • Upload date:
  • Size: 10.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0

File hashes

Hashes for gato_torch-0.0.2.tar.gz
Algorithm Hash digest
SHA256 e29746f33ef7406934bd7e65e0dd916baa9b1ee0fd35169759631a910f64b317
MD5 c909c2d9eea5e79d316513b713687731
BLAKE2b-256 3a954a1dd1e9725359c61ec1fffcb609e7a4c832bebee3484865a1ffd64dab60

See more details on using hashes here.

File details

Details for the file gato_torch-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: gato_torch-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 8.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0

File hashes

Hashes for gato_torch-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 f040acf3c689966c0ce1bb64b3e388db43a760e481765407d3fbfa88cd517f68
MD5 3d11f63076814356db8f013dbf5a8a1a
BLAKE2b-256 9507e60fa544bd1a80f3dc384c5ef558ecb82d3bf6ff321c7f85fa52cec41309

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page