No project description provided
Project description
Mammoth: Mixture of Mini Experts in Pathology
Mixture of Mini Experts: Overcoming the Linear Layer Bottleneck in Multiple Instance Learning, ICLR 2026.
Daniel Shao, Joel Runevic, Richard J. Chen, Drew F. K. Williamson, Ahrong Kim, Andrew H. Song*, Faisal Mahmood*
A parameter-efficient, plug-and-play mixture of experts for multiple instance learning models in computational pathology
Paper | OpenReview | Citation
How does Mammoth work?
Key Ideas
In Multiple Instance Learning (MIL) for whole-slide images, the standard pipeline is:
- Extract patch features (e.g. from a pretrained encoder),
- Transform them with a linear layer into task-specific patch features,
- Aggregate patches into a slide-level representation for classification.
Most works focus on (1) and (3). Mammoth explicitly targets (2): it replaces the single linear layer with a low-rank mixture of experts such that each patch gets a transformation tailored to its phenotype. This is done with a comparable number of parameters as the original linear layer.
- Low-rank: Each expert is a factorized (LoRA-style) linear layer, keeping the parameter count close to a single linear layer.
- Mixture of experts: Slot-based routing assigns each patch to a combination of experts; the final representation is a weighted combination of expert outputs.
- Plug-and-play: Drop-in replacement for the patch embedding linear layer in any MIL method. Works with mean/max pooling, attention, CLAM, TransMIL, and others.
Main Findings (click to expand)
- Improved performance: Across 8 MIL methods and 19 classification tasks, Mammoth improves performance in 130/152 configurations with an average +3.8% change, and often has a larger effect than the choice of aggregation method. Shown is the average performance per MIL method, averaged across all tasks
- Structured Feature Space: Mammoth yields a structured feature space, with outputs forming distinct clusters per expert, and subclusters per slot.
- Expert specialization: Mammoth experts focus on diverse morphological phenotypes, enabling context-specific processing
- Mitigated Instance-Gradient Intereference: Heterogeneous instances yield conflicting gradient updates for the standard linear layer, which is mitigated by Mammoth's expert routing.
Representative performance
| MIL Model | Linear (Morph, T=6) |
Mammoth (Morph, T=6) |
Linear (Molec, T=13) |
MAMMOTH (Molec, T=13) |
|---|---|---|---|---|
| ABMIL | 75.2 | 78.4 | 72.8 | 74.6 |
| CLAM | 71.7 | 78.5 | 72.9 | 73.7 |
| TransMIL | 72.8 | 76.5 | 72.2 | 73.7 |
| Transformer | 73.5 | 77.5 | 71.8 | 74.2 |
| ILRA | 71.5 | 77.7 | 71.6 | 72.8 |
| MeanMIL | 72.5 | 77.0 | 72.6 | 74.5 |
| MaxMIL | 71.9 | 74.8 | 72.9 | 74.1 |
| DSMIL | 72.7 | 75.6 | 72.1 | 73.3 |
Shown is average performance for the linear layer vs. Mammoth across different MIL methods with UNI patch features. Balanced accuracy is reported for morphological subtyping tasks, and AUROC is reported for molecular subtyping tasks.
Installation
Install via pip
Mammoth can be installed as a Python package:
pip install mammoth-moe
The package depends on:
- PyTorch
- einops
These will be installed automatically.
Development installation
To install from source:
pip install -e .
Manual installation
Use your existing Python environment. The mammoth.py module depends on:
- PyTorch
- einops
For quickstart instructions to use the MIL models in this repository, including environment setup, please see the MIL-Lab
Minimal example: adding Mammoth to any MIL model
Mammoth is a drop-in replacement for the first linear layer that maps patch features to the dimension used by the rest of your MIL model. Below, a simple mean-pooling MIL model uses either a linear layer or Mammoth:
import torch
import torch.nn as nn
from mammoth import Mammoth
class MeanMIL(nn.Module):
def __init__(self, in_dim, out_dim, num_classes, moe_args={}):
super().__init__()
if moe_args and moe_args.get('num_experts', 0) > 0:
self.fc = Mammoth(
input_dim=in_dim,
dim=out_dim,
**mammoth_args
)
else:
self.fc = nn.Linear(in_dim, out_dim)
self.classifier = nn.Linear(out_dim, num_classes)
def forward(self, x):
# x: (batch, num_patches, in_dim)
x = self.fc(x) # -> (batch, num_patches, out_dim)
x = torch.mean(x, dim=1) # aggregate
return self.classifier(x)
in_dim = 1024 # e.g. patch feature dimension from a backbone
dim = 512 # dimension for aggregation / classifier
num_classes = 3
# our recommended hyperparameters for MAMMOTH
moe_args = {
"input_dim": in_dim,
"dim": dim,
"num_experts": 30,
"num_slots": 10,
"num_heads": 16,
"slot_dim": 256,
"keep_slots": True, # if True, return the E*S aggregated features instead of the N transformed patch features
"share_lora_weights": True, # share the weights of the first low rank layer
"dropout": 0.1,
"auto_rank": True, # automatically calculate the appropriate low rank for parameter efficiency
}
model = MeanMIL(in_dim, dim, num_classes, moe_args=moe_args)
x = torch.randn(2, 1000, in_dim)
logits = model(x) # (2, num_classes)
[!Note] Mammoth is intended to be a drop-in replacement for the linear layer at comparable parameter counts. While
num_experts,num_slots, andnum_headsmay be adjusted, we strongly recommend settingshare_weights=Trueandauto_rank=Trueto automatically compute the appropriate layer sizes.
Viewing Expert Specialization
The routing scores for heatmaps can be generated via the parameter return_weights.
input = torch.randn(B, N, H * D)
# out is B (SE) (HD)
out = model.patch_router(input)
# routing_weights is B N E S H D
out, routing_weights = model.patch_router(input, return_weights=True)
For starter code to generate your own visualizations with the routing scores, please see [this script].(./examples/tutorial_mammoth_visualization.py)
Full MIL models
Enabling Mammoth requires passing a moe_args dict with num_experts > 0 and the usual Mammoth arguments (num_experts, input_dim, dim, num_heads, etc.). If moe_args is empty or num_experts == 0, the model uses the original linear layer.
Example: ABMIL with Mammoth
from abmil_mammoth import ABMIL, ABMILGatedBaseConfig, ABMILModel
# minimal args needed to initialize MAMMOTH. This will create 30 experts, 16 heads, 10 slots/expert, weight sharing
moe_args = {
"num_experts": 30
}
config = ABMILGatedBaseConfig(
in_dim=1024,
embed_dim=512,
num_classes=2,
moe_args=moe_args,
)
model = ABMILModel(config)
# Forward: (B, M, D) patch features -> logits, loss, etc.
MIL models with mammoth can also be instantiated with MIL-Lab's create_model method by specifying the base_mammoth config:
from src.builder import create_model
# standard abmil model with linear layer and uni's 1024 input dimension
create_model('abmil.base.uni', num_classes=5)
# use standard abmil with mammoth
create_model('abmil.base_mammoth.uni', num_classes=5)
# Specify the encoder to automatically update the input dimension
create_model('abmil.base_mammoth.conch_v15', num_classes=5)
The following MIL implementations are available. This allows the patch_embed layer to be optionally equipped with Mammoth by passing moe_args into the model class, or with create_model.
| Model | Code | Paper | Model Class | Initialization |
|---|---|---|---|---|
| ABMIL | Link | Link | ABMILModel() |
create_model('abmil.base_mammoth') |
| TransMIL | Link | Link | TransMILModel() |
create_model('transmil.base_mammoth') |
| Transformer | Link | Link | TransformerModel() |
create_model('transformer.base_mammoth') |
| WiKG | Link | Link | WIKGMILModel() |
create_model('wikg.base_mammoth') |
| DFTD | Link | Link | DFTDModel() |
create_model('dftd.base_mammoth') |
| DSMIL | Link | Link | DSMILModel() |
create_model('dsmil.base_mammoth') |
| ILRA | Link | Link | ILRAModel() |
create_model('ilra.base_mammoth') |
| RRT | Link | Link | RRTMILModel() |
create_model('rrt.base_mammoth') |
| CLAM | Link | Link | CLAMModel() |
create_model('clam.base_mammoth') |
Repository layout
| Path | Description |
|---|---|
modules/mammoth.py |
Core Mammoth module: Mammoth, factorized experts, slot routing, and supporting layers |
modules/components.py |
Shared utilities (e.g. ensure_batched) used by mammoth.py |
MIL-Lab/src/models/ |
MIL model wrappers (ABMIL, CLAM, DSMIL, TransMIL, etc.) with optional Mammoth patch embedding |
examples/tutorial_mammoth_visualization.py |
Expert dispatch heatmaps on WSIs using a saved Mammoth checkpoint |
config/paths.py |
Central path mappings for tasks and WSI/feature directories |
Issues
- The preferred mode of communication is via GitHub issues.
- If GitHub issues are inappropriate, email dshao@mit.edu and asong2@mdanderson.org
Funding
This work was funded by NIH NIGMS R35GM138216.
License and Terms of Use
ⓒ Mahmood Lab. This repository is released under the CC-BY-NC-ND 4.0 license and may only be used for non-commercial, academic research purposes with proper attribution. Any commercial use, sale, or other monetization of this repository is prohibited and requires prior approval. By downloading any pretrained encoder, you agree to follow the model's respective license.
Acknowledgements
The project was built on top of amazing repositories such as HuggingFace and open-source contributions for all MIL models from the community. We thank the authors and developers for their contribution.
Citation
If you use this code, the Mammoth method, or the MIL model implementations in your work, please cite:
@inproceedings{shao2026mammoth,
title={Mixture of Mini Experts: Overcoming the Linear Layer Bottleneck in Multiple Instance Learning},
author={Shao, Daniel and Runevic, Joel and Chen, Richard J. and Williamson, Drew F. K. and Kim, Ahrong and Song, Andrew H. and Mahmood, Faisal},
booktitle={International Conference on Learning Representations (ICLR)},
year={2026},
url={https://openreview.net/forum?id=S5Io33pc78}
}
@inproceedings{shao2025do,
title={Do Multiple Instance Learning Models Transfer?},
author={Shao, Daniel and Chen, Richard J and Song, Andrew H and Runevic, Joel and Lu, Ming Y. and Ding, Tong and and Mahmood, Faisal},
booktitle={International conference on machine learning},
year={2025},
}
License
See the repository for license information. The paper is under CC BY 4.0.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mammoth_moe-0.1.2.tar.gz.
File metadata
- Download URL: mammoth_moe-0.1.2.tar.gz
- Upload date:
- Size: 17.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6cd81f7a9ff47b461a4b3763d6c835cab9dd94967c006479209c7e75ed23cf17
|
|
| MD5 |
f4c3ede7cc36598fb5ffc9ea6bb178fb
|
|
| BLAKE2b-256 |
92aa85207fbb9a56bb455570153e944bce1252fad1fdf6830f40812f75b1dc08
|
File details
Details for the file mammoth_moe-0.1.2-py3-none-any.whl.
File metadata
- Download URL: mammoth_moe-0.1.2-py3-none-any.whl
- Upload date:
- Size: 12.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
de4b88b0045410a7c0ed6636d3e92fb022d61131d10539e32fce82442c4bcd52
|
|
| MD5 |
80ce155646b6223897a8d6c70d4fa784
|
|
| BLAKE2b-256 |
f3567282131b50b3a9b663d0491934a2a101b664e84af554e01d241505bad579
|