Skip to main content

A package for MLX model sharding and distributed inference

Project description

MLX Sharding

This project demonstrates how to implement pipeline parallelism for large language models using MLX. It includes tools for sharding a model, serving shards across multiple machines, and generating text using the distributed model. Additionally, it features an OpenAI API-compatible server for easier integration and usage.

Demo Video

To see the distributed inference in action, check out our demo video:

Sharding DeepSeek-Coder-V2-Lite-Instruct Demo

Educational Purpose

This repository is designed for educational purposes to illustrate how pipeline parallelism can be implemented in MLX. It provides a basic framework for:

  1. Sharding a large language model
  2. Distributing model shards across multiple machines
  3. Implementing a simple pipeline for text generation
  4. Serving the model through an OpenAI API-compatible interface

While not optimized for production use, this demo serves as a starting point for understanding and experimenting with pipeline parallelism in machine learning workflows.

Setup and Usage

1. Model Preparation

You have two main options for preparing and using the model:

Option A: Pre-Sharding the Model

If you prefer to pre-shard the model, use sharding_weight.py:

python sharding_weight.py --model "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" --output_dir shard_0 --start_layer 0 --end_layer 14 --total_layers 27
python sharding_weight.py --model "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" --output_dir shard_1 --start_layer 14 --end_layer 27 --total_layers 27
# Repeat for additional shards as needed

Option B: Dynamic Sharding

You can let the system dynamically load and shard the weights when starting the server. This option doesn't require pre-sharding.

2. Distribute Shards (If Using Option A)

If you've pre-sharded the model, copy the shard directories to their respective machines. Skip this step for Option B.

3. Start the Servers

Start server instances based on your chosen approach:

For Pre-Sharded Model (Option A)

On each machine with a shard, start a server instance. For example:

python -m shard.main --model mzbac/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx-shard-1

For Dynamic Sharding (Option B)

Start the server with specific layer ranges:

python -m shard.main --model "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" --start-layer 0 --end-layer 14

Note the IP address and port printed by each server.

4. Generate Text

Using the generate script

For a dynamically sharded setup:

python generate.py --model "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" --start_layer 0 --end_layer 14 --server_address <remote_ip1>:<port1>,<remote_ip2>:<port2> --prompt "Your prompt here" --max_tokens 512

For a pre-sharded setup:

python generate.py --model mzbac/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx-shard-0 --server_address <remote_ip1>:<port1>,<remote_ip2>:<port2> --prompt "Your prompt here" --max_tokens 512

Using the OpenAI API-compatible server

  1. Start the server:

    For dynamic sharding:

    python -m shard.openai_api --model "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" --llm-shard-addresses localhost:50051,<remote_ip1>:<port1>,<remote_ip2>:<port2> --start-layer 0 --end-layer 14
    

    For pre-sharded model:

    python -m shard.openai_api --model mzbac/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx-shard-0 --llm-shard-addresses localhost:50051,<remote_ip1>:<port1>,<remote_ip2>:<port2>
    
  2. Use the API endpoints:

    • /v1/completions: Text completion endpoint
    • /v1/chat/completions: Chat completion endpoint

Example usage:

curl localhost:8080/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
     "messages": [{"role": "user", "content": "Say this is a test!"}],
     "temperature": 0.7
   }'

Limitations and Considerations

  1. Network Dependency: The performance of this pipeline parallelism implementation is heavily dependent on network speed and latency between machines.

  2. Error Handling: The current implementation has basic error handling. In a production environment, you'd want to implement more robust error handling and recovery mechanisms.

  3. Security: This demo uses insecure gRPC channels. For any real-world application, implement proper security measures.

  4. Shard Configuration: Ensure that when using multiple shards, the layer ranges are set correctly to cover the entire model without overlap.

Extending the System

To extend the system for more shards:

  1. If pre-sharding, create additional shards using sharding_weight.py.
  2. Set up more server instances, one for each new shard.
  3. In generate.py or when using the OpenAI API server, include all shard addresses.
  4. Adjust the layer ranges accordingly when using dynamic sharding.

Requirements

  • Python 3.x
  • MLX library
  • gRPC and related dependencies
  • NumPy
  • Transformers library
  • Sufficient RAM on each machine to load and process its model shard

Acknowledgments

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_sharding-0.1.0.tar.gz (16.0 kB view details)

Uploaded Source

Built Distribution

mlx_sharding-0.1.0-py3-none-any.whl (19.4 kB view details)

Uploaded Python 3

File details

Details for the file mlx_sharding-0.1.0.tar.gz.

File metadata

  • Download URL: mlx_sharding-0.1.0.tar.gz
  • Upload date:
  • Size: 16.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.0

File hashes

Hashes for mlx_sharding-0.1.0.tar.gz
Algorithm Hash digest
SHA256 e7d3497c8bb0154766b91f9d10ca211b46c255777679c286c73d5c655c7d4f0b
MD5 6c829820ea7a7ca65814cb069d6b3ab9
BLAKE2b-256 c86eb5a3c43e3da33e7fac11764a91cacca267f9ea98e041d6191cb2f2619ceb

See more details on using hashes here.

File details

Details for the file mlx_sharding-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: mlx_sharding-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 19.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.0

File hashes

Hashes for mlx_sharding-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4f6734f2fa533700f00c35f6c6aae6b8836feddc6819ece9fcd78edbb49c63b1
MD5 048cc0ea2e96a62e7184f870e411929a
BLAKE2b-256 088978883826ffec98447524cd72c17bf6d2e47c3d33ae9c336e86784e7d3d9b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page