Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Prefix-aware routing and load balancing #239

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/vllm_router/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ vllm-router --port 8000 \
--static-models "facebook/opt-125m,meta-llama/Llama-3.1-8B-Instruct,facebook/opt-125m" \
--engine-stats-interval 10 \
--log-stats \
--routing-logic roundrobin
--routing-logic longest_prefix
```

## Dynamic Router Config
Expand Down
8 changes: 7 additions & 1 deletion src/vllm_router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,13 @@ def initialize_all(app: FastAPI, args):
args.batch_processor, args.file_storage_path, app.state.batch_storage
)

initialize_routing_logic(args.routing_logic, session_key=args.session_key)
initialize_routing_logic(
args.routing_logic,
args.session_key,
args.routing_logic_config,
args.endpoint_filters,
args.endpoint_filters_configs,
)

# Initialize feature gates
initialize_feature_gates(args.feature_gates)
Expand Down
35 changes: 31 additions & 4 deletions src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def validate_args(args):
raise ValueError("K8s port must be provided when using K8s service discovery.")
if args.routing_logic == "session" and args.session_key is None:
raise ValueError(
"Session key must be provided when using session routing logic."
"Session key must be provided when using session routing affinity."
)
if args.log_stats and args.log_stats_interval <= 0:
raise ValueError("Log stats interval must be greater than 0.")
Expand Down Expand Up @@ -99,14 +99,41 @@ def parse_args():
"--routing-logic",
type=str,
required=True,
choices=["roundrobin", "session"],
help="The routing logic to use",
choices=["roundrobin", "session", "longest_prefix", "simhash"],
help="The routing affinity to use.",
)
parser.add_argument(
"--session-key",
type=str,
default=None,
help="The key (in the header) to identify a session.",
help="The key (in the header) to identify a session. This is a shortcut"
" for --routing-logic-config "
'\'{"session_key": "<session_key>"}\'.',
)
parser.add_argument(
"--routing-logic-config",
type=str,
default="{}",
help="The routing configuration in JSON format.",
)
parser.add_argument(
"--endpoint-filters",
nargs="+",
default=[],
choices=["num_queueing_request"],
help="Tndpoint filters to use. Example usage: "
"--endpoint-filters num_queueing_request other_filter_A "
"other_filter_B ...",
)
parser.add_argument(
"--endpoint-filters-configs",
nargs="+",
default=[],
help="The configurations for endpoint filters, in JSON format. "
"Example usage: "
"--endpoint-filters-configs "
"'{\"percentile\": 0.9}' "
"other_filter_config_A other_filter_config_B ...",
)

# Batch API
Expand Down
255 changes: 121 additions & 134 deletions src/vllm_router/routers/routing_logic.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import abc
import enum
from typing import Dict, List
import json
from typing import Any, Dict, List

from fastapi import Request
from uhashring import HashRing

from vllm_router.log import init_logger
from vllm_router.service_discovery import EndpointInfo
from vllm_router.services.routing_service.affinity.factory import get_affinity
from vllm_router.services.routing_service.endpoint_filter.factory import (
get_endpoint_filter,
)
from vllm_router.stats.engine_stats import EngineStats
from vllm_router.stats.request_stats import RequestStats
from vllm_router.utils import SingletonABCMeta

logger = init_logger(__name__)


class RoutingLogic(str, enum.Enum):
ROUND_ROBIN = "roundrobin"
SESSION_BASED = "session"


class RoutingInterface(metaclass=SingletonABCMeta):
@abc.abstractmethod
def route_request(
Expand All @@ -27,6 +27,7 @@ def route_request(
engine_stats: Dict[str, EngineStats],
request_stats: Dict[str, RequestStats],
request: Request,
request_json: Dict[str, Any],
) -> str:
"""
Route the request to the appropriate engine URL
Expand All @@ -42,163 +43,149 @@ def route_request(
raise NotImplementedError


class RoundRobinRouter(RoutingInterface):
# TODO (ApostaC): when available engines in the endpoints changes, the
# algorithm may not be "perfectly" round-robin.
def __init__(self):
if hasattr(self, "_initialized"):
return
self.req_id = 0
self._initialized = True
class Router(RoutingInterface):

def route_request(
def __init__(
self,
endpoints: List[EndpointInfo],
engine_stats: Dict[str, EngineStats],
request_stats: Dict[str, RequestStats],
request: Request,
) -> str:
"""
Route the request to the appropriate engine URL using a simple
round-robin algorithm

Args:
endpoints (List[EndpointInfo]): The list of engine URLs
engine_stats (Dict[str, EngineStats]): The engine stats indicating
the 'physical' load of each engine
request_stats (Dict[str, RequestStats]): The request stats
indicating the request-level performance of each engine
request (Request): The incoming request
"""
len_engines = len(endpoints)
chosen = sorted(endpoints, key=lambda e: e.url)[self.req_id % len_engines]
self.req_id += 1
return chosen.url


class SessionRouter(RoutingInterface):
"""
Route the request to the appropriate engine URL based on the session key
in the request headers
"""
routing_logic: str,
routing_logic_config: Dict[str, Any],
endpoint_filters: List[str],
endpoint_filters_configs: List[Dict[str, Any]],
):

def __init__(self, session_key: str = None):
if hasattr(self, "_initialized"):
return
if session_key is None:
raise ValueError("SessionRouter must be initialized with a session_key")
self.session_key = session_key
self.hash_ring = HashRing()
self._initialized = True

def _qps_routing(
self, endpoints: List[EndpointInfo], request_stats: Dict[str, RequestStats]
) -> str:
"""
Route the request to the appropriate engine URL based on the QPS of
each engine
self.reconfigure(
routing_logic=routing_logic,
routing_logic_config=routing_logic_config,
endpoint_filters=endpoint_filters,
endpoint_filters_configs=endpoint_filters_configs,
)
self.initialized = True

Args:
request_stats (Dict[str, RequestStats]): The request stats
indicating the request-level performance of each engine
"""
lowest_qps = float("inf")
ret = None
for info in endpoints:
url = info.url
if url not in request_stats:
return url # This engine does not have any requests
request_stat = request_stats[url]
if request_stat.qps < lowest_qps:
lowest_qps = request_stat.qps
ret = url
return ret

def _update_hash_ring(self, endpoints: List["EndpointInfo"]):
"""
Update the hash ring with the current list of endpoints.
"""
# Extract endpoint URLs
endpoint_urls = [endpoint.url for endpoint in endpoints]

# Get the current nodes in the hash ring
current_nodes = set(self.hash_ring.get_nodes())

# Convert the new endpoint URLs to a set for easy comparison
new_nodes = set(endpoint_urls)

# Remove nodes that are no longer in the list
for node in current_nodes - new_nodes:
self.hash_ring.remove_node(node)
def reconfigure(
self,
routing_logic: str,
routing_logic_config: str,
endpoint_filters: List[str],
endpoint_filters_configs: List[str],
):

# Initialize the affinity module
self.affinity = None

routing_logic_config = json.loads(routing_logic_config)
self.affinity = get_affinity(routing_logic, **routing_logic_config)

# Initialize the endpoint filters
self.endpoint_filters = []

assert len(endpoint_filters) == len(endpoint_filters_configs), (
"The number of items in endpoint filters and endpoint filter "
"configs must be the same"
)

for endpoint_filter_name, endpoint_filter_config in zip(
endpoint_filters, endpoint_filters_configs
):
self.endpoint_filters.append(
get_endpoint_filter(
endpoint_filter_name, **json.loads(endpoint_filter_config)
)
)

# Add new nodes that are not already in the hash ring
for node in new_nodes - current_nodes:
self.hash_ring.add_node(node)
self._initialized = True

def route_request(
self,
endpoints: List[EndpointInfo],
engine_stats: Dict[str, EngineStats],
request_stats: Dict[str, RequestStats],
request: Request,
request_json: Dict[str, Any],
) -> str:
"""
Route the request to the appropriate engine URL by the 'session id' in
the request headers.
If there is no session id in the request header, it will pick a server
with lowest qps

Args:
endpoints (List[EndpointInfo]): The list of engine URLs
engine_stats (Dict[str, EngineStats]): The engine stats indicating
the 'physical' load of each engine
request_stats (Dict[str, RequestStats]): The request stats
indicating the request-level performance of each engine
request (Request): The incoming request
"""
session_id = request.headers.get(self.session_key, None)
logger.debug(f"Got session id: {session_id}")
endpoints = set(endpoint.url for endpoint in endpoints)
assert endpoints, "No endpoints provided for the routing logic."

for endpoint_filter in self.endpoint_filters:
previous_endpoints = endpoints
endpoints = endpoint_filter.get_filtered_endpoints(
endpoints, request_stats, engine_stats
)
if not endpoints:
logger.warning(
f"Endpoint filter {endpoint_filter.name} "
f"removed all endpoints from "
f"{previous_endpoints}. Reverting to previous "
f"endpoints and skipping all remaining "
f"endpoint filters."
)
endpoints = previous_endpoints
break

# NOTE(Kuntai): Only update the endpoint stats for the candidate
# endpoints instead of all endpoints.
# Another design is to actually update the endpoint stats for all
# endpoints. But I don't see that there is a strong reason to do so.
self.affinity.update_endpoints_stats(endpoints, engine_stats, request_stats)

# Update the hash ring with the current list of endpoints
self._update_hash_ring(endpoints)
selected_endpoint = self.affinity.get_high_affinity_endpoint(
request, request_json, endpoints
)

if session_id is None:
# Route based on QPS if no session ID is present
url = self._qps_routing(endpoints, request_stats)
else:
# Use the hash ring to get the endpoint for the session ID
url = self.hash_ring.get_node(session_id)
self.affinity.on_request_routed(request, request_json, selected_endpoint)

return url
return selected_endpoint


_router = None


# Instead of managing a global _global_router, we can define the initialization functions as:
def initialize_routing_logic(
routing_logic: RoutingLogic, *args, **kwargs
routing_logic: str,
session_key: str,
routing_logic_config: str,
endpoint_filters: List[str],
endpoint_filters_configs: str,
) -> RoutingInterface:
if routing_logic == RoutingLogic.ROUND_ROBIN:
logger.info("Initializing round-robin routing logic")
return RoundRobinRouter()
elif routing_logic == RoutingLogic.SESSION_BASED:
logger.info(f"Initializing session-based routing logic with kwargs: {kwargs}")
return SessionRouter(kwargs.get("session_key"))
else:
raise ValueError(f"Invalid routing logic {routing_logic}")

global _router
assert _router is None, "Routing logic already initialized"
if routing_logic == "session":
routing_logic_config.update({"session_key": session_key})
_router = Router(
routing_logic=routing_logic,
routing_logic_config=routing_logic_config,
endpoint_filters=endpoint_filters,
endpoint_filters_configs=endpoint_filters_configs,
)
return _router


def reconfigure_routing_logic(
routing_logic: RoutingLogic, *args, **kwargs
routing_logic: str,
session_key: str,
routing_logic_config: str,
endpoint_filters: List[str],
endpoint_filters_configs: str,
) -> RoutingInterface:
# Remove the existing routers from the singleton registry
for cls in (SessionRouter, RoundRobinRouter):
if cls in SingletonABCMeta._instances:
del SingletonABCMeta._instances[cls]
return initialize_routing_logic(routing_logic, *args, **kwargs)
global _router
_router.reconfigure(
routing_logic=routing_logic,
routing_logic_config=routing_logic_config,
endpoint_filters=endpoint_filters,
endpoint_filters_configs=endpoint_filters_configs,
)
return _router


def get_routing_logic() -> RoutingInterface:
# Look up in our singleton registry which router (if any) has been created.
for cls in (SessionRouter, RoundRobinRouter):
if cls in SingletonABCMeta._instances:
return cls()
raise ValueError("The global router has not been initialized")
global _router
assert _router is not None, (
"Routing logic not initialized. "
"Please call initialize_routing_logic() first."
)
return _router
Loading
Loading