soft-mixture-of-experts
Project description
soft-mixture-of-experts
PyTorch implementation of Soft MoE by Google Brain in From Sparse to Soft Mixtures of Experts
Thanks to lucidrains for his excellent
x-transformers
library! 🎉The ViT implementations here are heavily based on his ViTransformerWrapper.
TODO
- Implement Soft MoE layer (Usage, Code)
- Example end-to-end Transformer models
- Set up unit tests
- SoftMoE
- Transformer layers
- ViT models
- Reproduce parameter counts from Table 3
- Reproduce inference benchmarks from Tables 1, 2
- Release on PyPI
- Prerelease
- Stable
Install
PyPI:
work in progress
From source:
pip install "soft-mixture-of-experts @ git+ssh://git@github.com/fkodom/soft-mixture-of-experts.git"
For contributors:
# Clone/fork this repo. Example:
gh repo clone fkodom/soft-mixture-of-experts
cd soft-mixture-of-experts
# Install all dev dependencies (tests etc.) in editable mode
pip install -e .[test]
# Setup pre-commit hooks
pre-commit install
Usage
Vision Transformers
Using the ViT
and SoftMoEViT
classes directly:
from soft_mixture_of_experts.vit import ViT, SoftMoEViT
vit = ViT(num_classes=1000, device="cuda")
moe_vit = SoftMoEViT(num_classes=1000, num_experts=32, device="cuda")
# image shape: (batch_size, channels, height, width)
image = torch.randn(1, 3, 224, 224, device="cuda")
# classification prediction
# output shape: (batch_size, num_classes)
y_vit = vit(image)
y_moe = moe_vit(image)
# feature embeddings
# output shape: (batch_size, num_patches, d_model)
features_vit = vit(image, return_features=True)
features_moe = moe_vit(image, return_features=True)
or using pre-configured models:
from soft_mixture_of_experts.vit import soft_moe_vit_small
# Available models:
# - soft_moe_vit_small
# - soft_moe_vit_base
# - soft_moe_vit_large
# - vit_small
# - vit_base
# - vit_large
# - vit_huge
# Roughly 930M parameters 👀
moe_vit = soft_moe_vit_small(num_classes=1000, device="cuda")
# Everything else works the same as above...
Transformer Layers
from soft_mixture_of_experts.transformer import (
TransformerEncoder,
TransformerEncoderLayer,
TransformerDecoder,
TransformerDecoderLayer,
)
encoder = TransformerEncoder(
TransformerEncoderLayer(d_model=512, nhead=8),
num_layers=6,
)
decoder = TransformerDecoder(
TransformerDecoderLayer(d_model=512, nhead=8),
num_layers=6,
)
# input shape: (batch_size, seq_len, d_model)
x = torch.randn(2, 128, 512, device="cuda")
mem = encoder(x)
print(mem.shape)
# torch.Size([2, 128, 512])
y = decoder(x, mem)
print(y.shape)
# torch.Size([2, 128, 512])
Soft MoE
import torch
from soft_mixture_of_experts.soft_moe import SoftMoE
# SoftMoE with 32 experts, 2 slots per expert (64 total):
moe = SoftMoE(
embed_dim=512,
num_experts=32,
slots_per_expert=2,
bias=False, # optional, default: True
device="cuda", # optional, default: None
)
# input shape: (batch_size, seq_len, embed_dim)
x = torch.randn(2, 128, 512, device="cuda")
y = moe(x)
print(y.shape)
# torch.Size([2, 128, 512])
Test
Tests run automatically through GitHub Actions on each git push
.
You can also run tests manually with pytest
:
pytest
Citations
@misc{puigcerver2023sparse,
title={From Sparse to Soft Mixtures of Experts},
author={Joan Puigcerver and Carlos Riquelme and Basil Mustafa and Neil Houlsby},
year={2023},
eprint={2308.00951},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for soft-mixture-of-experts-0.1.0rc1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | b6a7322065c395fa93f6e605e40e073366080f86925c62bbea06b66b4ee98cb5 |
|
MD5 | d4df6207f1e36d356fc2e3a8931db188 |
|
BLAKE2b-256 | 16c992a035a9ac640c9160f01a7e6f453a51b10f26a12e674a4a9dda5a30a113 |
Close
Hashes for soft_mixture_of_experts-0.1.0rc1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0a095009a0333bd934f2180df0f76d110ea501196faa822cf36e5393216d59d9 |
|
MD5 | fe9bf9cd60dcc48dfe190a40f530d08b |
|
BLAKE2b-256 | eafa95a23320703cc89cb05a5efb4055040915419159f9cafe700ac9da0887da |