Skip to main content

The one-stop solution to easily integrate MoE & MoD layers into custom PyTorch code.

Project description

PyTorch Mixtures [PyPi]

A plug-and-play module for Mixture-of-Experts and in PyTorch. Your one-stop solution for inserting MoE layers into custom neural networks effortlessly!

Installation

Simply using pip3 install pytorch-mixtures will install this package. Note that this requires torch and einops to be pre-installed as dependencies. If you would like to build this package from source, run the following command:

git clone https://github.com/jaisidhsingh/pytorch-mixtures.git
cd pytorch-mixtures
pip3 install .

Usage

pytorch-mixtures is designed to effortlessly integrate into your existing code for any neural network of your choice, for example

import torch
from pytorch_mixtures import TopkMoE, ExpertChoiceMoE, MoEConfig


BATCH_SIZE = 16
SEQ_LEN = 128
DIM = 768
NUM_EXPERTS = 8
CAPACITY_FACTOR = 1.25

config = MoEConfig(
    hidden_dim=DIM,
    intermediate_dim=DIM * 4,
    num_experts=NUM_EXPERTS,
    expert_fn="ff",
    expert_act="silu",
    router_fn="topk",
    capacity_factor=CAPACITY_FACTOR,
    topk=2,
    dtype=torch.float32
)

moe = TopkMoE(config)

x = torch.randn(BATCH_SIZE, SEQ_LEN, DIM)
output = moe(x)  # shape: [BATCH_SIZE, SEQ_LEN, DIM]

You can also use this easily within your own nn.Module classes

import torch
import torch.nn as nn
from pytorch_mixtures import TopkMoE, ExpertChoiceMoE, MoEConfig


class CustomMoEBlock(nn.Module):
    def __init__(self, dim, num_experts, capacity_factor):
        super().__init__()
        self.config = MoEConfig(
            hidden_dim=dim,
            intermediate_dim=dim * 4,
            num_experts=num_experts,
            expert_fn="ff",
            expert_act="silu",
            router_fn="topk",
            capacity_factor=capacity_factor,
            topk=2,
            dtype=torch.float32
        )
        self.moe = TopkMoE(self.config)
        
    def forward(self, x):
        return self.moe(x)


my_block = CustomMoEBlock(
    dim=768,
    num_experts=8,
    capacity_factor=1.25
)

x = torch.randn(16, 128, 768)
output = my_block(x)  # output shape: [16, 128, 768]

Testing

This package provides the user to run a simple test for the MoE code. If all experts are initialized as the same module, the output of the MoE should be equal to the input tensor passed through any expert. The users can run these tests for themselves by running the following:

from pytorch_mixtures import run_tests

run_tests()

Or from the command line:

python -m tests.tests

Note: All tests pass correctly. If a test fails, it is likely due to an edge case in the random initializations. Try again, and it will pass.

Citation

If you found this package useful, please cite it in your work:

@misc{JaisidhSingh2024,
  author = {Singh, Jaisidh},
  title = {pytorch-mixtures},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/jaisidhsingh/pytorch-mixtures}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_mixtures-0.1.5.tar.gz (6.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_mixtures-0.1.5-py3-none-any.whl (7.2 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_mixtures-0.1.5.tar.gz.

File metadata

  • Download URL: pytorch_mixtures-0.1.5.tar.gz
  • Upload date:
  • Size: 6.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for pytorch_mixtures-0.1.5.tar.gz
Algorithm Hash digest
SHA256 4ebf689972bbcb7ae366276fb6613dd569161c3235c43991645d9db7923173a4
MD5 8c8a9391b7dd2d2d9de6ddfeca61194d
BLAKE2b-256 d679b1c7d1573b0c5acb7c247b26d45f1395e3e5ce2b4f4f831ceef9cbe11267

See more details on using hashes here.

File details

Details for the file pytorch_mixtures-0.1.5-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_mixtures-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 b72380400fde6c50dd04d1387a9fe1b8720251bb6922b7a5709371e306dfcb6e
MD5 5794de3b3e29f6d972a6b78b2d5c34f1
BLAKE2b-256 8341bcc0bd03d370b5d2ddfc784d0c4f973db5ba95f753bdb7643d3ac98e1381

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page