Skip to main content

FMS Acceleration Plugin for Mixture-of-Experts

Project description

FMS Acceleration for Mixture-of-Experts

This library contains plugins to accelerate finetuning with the following optimizations:

  1. Expert-Parallel MoE with Triton Kernels from ScatterMoE, and some extracted from megablocks.
    • Megablocks kernels for gather and scatter

Plugins

Plugin Description Depends Loading Augmentation Callbacks
scattermoe MoE Expert Parallel with Triton Kernels from scattermoe (& megablocks) ScatterMoE / extracted kernels from megablocks

Adding New Models

Our ScatterMoe implementation is a module-swap; to add new models we need to update the specifications in scattermoe_constants.py.

  • See the code documentation within to understand how to add new models.

Using ScatterMoE Saved Checkpoints

ScatterMoE checkpoints are saved using torch.distributed.checkpoint (DCP) and which is by default StateDictType.SHARDED_STATE_DICT:

  • DTensors limited support for full state dicts.
  • sharded state dicts are the extremely efficient, and require little comms overhead when saving.

We provide a script to recover back the original checkpoint:

  • currently the script is only tested in the case where DCP has saved the model in a single node.

If the checkpoint is stored in hf/checkpoint-10, call the following to have the converted checkpoint written into output_dir:

python -m fms_acceleration_moe.utils.checkpoint_utils \
    hf/checkpoint-10 output_dir \
    mistralai/Mixtral-8x7B-Instruct-v0.1

Code Extracted from Megablocks

Notes on code extraction:

Running Benchmarks

Run the below in the top-level directory of this repo:

  • the scattermoe dep is not included by default, so the -x switch installs it.
  • consider disabling the torch memory logging to see improved speeds.
tox -e run-benches \
    -x testenv:run-benches.setenv+="MEMORY_LOGGING=nvidia" \
    -- \
    "1 2 4" 128 benchmark_outputs scenarios-moe.yaml accelerated-moe-full

or run the larger Mixtral-8x7B bench:

tox ... \
    8 128 benchmark_outputs scenarios-moe.yaml accelerated-moe-full-mixtral

NOTE: if FileNotFoundError is observed on the triton cache, similar to issues like these:

then somehow tox is causing problems with triton and multiprocessing (there is some race condition). But the workaound is to first activate the tox env and running in bash:

# if FileNotFoundError in the triton cache is observed
# - then activate the env and run the script manually

source .tox/run-benches/bin/activate
bash scripts/run_benchmarks.sh \
    ....

Triton Kernel Dependencies

Triton Kernels are copied into scattermoe_utils and were copied from kernel hyperdrive which is a fork of cute kernels

Known Issues

These are currently some known issues not yet resolved:

  • should eventually remove the dependency on an external kernel-hyperdrive repository.
  • now support only loading sharded safetensor non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed.
  • when used together with FSDP, the FSDP's clip_grad_norm will not properly compute for ScatterMoE, see issue here.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fms_acceleration_moe-0.4.6-py3-none-any.whl (51.6 kB view details)

Uploaded Python 3

File details

Details for the file fms_acceleration_moe-0.4.6-py3-none-any.whl.

File metadata

File hashes

Hashes for fms_acceleration_moe-0.4.6-py3-none-any.whl
Algorithm Hash digest
SHA256 8ff199666431a76244dbfc47e05ee3ed31177dcb60163b2c2da4f0b78bf15354
MD5 64537f544be61eac2eec26f649e8caf5
BLAKE2b-256 cc171d0568956b7cfb80b468bbb8e574815dd70cf3662e4eef1dd9bc2abe074c

See more details on using hashes here.

Provenance

The following attestation bundles were made for fms_acceleration_moe-0.4.6-py3-none-any.whl:

Publisher: build-and-publish.yml on foundation-model-stack/fms-acceleration

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page