Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Prefix-aware routing and load balancing #239

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def parse_args():
default=None,
help="The key (in the header) to identify a session.",
)
parser.add_argument(
"--routing-config",
type=str,
default="{}",
help="The routing configuration in JSON format.",
)

# Batch API
# TODO(gaocegege): Make these batch api related arguments to a separate config.
Expand Down
203 changes: 72 additions & 131 deletions src/vllm_router/routers/routing_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@
from vllm_router.stats.request_stats import RequestStats
from vllm_router.utils import SingletonABCMeta

logger = init_logger(__name__)

from vllm_router.routers.affinity.factory import get_affinity
from vllm_router.routers.endpoint_filter.factory import get_endpoint_filter

class RoutingLogic(str, enum.Enum):
ROUND_ROBIN = "roundrobin"
SESSION_BASED = "session"
logger = init_logger(__name__)


class RoutingInterface(metaclass=SingletonABCMeta):
Expand All @@ -27,6 +25,7 @@ def route_request(
engine_stats: Dict[str, EngineStats],
request_stats: Dict[str, RequestStats],
request: Request,
request_json: Dict[str, Any],
) -> str:
"""
Route the request to the appropriate engine URL
Expand All @@ -42,163 +41,105 @@ def route_request(
raise NotImplementedError


class RoundRobinRouter(RoutingInterface):
# TODO (ApostaC): when available engines in the endpoints changes, the
# algorithm may not be "perfectly" round-robin.
def __init__(self):
if hasattr(self, "_initialized"):
return
self.req_id = 0
self._initialized = True

def route_request(
self,
endpoints: List[EndpointInfo],
engine_stats: Dict[str, EngineStats],
request_stats: Dict[str, RequestStats],
request: Request,
) -> str:
"""
Route the request to the appropriate engine URL using a simple
round-robin algorithm

Args:
endpoints (List[EndpointInfo]): The list of engine URLs
engine_stats (Dict[str, EngineStats]): The engine stats indicating
the 'physical' load of each engine
request_stats (Dict[str, RequestStats]): The request stats
indicating the request-level performance of each engine
request (Request): The incoming request
"""
len_engines = len(endpoints)
chosen = sorted(endpoints, key=lambda e: e.url)[self.req_id % len_engines]
self.req_id += 1
return chosen.url

class Router(RoutingInterface):

class SessionRouter(RoutingInterface):
"""
Route the request to the appropriate engine URL based on the session key
in the request headers
"""
def __init__(
self,
**kwargs: Dict[str, Any],
):

def __init__(self, session_key: str = None):
if hasattr(self, "_initialized"):
return
if session_key is None:
raise ValueError("SessionRouter must be initialized with a session_key")
self.session_key = session_key
self.hash_ring = HashRing()
self._initialized = True

def _qps_routing(
self, endpoints: List[EndpointInfo], request_stats: Dict[str, RequestStats]
) -> str:
"""
Route the request to the appropriate engine URL based on the QPS of
each engine
self.reconfigure(**kwargs)
self.initialized = True

Args:
request_stats (Dict[str, RequestStats]): The request stats
indicating the request-level performance of each engine
"""
lowest_qps = float("inf")
ret = None
for info in endpoints:
url = info.url
if url not in request_stats:
return url # This engine does not have any requests
request_stat = request_stats[url]
if request_stat.qps < lowest_qps:
lowest_qps = request_stat.qps
ret = url
return ret

def _update_hash_ring(self, endpoints: List["EndpointInfo"]):
"""
Update the hash ring with the current list of endpoints.
"""
# Extract endpoint URLs
endpoint_urls = [endpoint.url for endpoint in endpoints]
def reconfigure(self, **kwargs: Dict[str, Any]):

# Get the current nodes in the hash ring
current_nodes = set(self.hash_ring.get_nodes())
# Initialize the affinity module
self.affinity = None
if "affinity" not in kwargs:
logger.warning("No affinity specified, using simple round-robin logic to select endpoints")
self.affinity = get_affinity("round_robin", {})
else:
self.affinity = get_affinity(**kwargs["affinity"])

# Convert the new endpoint URLs to a set for easy comparison
new_nodes = set(endpoint_urls)
# Initialize the endpoint filters
self.endpoint_filters = []
if "endpoint_filters" not in kwargs:
logger.info("No endpoint filters specified.")
else:
for endpoint_filter_kwargs in kwargs["endpoint_filters"]:
self.endpoint_filters.append(get_endpoint_filter(**endpoint_filter_kwargs))

# Remove nodes that are no longer in the list
for node in current_nodes - new_nodes:
self.hash_ring.remove_node(node)
self._initialized = True

# Add new nodes that are not already in the hash ring
for node in new_nodes - current_nodes:
self.hash_ring.add_node(node)

def route_request(
self,
endpoints: List[EndpointInfo],
engine_stats: Dict[str, EngineStats],
request_stats: Dict[str, RequestStats],
request: Request,
request_json: Dict[str, Any],
) -> str:
"""
Route the request to the appropriate engine URL by the 'session id' in
the request headers.
If there is no session id in the request header, it will pick a server
with lowest qps

Args:
endpoints (List[EndpointInfo]): The list of engine URLs
engine_stats (Dict[str, EngineStats]): The engine stats indicating
the 'physical' load of each engine
request_stats (Dict[str, RequestStats]): The request stats
indicating the request-level performance of each engine
request (Request): The incoming request
"""
session_id = request.headers.get(self.session_key, None)
logger.debug(f"Got session id: {session_id}")
self.affinity.update_endpoints_stats(endpoints, engine_stats, request_stats)

# Update the hash ring with the current list of endpoints
self._update_hash_ring(endpoints)
endpoints = set(endpoint.url for endpoint in endpoints)
assert endpoints, "No endpoints provided for the routing logic."

if session_id is None:
# Route based on QPS if no session ID is present
url = self._qps_routing(endpoints, request_stats)
else:
# Use the hash ring to get the endpoint for the session ID
url = self.hash_ring.get_node(session_id)
for endpoint_filter in self.endpoint_filters:
previous_endpoints = endpoints
endpoints = endpoint_filter.get_filtered_endpoints(
endpoints,
request_stats,
engine_stats
)
if not endpoints:
logger.warning(f"Endpoint filter {endpoint_filter.name} "
f"removed all endpoints from "
f"{previous_endpoints}. Reverting to previous "
f"endpoints and skipping all remaining "
f"endpoint filters.")
endpoints = previous_endpoints
break

selected_endpoint = self.affinity.get_high_affinity_endpoint(
request,
request_json,
endpoints
)

return url
self.affinity.on_request_routed(
request,
request_json,
selected_endpoint
)

return selected_endpoint


_router = None

# Instead of managing a global _global_router, we can define the initialization functions as:
def initialize_routing_logic(
routing_logic: RoutingLogic, *args, **kwargs
**kwargs
) -> RoutingInterface:
if routing_logic == RoutingLogic.ROUND_ROBIN:
logger.info("Initializing round-robin routing logic")
return RoundRobinRouter()
elif routing_logic == RoutingLogic.SESSION_BASED:
logger.info(f"Initializing session-based routing logic with kwargs: {kwargs}")
return SessionRouter(kwargs.get("session_key"))
else:
raise ValueError(f"Invalid routing logic {routing_logic}")

assert _router is None, "Routing logic already initialized"
_router = Router(**kwargs)
return _router


def reconfigure_routing_logic(
routing_logic: RoutingLogic, *args, **kwargs
**kwargs
) -> RoutingInterface:
# Remove the existing routers from the singleton registry
for cls in (SessionRouter, RoundRobinRouter):
if cls in SingletonABCMeta._instances:
del SingletonABCMeta._instances[cls]
return initialize_routing_logic(routing_logic, *args, **kwargs)
_router.reconfigure(**kwargs)
return _router


def get_routing_logic() -> RoutingInterface:
# Look up in our singleton registry which router (if any) has been created.
for cls in (SessionRouter, RoundRobinRouter):
if cls in SingletonABCMeta._instances:
return cls()
raise ValueError("The global router has not been initialized")
assert _router is not None, ("Routing logic not initialized. "
"Please call initialize_routing_logic() first.")
return _router
Empty file.
59 changes: 59 additions & 0 deletions src/vllm_router/services/routing_service/affinity/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Abstract class for best endpoint selector.
"""

import abc
from typing import Set, Dict, Any
from fastapi import Request
from vllm_router.types import EngineStats, RequestStats

class BaseAffinityMaintainer(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_high_affinity_endpoint(
self,
request: Request,
request_json: Dict[str, Any],
unavailable_endpoints: Set[str] = set(),
) -> str:
"""
Get the endpoint with the highest affinity for the request.
If there are multiple endpoints with the same affinity, return one of them randomly.

Args:
request (Request): The request.
request_json (Dict[str, Any]): The jsonized request body.
unavailable_endpoints (Set[str]): The endpoints that are temporarily unavailable.

Returns:
str: The endpoint with the highest affinity for the request.
"""
pass

@abc.abstractmethod
def on_request_routed(
self,
request: Request,
request_json: Dict[str, Any],
endpoint: str,
) -> None:
"""
Notify the affinity maintainer that the request is actually routed to the endpoint.

Args:
request (Request): The request.
request_json (Dict[str, Any]): The jsonized request body.
endpoint (str): The endpoint that is actually routed to.
"""
pass

@abc.abstractmethod
def update_endpoints_stats(
self,
endpoints: Set[str],
engine_stats: Dict[str, EngineStats],
request_stats: Dict[str, RequestStats],
) -> None:
"""
Update the endpoint stats. This will not remove any endpoints.
"""
pass
32 changes: 32 additions & 0 deletions src/vllm_router/services/routing_service/affinity/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

from vllm_router.services.routing_service.affinity.base import BaseAffinity

import json
from logging import getLogger

logger = getLogger(__name__)

from vllm_router.services.routing_service.affinity.round_robin_affinity import RoundRobinAffinity
from vllm_router.services.routing_service.affinity.session_based_affinity import SessionBasedAffinity
from vllm_router.services.routing_service.affinity.longest_prefix_affinity import LongestPrefixAffinity
from vllm_router.services.routing_service.affinity.simhash_affinity import SimhashAffinity

affinity_name_to_class = {
"round_robin": RoundRobinAffinity,
"session": SessionBasedAffinity,
"longest_prefix": LongestPrefixAffinity,
"simhash": SimhashAffinity,
}

def get_affinity(affinity_name: str, affinity_config: Dict[str, Any] = {}, **kwargs) -> BaseAffinity:

if affinity_name not in affinity_name_to_class:
raise ValueError(f"Invalid affinity name: {affinity_name}")


assert kwargs == {}, ("There are extra kwargs forwarded to the affinity "
"factory method. This is likely unintended. "
"Received kwargs: %s" % kwargs)

logger.info(f"Using affinity type: {affinity_name} with config: {affinity_config}")
return affinity_name_to_class[affinity_name](**affinity_config)
Loading
Loading