Skip to main content

A DP-based dynamic batching scheduler for inference workloads

Project description

Razors-Edge-batching-scheduler

A scheduler to maximize throughput and fulfill latency objectives for ML requests.

What this is

This project experiments with a smarter way to group requests into batches so you can:

  • get more throughput
  • keep latency lower
  • handle different input sizes better than simple batching

It focuses on workloads like embeddings / classification where batched compute is much faster than one-by-one processing.

How it works

When batching inputs for AI, there is usually padding. This padding creates inefficiency. Therefore, to maximize throughput, inputs with very different sizes should not be batched together.

In addition, the best next batch to run can be chosen using different strategies (e.g., FIFO, MINMAX, or GUARDED_BATCH_SIZE) depending on latency and throughput goals.

This repo describes a scheduler that takes these details into account.

Installation

Install from PyPI:

pip install razors-edge-batching-scheduler

For local development (from this repository):

python -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
pip install -e .

Demo (real GPU benchmark task)

Below is the full real GPU benchmark task code used for Razor's Edge scheduling demos:

from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from itertools import cycle
from typing import Any

import numpy as np

from razors_edge.razors_edge_compute_task import RazorsEdgeComputeTask


class RazorsEdgeGPUBenchmarkTask(RazorsEdgeComputeTask):
    """Dummy task with realistic benchmarking, batching, and post-processing behavior."""

    @property
    def batch_benchmark_sizes(self) -> list[int]:
        return [1, 2, 3, 5, 8, 10, 13, 16]

    @property
    def min_input_size(self) -> int:
        return 1

    @property
    def max_input_size(self) -> int:
        return 1024

    @property
    def max_input_points(self) -> int:
        return 7

    @property
    def is_gpu(self) -> bool:
        return True

    def get_input_size(self, input_data: Any, preprocessed_input: Any) -> int:
        """Return the token count for pre-tokenized model input."""
        return int(preprocessed_input["input_ids"].shape[1])

    def generate_test_input(self, batch_size: int, input_size: int) -> tuple[tuple, dict[str, np.ndarray]]:
        return (), {
            "input_ids": self.torch.ones((batch_size, input_size), dtype=self.torch.long, device="cuda"),
            "attention_mask": self.torch.ones((batch_size, input_size), dtype=self.torch.long, device="cuda")
        }

    def load_model(self, model_pool: ThreadPoolExecutor) -> Any:
        import os
        BASE_DIR = "E:\\Github\\Razors-Edge-batching-scheduler"
        os.environ["HF_HUB_OFFLINE"] = "1"
        os.environ["HF_HOME"] = f"{BASE_DIR}\\models"
        import torch
        assert torch.cuda.is_available(), "CUDA NOT AVAILABLE"
        from transformers import AutoTokenizer, AutoModel
        with torch.inference_mode():
            self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
            model = AutoModel.from_pretrained("BAAI/bge-m3")
            model = model.eval().half().to("cuda")
        self.torch = torch
        max_batch_size = self.batch_benchmark_sizes[-1]
        max_input_size = self.max_input_size
        self.token_buffer = cycle(
            [self.generate_test_input(max_batch_size, max_input_size)[1] for _ in range(model_pool._max_workers + 1)]
        )
        torch.cuda.empty_cache()

        def run_model(*_, **inputs):
            with torch.inference_mode(), torch.autocast("cuda"):
                return model(**inputs)

        return run_model

    def preprocess_input_without_size(self, input_data: str) -> tuple[str, dict[str, np.ndarray]]:
        return self.tokenizer([input_data], padding=True, truncation=True, return_tensors="pt")

    def create_batch(self, to_batch: list[tuple[str, dict[str, np.ndarray]]]) -> tuple[tuple, dict[str, np.ndarray]]:
        token_buffer = next(self.token_buffer)
        max_size = max(payload["input_ids"].shape[1] for payload in to_batch)
        batch_size = len(to_batch)
        buffer_copy = {k: v[:batch_size, :max_size] for k, v in token_buffer.items()}
        buffer_copy["input_ids"].fill_(1)
        buffer_copy["attention_mask"].fill_(0)
        for row, payload in enumerate(to_batch):
            for key, value in payload.items():
                buffer_copy[key][row, : value.shape[1]] = value[0]
        return (), buffer_copy

    def postprocess_output(self, call_output: Any) -> Iterable[list[float]]:
        """Normalize embeddings and return list rows."""
        with self.torch.inference_mode() and self.torch.autocast("cuda"):
            embeddings = call_output.last_hidden_state.mean(dim=1)
            embeddings = self.torch.nn.functional.normalize(embeddings, p=2, dim=1)
            return embeddings.tolist()

Usage:

Note - multiple Tasks can be put into a single ComputeExecutor, and can be run with multiple threads.

import asyncio
import random
import string
import time

from batching_executor.process_manager import ComputeExecutor

executor = ComputeExecutor(
    [RazorsEdgeGPUBenchmarkTask],
    async_limit=64,
    model_thread_limit=1,
)


def generate_random_strings(n, a, b, seed=42):
    random.seed(seed)
    chars = string.ascii_letters + string.digits
    return [''.join(random.choice(chars) for _ in range(random.randint(a, b))) for _ in range(n)]


async def benchmark_async(target, parallelism_limit: int, max_token_count: int, request_count: int):
    payloads = generate_random_strings(request_count, 1, max_token_count)
    start = time.perf_counter()

    semaphore = asyncio.Semaphore(parallelism_limit)

    async def limited_task(*args, **kwargs):
        async with semaphore:
            return await executor.async_compute_fn(*args, **kwargs)

    await asyncio.gather(*(limited_task(target, payload) for payload in payloads))
    elapsed = time.perf_counter() - start
    return elapsed, request_count / elapsed

Project layout

  • src/ → scheduler + task logic
  • tests/ → test coverage
  • demos/ → experiments and benchmark notebooks
  • images/ → generated benchmark plots
  • PAPER.md → full deep-dive explanation

Result images

Benchmark plots are in images/.

  • Synthetic throughput comparisons
  • Gains from variable batch sizing
  • Real workload benchmarks

Run tests

python -m coverage run --source=src -m unittest discover -v
coverage html

Recommended Background Music

When using these methods, it is recommended that you listen to this for better code.

Razor's Edge (Official Nightcore)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

razors_edge_batching_scheduler-0.1.0.tar.gz (22.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

razors_edge_batching_scheduler-0.1.0-py3-none-any.whl (21.7 kB view details)

Uploaded Python 3

File details

Details for the file razors_edge_batching_scheduler-0.1.0.tar.gz.

File metadata

File hashes

Hashes for razors_edge_batching_scheduler-0.1.0.tar.gz
Algorithm Hash digest
SHA256 5360e650bec54ea915d523741af84c954e62b44c2ff37cead9b11ce52d69c112
MD5 cfeebcf3c17cd2ccde2d266932d525a4
BLAKE2b-256 af9d693c63eaf004d8c8a7f56d42e767609638709f0d7cdd78d7cf4c3842af3d

See more details on using hashes here.

File details

Details for the file razors_edge_batching_scheduler-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for razors_edge_batching_scheduler-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0f7251b83e294c8d688d253096c2e820332a65ed961b9c2691b0fcaf0835e6e1
MD5 718f26a6250c1292a09e95dc65397941
BLAKE2b-256 06b51b7d135ad17c6b9eebfcf68e779770cb0cd23f65023367620e7444ec8f72

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page