Skip to main content

The one-stop solution to easily integrate MoE & MoD layers into custom PyTorch code.

Project description

PyTorch Mixtures

A plug-and-play module for Mixture-of-Experts and Mixture-of-Depths in PyTorch. Your one-stop solution for inserting MoE/MoD layers into custom neural networks effortlessly!

--

Sources:

  1. Sparse Mixture of Experts, 2017
  2. Mixture of Depths, 2024

Features/Todo

  • Mixture of Experts
    • Top-k Routing
    • Expert Choice Routing
    • router-z loss
    • load-balancing loss
    • Testing of all MoE protocols - finished
  • Mixture of Depths
    • capacity-based routing around attention layer
    • Testing of MoD protocol - finished

Installation

Simply using pip3 install pytorch-mixtures will install this package. Note that this requires torch and einops to be pre-installed as dependencies. If you would like to build this package from source, run the following command:

git clone https://github.com/jaisidhsingh/pytorch-mixtures.git
cd pytorch-mixtures
pip3 install .

Usage

pytorch-mixtures is designed to effortlessly integrate into your existing code for any neural network of your choice, for example

from pytorch_mixtures.routing import ExpertChoiceRouter
from pytorch_mixtures.moe_layer import MoELayer

import torch
import torch.nn as nn


# define some config
BATCH_SIZE = 16
SEQ_LEN = 128
DIM = 768
NUM_EXPERTS = 8
CAPACITY_FACTOR = 1.25

# first initialize the router
router = ExpertChoiceRouter(dim=DIM, num_experts=NUM_EXPERTS)

# choose the experts you want: pytorch-mixtures just needs a list of `nn.Module` experts
# for e.g. our experts are just linear layers
experts=[nn.Linear(DIM, DIM) for _ in range(NUM_EXPERTS)]

# supply the router and experts to the MoELayer for modularity
moe = MoELayer(
    num_experts=NUM_EXPERTS, 
    router=router, 
    experts=experts, 
    capacity_factor=CAPACITY_FACTOR
)

# initialize some test input
x = torch.randn(B, N, D)

# pass through moe
moe_output = moe(x) # shape: [B, N, D]

You can also use this easily within your own nn.Module classes

from pytorch_mixtures.routing import ExpertChoiceRouter
from pytorch_mixtures.moe import MoELayer
from pytorch_mixtures.utils import MHSA # multi-head self-attention layer provided for ease
import torch
import torch.nn as nn


class CustomMoEAttentionBlock(nn.Module):
    def __init__(self, dim, num_heads, num_experts, capacity_factor, experts):
        super().__init__()
        self.attn = MHSA(dim, num_heads)
        self.router = ExpertChoiceRouter(dim, num_experts)
        self.moe = MoELayer(dim, router, experts, capacity_factor)
        
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
    
    def forward(self, x):
        x = self.norm1(self.attn(x) + x)
        x = self.norm2(self.moe(x) + x)
        return x


experts = [nn.Linear(768, 768) for _ in range(8)]
my_block = CustomMoEAttentionBlock(
    dim=768,
    num_heads=8,
    num_experts=8,
    capacity_factor=1.25,
    experts=experts
)

# some test input
x = torch.randn(16, 128, 768)
output = my_block(x) # shape: [16, 128, 768]

Citation

If you found this package useful, please cite it in your work:

@misc{JaisidhSingh2024,
  author = {Singh, Jaisidh},
  title = {pytorch-mixtures},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/jaisidhsingh/pytorch-mixtures}},
}

References

This package was built with the help of open-source code mentioned below:

  1. Google Flaxformer
  2. ST-MoE by Lucidrains
  3. MoD Huggingface blog

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-mixtures-0.1.0.tar.gz (7.6 kB view details)

Uploaded Source

Built Distribution

pytorch_mixtures-0.1.0-py3-none-any.whl (8.9 kB view details)

Uploaded Python 3

File details

Details for the file pytorch-mixtures-0.1.0.tar.gz.

File metadata

  • Download URL: pytorch-mixtures-0.1.0.tar.gz
  • Upload date:
  • Size: 7.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.6

File hashes

Hashes for pytorch-mixtures-0.1.0.tar.gz
Algorithm Hash digest
SHA256 d9d2f5edd1aed308bbb5e2b747db126ad52d614c2280e79ab31fe843633e2904
MD5 cc9fd8ffc929f76ec89bc80340d23c81
BLAKE2b-256 3e333432c98d0dc9197fa4021662d29704a89c4312397a620babc8f40e57b899

See more details on using hashes here.

File details

Details for the file pytorch_mixtures-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_mixtures-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d39154004da976e36063d3b02934f610e4a1981013e4bdc246bf8e2becbd7c00
MD5 191641a32fd260b81a1ec5420c1edf89
BLAKE2b-256 4254957d50bb5f83a1b2ffae18c4be9ef8982ae5789f69ce4c084c8b2a9880a1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page