Skip to main content

The one-stop solution to easily integrate MoE & MoD layers into custom PyTorch code.

Project description

PyTorch Mixtures [PyPi]

A plug-and-play module for Mixture-of-Experts and in PyTorch. Your one-stop solution for inserting MoE layers into custom neural networks effortlessly!

Installation

Simply using pip3 install pytorch-mixtures will install this package. Note that this requires torch and einops to be pre-installed as dependencies. If you would like to build this package from source, run the following command:

git clone https://github.com/jaisidhsingh/pytorch-mixtures.git
cd pytorch-mixtures
pip3 install .

Usage

pytorch-mixtures is designed to effortlessly integrate into your existing code for any neural network of your choice, for example

import torch
from pytorch_mixtures import TopkMoE, ExpertChoiceMoE, MoEConfig


BATCH_SIZE = 16
SEQ_LEN = 128
DIM = 768
NUM_EXPERTS = 8
CAPACITY_FACTOR = 1.25

config = MoEConfig(
    hidden_dim=DIM,
    intermediate_dim=DIM * 4,
    num_experts=NUM_EXPERTS,
    expert_fn="ff",
    expert_act="silu",
    router_fn="topk",
    capacity_factor=CAPACITY_FACTOR,
    topk=2,
    dtype=torch.float32
)

moe = TopkMoE(config)

x = torch.randn(BATCH_SIZE, SEQ_LEN, DIM)
output = moe(x)  # shape: [BATCH_SIZE, SEQ_LEN, DIM]

You can also use this easily within your own nn.Module classes

import torch
import torch.nn as nn
from pytorch_mixtures import TopkMoE, ExpertChoiceMoE, MoEConfig


class CustomMoEBlock(nn.Module):
    def __init__(self, dim, num_experts, capacity_factor):
        super().__init__()
        self.config = MoEConfig(
            hidden_dim=dim,
            intermediate_dim=dim * 4,
            num_experts=num_experts,
            expert_fn="ff",
            expert_act="silu",
            router_fn="topk",
            capacity_factor=capacity_factor,
            topk=2,
            dtype=torch.float32
        )
        self.moe = TopkMoE(self.config)
        
    def forward(self, x):
        return self.moe(x)


my_block = CustomMoEBlock(
    dim=768,
    num_experts=8,
    capacity_factor=1.25
)

x = torch.randn(16, 128, 768)
output = my_block(x)  # output shape: [16, 128, 768]

Testing

This package provides the user to run a simple test for the MoE code. If all experts are initialized as the same module, the output of the MoE should be equal to the input tensor passed through any expert. The users can run these tests for themselves by running the following:

from pytorch_mixtures import run_tests

run_tests()

Or from the command line:

python -m tests.tests

Note: All tests pass correctly. If a test fails, it is likely due to an edge case in the random initializations. Try again, and it will pass.

Citation

If you found this package useful, please cite it in your work:

@misc{JaisidhSingh2024,
  author = {Singh, Jaisidh},
  title = {pytorch-mixtures},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/jaisidhsingh/pytorch-mixtures}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_mixtures-0.1.4.tar.gz (6.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_mixtures-0.1.4-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_mixtures-0.1.4.tar.gz.

File metadata

  • Download URL: pytorch_mixtures-0.1.4.tar.gz
  • Upload date:
  • Size: 6.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for pytorch_mixtures-0.1.4.tar.gz
Algorithm Hash digest
SHA256 e255783422df6f8b803ec8fe683fa63e9321736edbec231faf05b982c83d4c6a
MD5 cd617bee2f782a2e79f17d9a8807f3be
BLAKE2b-256 196c1f4d99ba28600b15e7006ebbc1b83914e67d1ff64b4489b6ff9cd4c85581

See more details on using hashes here.

File details

Details for the file pytorch_mixtures-0.1.4-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_mixtures-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 f0fddde4ae43413707250dbc5f23046fdbef37c82539ae4b7ce46030c29553f2
MD5 4924051db1fdea0621dfad0d68b864f8
BLAKE2b-256 61f7dcf1dc28d936c21dd3c615ecb2c32199fad69cbdf4a5b8b081a11834909d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page