Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] support weighted round robin routing #213

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ The router ensures efficient request distribution among backends. It supports:
- Exporting observability metrics for each serving engine instance, including QPS, time-to-first-token (TTFT), number of pending/running/finished requests, and uptime
- Automatic service discovery and fault tolerance by Kubernetes API
- Multiple different routing algorithms
- Round-robin routing
- Session-ID based routing
- Round-robin routing: Distributes requests evenly across all endpoints
- Session-ID based routing: Routes requests from the same session to the same endpoint
- Weighted routing: Distributes traffic proportionally based on configured weights using Smooth Weighted Round Robin algorithm
- (WIP) prefix-aware routing

Please refer to the [router documentation](./src/vllm_router/README.md) for more details.
Expand Down
185 changes: 185 additions & 0 deletions src/tests/test_weighted_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import time
import unittest
from collections import Counter
from unittest.mock import MagicMock

from fastapi import Request

from vllm_router.routing_logic import WeightedRouter
from vllm_router.service_discovery import EndpointInfo


class TestWeightedRouter(unittest.TestCase):
def setUp(self):
# Create test endpoints with current timestamp
current_time = time.time()
self.endpoints = [
EndpointInfo(
url="http://endpoint1:8000",
model_name="model1",
added_timestamp=current_time,
),
EndpointInfo(
url="http://endpoint2:8000",
model_name="model1",
added_timestamp=current_time,
),
EndpointInfo(
url="http://endpoint3:8000",
model_name="model1",
added_timestamp=current_time,
),
]

# Configure weights (50%, 30%, 20%)
self.weights = {
"http://endpoint1:8000": 51,
"http://endpoint2:8000": 27,
"http://endpoint3:8000": 22,
}

# Initialize router
self.router = WeightedRouter(weights=self.weights)

# Create mock request, engine stats and request stats
self.mock_request = MagicMock(spec=Request)
self.mock_engine_stats = {}
self.mock_request_stats = {}

def test_weight_distribution(self):
"""Test if requests are distributed according to the configured weights."""
# Number of requests to simulate
num_requests = 1000

# Collect routing decisions
routing_results = Counter()
for _ in range(num_requests):
chosen_url = self.router.route_request(
self.endpoints,
self.mock_engine_stats,
self.mock_request_stats,
self.mock_request,
)
routing_results[chosen_url] += 1

# Calculate actual distribution percentages
total_requests = sum(routing_results.values())
actual_distribution = {
url: (count / total_requests) * 100
for url, count in routing_results.items()
}

# Define acceptable margin of error (in percentage points)
margin = 5.0
# Print actual distribution
print(f"Actual distribution: {actual_distribution}")
# Verify distribution matches configured weights within margin
for url, expected_weight in self.weights.items():
actual_weight = actual_distribution[url]
self.assertAlmostEqual(
actual_weight,
expected_weight,
delta=margin,
msg=f"Distribution for {url} ({actual_weight:.1f}%) differs from expected weight ({expected_weight}%) by more than {margin}%",
)

def test_dynamic_endpoint_changes(self):
"""Test if router handles endpoint changes correctly."""
# Initial routing with all endpoints
url1 = self.router.route_request(
self.endpoints,
self.mock_engine_stats,
self.mock_request_stats,
self.mock_request,
)
self.assertIn(url1, self.weights.keys())

# Remove one endpoint
reduced_endpoints = self.endpoints[1:] # Remove first endpoint
url2 = self.router.route_request(
reduced_endpoints,
self.mock_engine_stats,
self.mock_request_stats,
self.mock_request,
)
self.assertIn(url2, [ep.url for ep in reduced_endpoints])

# Add back all endpoints
url3 = self.router.route_request(
self.endpoints,
self.mock_engine_stats,
self.mock_request_stats,
self.mock_request,
)
self.assertIn(url3, self.weights.keys())

def test_missing_weights(self):
"""Test if router handles endpoints without configured weights."""
# Create router with weights for only some endpoints
partial_weights = {
"http://endpoint1:8000": 50,
"http://endpoint2:8000": 50,
}
router = WeightedRouter(weights=partial_weights)

# Route requests and verify all endpoints are still used
used_endpoints = set()
for _ in range(100):
url = router.route_request(
self.endpoints,
self.mock_engine_stats,
self.mock_request_stats,
self.mock_request,
)
used_endpoints.add(url)

# Verify all endpoints are used, even those without configured weights
self.assertEqual(
used_endpoints,
{ep.url for ep in self.endpoints},
"Not all endpoints were used in routing",
)

def test_smooth_distribution(self):
"""Test if the distribution is smooth without bursts."""
# Track consecutive selections of the same endpoint
max_consecutive = {url: 0 for url in self.weights}
current_consecutive = {url: 0 for url in self.weights}
last_url = None

# Route a significant number of requests
num_requests = 1000
for _ in range(num_requests):
url = self.router.route_request(
self.endpoints,
self.mock_engine_stats,
self.mock_request_stats,
self.mock_request,
)

# Update consecutive counts
for endpoint_url in self.weights:
if url == endpoint_url:
current_consecutive[url] += 1
max_consecutive[url] = max(
max_consecutive[url], current_consecutive[url]
)
else:
current_consecutive[endpoint_url] = 0

last_url = url

# Check that no endpoint was selected too many times in a row
# For SWRR, the maximum consecutive selections should be relatively small
for url, weight in self.weights.items():
expected_max = (weight / 10) + 2 # Heuristic threshold
self.assertLess(
max_consecutive[url],
expected_max,
f"Endpoint {url} was selected {max_consecutive[url]} times consecutively, "
f"which is more than expected ({expected_max}) for its weight {weight}%",
)


if __name__ == "__main__":
unittest.main()
40 changes: 35 additions & 5 deletions src/vllm_router/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The source code for the request router.
- Multiple different routing algorithms
- Round-robin routing
- Session-ID based routing
- Weighted routing (Smooth Weighted Round Robin)
- (WIP) prefix-aware routing

## Running the router
Expand All @@ -32,8 +33,9 @@ The router can be configured using command-line arguments. Below are the availab

### Routing Logic Options

- `--routing-logic`: The routing logic to use. Options are `roundrobin` or `session`. This option is required.
- `--session-key`: The key (in the header) to identify a session.
- `--routing-logic`: The routing logic to use. Options are `roundrobin`, `session`, or `weighted`. Default is `roundrobin`.
- `--session-key`: The key (in the header) to identify a session when using session-based routing.
- `--weights`: JSON string mapping of endpoint URLs to their weights for weighted routing. Example: `{"http://endpoint1:8000": 30, "http://endpoint2:8000": 70}`.

### Monitoring Options

Expand Down Expand Up @@ -62,7 +64,7 @@ You can install the router using the following command:
pip install -e .
```

**Example 1:** running the router locally at port 8000 in front of multiple serving engines:
**Example 1:** running the router locally at port 8000 in front of multiple serving engines with round-robin routing:

```bash
vllm-router --port 8000 \
Expand All @@ -74,6 +76,17 @@ vllm-router --port 8000 \
--routing-logic roundrobin
```

**Example 2:** running the router with weighted routing to distribute traffic proportionally:

```bash
vllm-router --port 8000 \
--service-discovery static \
--static-backends "http://localhost:9001,http://localhost:9002,http://localhost:9003" \
--static-models "facebook/opt-125m,meta-llama/Llama-3.1-8B-Instruct,facebook/opt-125m" \
--routing-logic weighted \
--weights '{"http://localhost:9001": 50, "http://localhost:9002": 30, "http://localhost:9003": 20}'
```

## Dynamic Router Config

The router can be configured dynamically using a json file when passing the `--dynamic-config-json` option.
Expand All @@ -84,7 +97,7 @@ Currently, the dynamic config supports the following fields:
**Required fields:**

- `service_discovery`: The service discovery type. Options are `static` or `k8s`.
- `routing_logic`: The routing logic to use. Options are `roundrobin` or `session`.
- `routing_logic`: The routing logic to use. Options are `roundrobin`, `session`, or `weighted`.

**Optional fields:**

Expand All @@ -94,8 +107,9 @@ Currently, the dynamic config supports the following fields:
- (When using `k8s` service discovery) `k8s_namespace`: The namespace of vLLM pods when using K8s service discovery. Default is `default`.
- (When using `k8s` service discovery) `k8s_label_selector`: The label selector to filter vLLM pods when using K8s service discovery.
- `session_key`: The key (in the header) to identify a session when using session-based routing.
- `weights`: A dictionary mapping endpoint URLs to their weights when using weighted routing.

Here is an example dynamic config file:
Here is an example dynamic config file with round-robin routing:

```json
{
Expand All @@ -106,6 +120,22 @@ Here is an example dynamic config file:
}
```

Here is an example dynamic config file with weighted routing:

```json
{
"service_discovery": "static",
"routing_logic": "weighted",
"static_backends": "http://localhost:9001,http://localhost:9002,http://localhost:9003",
"static_models": "facebook/opt-125m,meta-llama/Llama-3.1-8B-Instruct,facebook/opt-125m",
"weights": {
"http://localhost:9001": 50,
"http://localhost:9002": 30,
"http://localhost:9003": 20
}
}
```

### Get current dynamic config

If the dynamic config is enabled, the router will reflect the current dynamic config in the `/health` endpoint.
Expand Down
10 changes: 9 additions & 1 deletion src/vllm_router/dynamic_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class DynamicRouterConfig:

# Routing logic configurations
session_key: Optional[str] = None
weights: Optional[dict] = (
None # Mapping of endpoint URLs to their weights for weighted routing
)

# Batch API configurations
# TODO (ApostaC): Support dynamic reconfiguration of batch API
Expand Down Expand Up @@ -64,6 +67,11 @@ def from_args(args) -> "DynamicRouterConfig":
# Routing logic configurations
routing_logic=args.routing_logic,
session_key=args.session_key,
weights=(
json.loads(args.weights)
if hasattr(args, "weights") and args.weights
else None
),
)

@staticmethod
Expand Down Expand Up @@ -139,7 +147,7 @@ def reconfigure_routing_logic(self, config: DynamicRouterConfig):
Reconfigures the router with the given config.
"""
routing_logic = ReconfigureRoutingLogic(
config.routing_logic, session_key=config.session_key
config.routing_logic, session_key=config.session_key, weights=config.weights
)
self.app.state.router = routing_logic
logger.info(f"DynamicConfigWatcher: Routing logic reconfiguration complete")
Expand Down
20 changes: 16 additions & 4 deletions src/vllm_router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,11 @@ def validate_args(args):


def parse_args():
parser = argparse.ArgumentParser(description="Run the FastAPI app.")
parser = argparse.ArgumentParser(
description="vLLM router that routes requests to vLLM engines."
)

# Service discovery arguments
parser.add_argument(
"--host", default="0.0.0.0", help="The host to run the server on."
)
Expand Down Expand Up @@ -752,19 +756,27 @@ def parse_args():
default="",
help="The label selector to filter vLLM pods when using K8s service discovery.",
)

# Routing logic arguments
parser.add_argument(
"--routing-logic",
type=str,
required=True,
choices=["roundrobin", "session"],
help="The routing logic to use",
default="roundrobin",
choices=["roundrobin", "session", "weighted"],
help="The routing logic to use.",
)
parser.add_argument(
"--session-key",
type=str,
default=None,
help="The key (in the header) to identify a session.",
)
parser.add_argument(
"--weights",
type=str,
default=None,
help='JSON string mapping of endpoint URLs to their weights for weighted routing. Example: {"http://endpoint1:8000": 30, "http://endpoint2:8000": 70}',
)

# Batch API
# TODO(gaocegege): Make these batch api related arguments to a separate config.
Expand Down
Loading