Skip to main content

soft-mixture-of-experts

Project description

soft-mixture-of-experts

PyTorch implementation of Soft MoE by Google Brain in From Sparse to Soft Mixtures of Experts

soft-moe-layer

Thanks to lucidrains for his excellent x-transformers library! 🎉

The ViT implementations here are heavily based on his ViTransformerWrapper.

TODO

  • Implement Soft MoE layer (Usage, Code)
  • Example end-to-end Transformer models
    • vision transformer (Usage, Code)
    • language model (skip for now)
    • add to README
  • Set up unit tests
    • SoftMoE
    • Transformer layers
    • ViT models
  • Reproduce parameter counts from Table 3 (Ablations)
  • Reproduce inference benchmarks from Tables 1, 2 (Ablations)
  • Release on PyPI
    • Prerelease
    • Stable

Install

PyPI:

pip install soft-mixture-of-experts

From source:

pip install "soft-mixture-of-experts @ git+ssh://git@github.com/fkodom/soft-mixture-of-experts.git"

For contributors:

# Clone/fork this repo. Example:
gh repo clone fkodom/soft-mixture-of-experts
cd soft-mixture-of-experts
# Install all dev dependencies (tests etc.) in editable mode
pip install -e .[test]
# Setup pre-commit hooks
pre-commit install

Usage

Vision Transformers

Using the ViT and SoftMoEViT classes directly:

from soft_mixture_of_experts.vit import ViT, SoftMoEViT

vit = ViT(num_classes=1000, device="cuda")
moe_vit = SoftMoEViT(num_classes=1000, num_experts=32, device="cuda")

# image shape: (batch_size, channels, height, width)
image = torch.randn(1, 3, 224, 224, device="cuda")

# classification prediction
# output shape: (batch_size, num_classes)
y_vit = vit(image)
y_moe = moe_vit(image)

# feature embeddings
# output shape: (batch_size, num_patches, d_model)
features_vit = vit(image, return_features=True)
features_moe = moe_vit(image, return_features=True)

or using pre-configured models:

from soft_mixture_of_experts.vit import soft_moe_vit_small

# Available models:
# - soft_moe_vit_small
# - soft_moe_vit_base
# - soft_moe_vit_large
# - vit_small
# - vit_base
# - vit_large
# - vit_huge

# Roughly 930M parameters 👀
moe_vit = soft_moe_vit_small(num_classes=1000, device="cuda")

# Everything else works the same as above...

Transformer Layers

from soft_mixture_of_experts.transformer import (
    TransformerEncoder,
    TransformerEncoderLayer,
    TransformerDecoder,
    TransformerDecoderLayer,
)

encoder = TransformerEncoder(
    TransformerEncoderLayer(d_model=512, nhead=8),
    num_layers=6,
)
decoder = TransformerDecoder(
    TransformerDecoderLayer(d_model=512, nhead=8),
    num_layers=6,
)

# input shape: (batch_size, seq_len, d_model)
x = torch.randn(2, 128, 512, device="cuda")

mem = encoder(x)
print(mem.shape)
# torch.Size([2, 128, 512])

y = decoder(x, mem)
print(y.shape)
# torch.Size([2, 128, 512])

Soft MoE

import torch

from soft_mixture_of_experts.soft_moe import SoftMoE

# SoftMoE with 32 experts, 2 slots per expert (64 total):
moe = SoftMoE(
    in_features=512,
    out_features=512,
    num_experts=32,
    slots_per_expert=2,
    bias=False,  # optional, default: True
    device="cuda",  # optional, default: None
)

# input shape: (batch_size, seq_len, embed_dim)
x = torch.randn(2, 128, 512, device="cuda")

y = moe(x)
print(y.shape)
# torch.Size([2, 128, 512])

Ablations

I closely reproduce the parameter counts and (relative) inference times from the paper.

Table 3

All models are benchmarked with:

batch_size = 8  # see note below
image_size = 224
num_channels = 3
num_classes = 21000  # as in ImageNet 21k

$\dagger$ The authors benchmark "eval ms/img" using TPUv3, and I use single A100 40GB. The authors also are not clear on the batch size used for inference. In Figure 6, they specifically mention using batch size 8. So, I assume a batch size of 8, and observe that inference times are similar to what is reported in the paper.

Model Params Params
(paper)
Eval ms/img $\dagger$ Eval ms/img
(paper)
ViT S/16 30 M 33 M 0.9 0.5
Soft MoE S/16 128E 932 M 933 M 1.3 0.7
Soft MoE S/14 128E 1.8 B 1.8 B 1.5 0.9
ViT B/16 102 M 108 M 1.0 0.9
Soft MoE B/16 128E 3.7 B 3.7 B 1.5 1.5
ViT L/16 325 M 333 M 1.8 4.9
Soft MoE L/16 128E 13.1 B 13.1 B 3.5 4.8

Test

Tests run automatically through GitHub Actions on each git push.

You can also run tests manually with pytest:

pytest

Citations

@misc{puigcerver2023sparse,
      title={From Sparse to Soft Mixtures of Experts}, 
      author={Joan Puigcerver and Carlos Riquelme and Basil Mustafa and Neil Houlsby},
      year={2023},
      eprint={2308.00951},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

soft-mixture-of-experts-0.2.0.tar.gz (124.9 kB view details)

Uploaded Source

Built Distribution

soft_mixture_of_experts-0.2.0-py3-none-any.whl (118.4 kB view details)

Uploaded Python 3

File details

Details for the file soft-mixture-of-experts-0.2.0.tar.gz.

File metadata

File hashes

Hashes for soft-mixture-of-experts-0.2.0.tar.gz
Algorithm Hash digest
SHA256 283a4f39f7f89adb5691adb6c9fe25510489afd979c8a973e79bda87b3813d6f
MD5 1218d4484cf482f379a78678b768bfb2
BLAKE2b-256 54fa622f13dc18cd8cf538bdae5c8589530cd132ea1a4cc051237195b3ccc4e0

See more details on using hashes here.

File details

Details for the file soft_mixture_of_experts-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for soft_mixture_of_experts-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c1b2f02d78a2e57100f92982849b7b41d2e791217aead1cde7e4746c461950d5
MD5 4d0e59cd3baf1f93b7bdc3394abca4d3
BLAKE2b-256 88922fddc540b6506e7c6b8fb2f55156dc135f024abcd201be8b6a2a38e86d5a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page