Skip to main content

A PyTorch implementation of "MetaFormer Baselines" with optional extensions.

Project description

🥞 x-Metaformer

A PyTorch implementation of "MetaFormer Baselines" with optional extensions.
We support various self-supervised pretraining approaches such as BarlowTwins, MoCoV3 or VICReg (see x_metaformer.pretraining).

Setup

Simply run: pip install x-metaformer

Example

import torch
from x_metaformer import CAFormer, ConvFormer


my_metaformer = CAFormer(
    in_channels=3,
    depths=(3, 3, 9, 3),
    dims=(64, 128, 320, 512),
    multi_query_attention=False,  # share keys and values across query heads
    use_seqpool=True,  # use sequence pooling vom CCT
    init_kernel_size=3,
    init_stride=2,
    drop_path_rate=0.4,
    norm='ln',  # ln, bn, rms (layernorm, batchnorm, rmsnorm)
    use_grn_mlp=True,  # use global response norm in mlps
    use_dual_patchnorm=False,  # norm on both sides for the patch embedding
    use_pos_emb=True,  # use 2d sinusodial positional embeddings
    head_dim=32,
    num_heads=4,
    attn_dropout=0.1,
    proj_dropout=0.1,
    patchmasking_prob=0.05,  # replace 5% of the initial tokens with a </mask> token
    scale_value=1.0, # scale attention logits by this value
    trainable_scale=False, # if scale can be trained
    num_mem_vecs=0, # additional memory vectors (in the attention layers)
    sparse_topk=0,  # sparsify - keep only top k values (in the attention layers)
    l2=False,   # l2 norm on tokens (in the attention layers) 
    improve_locality=False,  # remove attention on own token
    use_starreglu=False  # use gated StarReLU
)

x   = torch.randn(64, 3, 64, 64)  # B C H W
out = my_metaformer(x, return_embeddings=False)  # returns average pooled tokens

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

x-Metaformer-0.3.1.tar.gz (11.6 kB view hashes)

Uploaded Source

Built Distribution

x_Metaformer-0.3.1-py3-none-any.whl (16.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page