diff --git a/src/vllm_router/affinity/factory.py b/src/vllm_router/affinity/factory.py deleted file mode 100644 index 86f4809..0000000 --- a/src/vllm_router/affinity/factory.py +++ /dev/null @@ -1,28 +0,0 @@ - -from vllm_router.affinity.base import BaseAffinity -from vllm_router.overload_detectors.base import BaseOverloadDetector - -import json -from logging import getLogger - -logger = getLogger(__name__) - -from vllm_router.affinity.round_robin_affinity import RoundRobinAffinity -from vllm_router.affinity.session_based_affinity import SessionBasedAffinity -from vllm_router.affinity.longest_prefix_affinity import LongestPrefixAffinity -from vllm_router.affinity.simhash_affinity import SimhashAffinity - -affinity_str_to_class = { - "round_robin": RoundRobinAffinity, - "session": SessionBasedAffinity, - "longest_prefix": LongestPrefixAffinity, - "simhash": SimhashAffinity, -} - -def get_affinity(affinity_type: str, affinity_config: str) -> BaseAffinity: - if affinity_type not in affinity_str_to_class: - raise ValueError(f"Invalid affinity type: {affinity_type}") - - kwargs = json.loads(affinity_config) - logger.info(f"Using affinity type: {affinity_type} with config: {kwargs}") - return affinity_str_to_class[affinity_type](**kwargs) diff --git a/src/vllm_router/overload_detector/factory.py b/src/vllm_router/overload_detector/factory.py deleted file mode 100644 index 92304e2..0000000 --- a/src/vllm_router/overload_detector/factory.py +++ /dev/null @@ -1,21 +0,0 @@ - -from vllm_router.overload_detectors.base import BaseOverloadDetector - -import json -from logging import getLogger - -logger = getLogger(__name__) - -from vllm_router.overload_detector.num_queued_requests import NumQueuedRequestsOverloadDetector - -overload_detector_str_to_class = { - "num_queued_requests": NumQueuedRequestsOverloadDetector, -} - -def get_overload_detector(overload_detector_type: str, overload_detector_config: str) -> BaseOverloadDetector: - if overload_detector_type not in overload_detector_str_to_class: - raise ValueError(f"Invalid overload detector type: {overload_detector_type}") - - kwargs = json.loads(overload_detector_config) - logger.info(f"Using overload detector type: {overload_detector_type} with config: {kwargs}") - return overload_detector_str_to_class[overload_detector_type](**kwargs) diff --git a/src/vllm_router/affinity/__init__.py b/src/vllm_router/routers/affinity/__init__.py similarity index 100% rename from src/vllm_router/affinity/__init__.py rename to src/vllm_router/routers/affinity/__init__.py diff --git a/src/vllm_router/affinity/base.py b/src/vllm_router/routers/affinity/base.py similarity index 95% rename from src/vllm_router/affinity/base.py rename to src/vllm_router/routers/affinity/base.py index 3157cf4..277c700 100644 --- a/src/vllm_router/affinity/base.py +++ b/src/vllm_router/routers/affinity/base.py @@ -53,6 +53,6 @@ def update_endpoints_stats( request_stats: Dict[str, RequestStats], ) -> None: """ - Update the endpoint stats. + Update the endpoint stats. This will not remove any endpoints. """ pass diff --git a/src/vllm_router/routers/affinity/factory.py b/src/vllm_router/routers/affinity/factory.py new file mode 100644 index 0000000..59a375d --- /dev/null +++ b/src/vllm_router/routers/affinity/factory.py @@ -0,0 +1,32 @@ + +from vllm_router.routers.affinity.base import BaseAffinity + +import json +from logging import getLogger + +logger = getLogger(__name__) + +from vllm_router.affinity.round_robin_affinity import RoundRobinAffinity +from vllm_router.affinity.session_based_affinity import SessionBasedAffinity +from vllm_router.affinity.longest_prefix_affinity import LongestPrefixAffinity +from vllm_router.affinity.simhash_affinity import SimhashAffinity + +affinity_name_to_class = { + "round_robin": RoundRobinAffinity, + "session": SessionBasedAffinity, + "longest_prefix": LongestPrefixAffinity, + "simhash": SimhashAffinity, +} + +def get_affinity(affinity_name: str, affinity_config: Dict[str, Any] = {}, **kwargs) -> BaseAffinity: + + if affinity_name not in affinity_name_to_class: + raise ValueError(f"Invalid affinity name: {affinity_name}") + + + assert kwargs == {}, ("There are extra kwargs forwarded to the affinity " + "factory method. This is likely unintended. " + "Received kwargs: %s" % kwargs) + + logger.info(f"Using affinity type: {affinity_name} with config: {affinity_config}") + return affinity_name_to_class[affinity_name](**affinity_config) diff --git a/src/vllm_router/affinity/longest_prefix_affinity.py b/src/vllm_router/routers/affinity/longest_prefix_affinity.py similarity index 98% rename from src/vllm_router/affinity/longest_prefix_affinity.py rename to src/vllm_router/routers/affinity/longest_prefix_affinity.py index 5ef35de..a0e8b4c 100644 --- a/src/vllm_router/affinity/longest_prefix_affinity.py +++ b/src/vllm_router/routers/affinity/longest_prefix_affinity.py @@ -110,6 +110,7 @@ def __init__(self, **kwargs): self.trie = HashTrie(chunk_size=chunk_size) self.chunk_size = chunk_size self.logger = logging.getLogger("LCPMatcher") + self.name = "longest_prefix_affinity" def get_high_affinity_endpoint( self, @@ -155,7 +156,5 @@ def update_endpoints_stats( engine_stats: Dict[str, EngineStats], request_stats: Dict[str, RequestStats], ) -> None: - """ - Set the endpoints for the trie. - """ - self.trie.root.endpoints = endpoints + pass + diff --git a/src/vllm_router/affinity/round_robin_affinity.py b/src/vllm_router/routers/affinity/round_robin_affinity.py similarity index 86% rename from src/vllm_router/affinity/round_robin_affinity.py rename to src/vllm_router/routers/affinity/round_robin_affinity.py index 1de4f96..e9d0350 100644 --- a/src/vllm_router/affinity/round_robin_affinity.py +++ b/src/vllm_router/routers/affinity/round_robin_affinity.py @@ -21,21 +21,18 @@ def __init__( **kwargs ): self.index = 0 - self.endpoints = set() - + self.name = "round_robin_affinity" def get_high_affinity_endpoint( self, request: Request, request_json: Dict[str, Any], - unavailable_endpoints: Set[str], + available_endpoints: Set[str], ) -> str: - - available_endpoints = list(self.endpoints - unavailable_endpoints) - if not available_endpoints: raise ValueError(f"No available endpoints for request: {request}") + available_endpoints = list(available_endpoints) endpoint = available_endpoints[self.index % len(available_endpoints)] self.index = self.index + 1 return endpoint @@ -49,5 +46,4 @@ def update_endpoints_stats( engine_stats: Dict[str, EngineStats], request_stats: Dict[str, RequestStats], ) -> None: - - self.endpoints = endpoints \ No newline at end of file + pass diff --git a/src/vllm_router/affinity/session_affinity.py b/src/vllm_router/routers/affinity/session_affinity.py similarity index 78% rename from src/vllm_router/affinity/session_affinity.py rename to src/vllm_router/routers/affinity/session_affinity.py index 0eb731c..3bd7b77 100644 --- a/src/vllm_router/affinity/session_affinity.py +++ b/src/vllm_router/routers/affinity/session_affinity.py @@ -8,32 +8,31 @@ def __init__( self, **kwargs ): - if hasattr(self, "_initialized"): - return if "session_key" not in kwargs: raise ValueError("Using session affinity without specifying " "session_key in affinity config. Please specify a session_key.") self.session_key = kwargs["session_key"] self.hash_ring = HashRing() - self._initialized = True + self.name = "session_affinity" def get_high_affinity_endpoint( self, request: Request, request_json: Dict[str, Any], - unavailable_endpoints: Set[str], + available_endpoints: Set[str], ) -> str: - assert unavailable_endpoints.issubset(self.endpoints) - + assert available_endpoints.issubset(self.endpoints), ( + f"Available endpoints must be a subset of the endpoints in the hash" + f"ring. \nAvailable endpoints: {available_endpoints} \n" + f"Endpoints in hash ring: {self.endpoints}\n" + ) session_id = request.headers.get(self.session_key, None) - # Iterate through nodes starting from the hash position for endpoint in self.hash_ring.iterate_nodes(str(session_id), distinct=True): - - if endpoint not in unavailable_endpoints: + if endpoint in available_endpoints: return endpoint raise ValueError(f"No endpoint found for request: {request}") @@ -58,10 +57,6 @@ def update_endpoints_stats( # Convert the new endpoint URLs to a set for easy comparison new_nodes = endpoints - # Remove nodes that are no longer in the list - for node in current_nodes - new_nodes: - self.hash_ring.remove_node(node) - # Add new nodes that are not already in the hash ring for node in new_nodes - current_nodes: self.hash_ring.add_node(node) diff --git a/src/vllm_router/affinity/simhash_affinity.py b/src/vllm_router/routers/affinity/simhash_affinity.py similarity index 98% rename from src/vllm_router/affinity/simhash_affinity.py rename to src/vllm_router/routers/affinity/simhash_affinity.py index 4174b05..1527ea8 100644 --- a/src/vllm_router/affinity/simhash_affinity.py +++ b/src/vllm_router/routers/affinity/simhash_affinity.py @@ -54,6 +54,7 @@ def __init__( self.hash_ring = HashRing() self.hash_func = hash_type.get_hash_func(max_length=max_length) self.endpoints = set() + self.name = "simhash_affinity" def get_high_affinity_endpoint( diff --git a/src/vllm_router/overload_detector/__init__.py b/src/vllm_router/routers/endpoint_filter/__init__.py similarity index 100% rename from src/vllm_router/overload_detector/__init__.py rename to src/vllm_router/routers/endpoint_filter/__init__.py diff --git a/src/vllm_router/overload_detector/base.py b/src/vllm_router/routers/endpoint_filter/base.py similarity index 62% rename from src/vllm_router/overload_detector/base.py rename to src/vllm_router/routers/endpoint_filter/base.py index c4884bd..df4faf9 100644 --- a/src/vllm_router/overload_detector/base.py +++ b/src/vllm_router/routers/endpoint_filter/base.py @@ -3,15 +3,15 @@ from typing import Set, Dict import abc -class BaseOverloadDetector(metaclass=abc.ABCMeta): +class BaseEndpointFilter(metaclass=abc.ABCMeta): @abc.abstractmethod - def get_overload_endpoints( + def get_filtered_endpoints( self, endpoints: Set[str], request_stats: Dict[str, RequestStats], engine_stats: Dict[str, EngineStats], ) -> Set[str]: """ - Check if the endpoint is overloaded. + Filter the endpoints based on the request stats and engine stats. """ - pass \ No newline at end of file + pass diff --git a/src/vllm_router/routers/endpoint_filter/factory.py b/src/vllm_router/routers/endpoint_filter/factory.py new file mode 100644 index 0000000..a5ea05b --- /dev/null +++ b/src/vllm_router/routers/endpoint_filter/factory.py @@ -0,0 +1,24 @@ + +from vllm_router.routers.endpoint_filter.base import BaseEndpointFilter + +import json +from logging import getLogger + +logger = getLogger(__name__) + +from vllm_router.routers.endpoint_filter.num_queueing_request_filter import NumQueueingRequestFilter + +endpoint_filter_name_to_class = { + "num_queueing_request_filter": NumQueueingRequestFilter, +} + +def get_endpoint_filter(endpoint_filter_name: str, endpoint_filter_config: Dict[str, Any] = {}, **kwargs) -> BaseEndpointFilter: + if endpoint_filter_name not in endpoint_filter_name_to_class: + raise ValueError(f"Invalid endpoint filter name: {endpoint_filter_name}") + + assert kwargs == {}, ("There are extra kwargs forwarded to the endpoint filter " + "factory method. This is likely unintended. " + "Received kwargs: %s" % kwargs) + + logger.info(f"Using endpoint filter type: {endpoint_filter_name} with config: {endpoint_filter_config}") + return endpoint_filter_name_to_class[endpoint_filter_name](**endpoint_filter_config) diff --git a/src/vllm_router/overload_detector/num_queueing_request.py b/src/vllm_router/routers/endpoint_filter/num_queueing_request_filter.py similarity index 69% rename from src/vllm_router/overload_detector/num_queueing_request.py rename to src/vllm_router/routers/endpoint_filter/num_queueing_request_filter.py index a244722..e6bbf99 100644 --- a/src/vllm_router/overload_detector/num_queueing_request.py +++ b/src/vllm_router/routers/endpoint_filter/num_queueing_request_filter.py @@ -1,26 +1,27 @@ -from vllm_router.load_metrics.base import BaseLoadMetric +from vllm_router.routers.endpoint_filter.base import BaseEndpointFilter from vllm_router.types import RequestStats, EngineStats import logging logger = logging.getLogger(__name__) -class NumQueueingRequest(BaseOverloadDetector): +class NumQueueingRequestFilter(BaseEndpointFilter): def __init__(self, **kwargs): if "percentile" not in kwargs: - logger.warning("Using num_queueing_request overload detector " - "without specifying percentile in overload detector config." + logger.warning("Using num_queueing_request endpoint filter " + "without specifying percentile in endpoint filter config." "Setting percentile to default value: 0.9") percentile = 0.9 else: percentile = kwargs["percentile"] self.percentile = percentile + self.name = "num_queueing_request_filter" - def get_overload_endpoints( + def get_filtered_endpoints( self, endpoints: Set[str], request_stats: Dict[str, RequestStats], diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index a0b35be..814c934 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -11,12 +11,10 @@ from vllm_router.stats.request_stats import RequestStats from vllm_router.utils import SingletonABCMeta -logger = init_logger(__name__) - +from vllm_router.routers.affinity.factory import get_affinity +from vllm_router.routers.endpoint_filter.factory import get_endpoint_filter -class RoutingLogic(str, enum.Enum): - ROUND_ROBIN = "roundrobin" - SESSION_BASED = "session" +logger = init_logger(__name__) class RoutingInterface(metaclass=SingletonABCMeta): @@ -27,6 +25,7 @@ def route_request( engine_stats: Dict[str, EngineStats], request_stats: Dict[str, RequestStats], request: Request, + request_json: Dict[str, Any], ) -> str: """ Route the request to the appropriate engine URL @@ -42,60 +41,105 @@ def route_request( raise NotImplementedError -class RoundRobinRouter(RoutingInterface): - # TODO (ApostaC): when available engines in the endpoints changes, the - # algorithm may not be "perfectly" round-robin. - def __init__(self): +class Router(RoutingInterface): + + def __init__( + self, + **kwargs: Dict[str, Any], + ): + if hasattr(self, "_initialized"): return - self.req_id = 0 + + self.reconfigure(**kwargs) + self.initialized = True + + def reconfigure(self, **kwargs: Dict[str, Any]): + + # Initialize the affinity module + self.affinity = None + if "affinity" not in kwargs: + logger.warning("No affinity specified, using simple round-robin logic to select endpoints") + self.affinity = get_affinity("round_robin", {}) + else: + self.affinity = get_affinity(**kwargs["affinity"]) + + # Initialize the endpoint filters + self.endpoint_filters = [] + if "endpoint_filters" not in kwargs: + logger.info("No endpoint filters specified.") + else: + for endpoint_filter_kwargs in kwargs["endpoint_filters"]: + self.endpoint_filters.append(get_endpoint_filter(**endpoint_filter_kwargs)) + self._initialized = True + def route_request( self, endpoints: List[EndpointInfo], engine_stats: Dict[str, EngineStats], request_stats: Dict[str, RequestStats], request: Request, + request_json: Dict[str, Any], ) -> str: - - overload_endpoints = self.overload_detector.get_overload_endpoints(endpoints, request_stats, engine_stats) - self.affinity.update_endpoints_stats(overload_endpoints, engine_stats, request_stats) - url = self.affinity.get_high_affinity_endpoint(request, request_json, overload_endpoints) + self.affinity.update_endpoints_stats(endpoints, engine_stats, request_stats) + + endpoints = set(endpoint.url for endpoint in endpoints) + assert endpoints, "No endpoints provided for the routing logic." - self.affinity.on_request_routed(url, request_stats, engine_stats) + for endpoint_filter in self.endpoint_filters: + previous_endpoints = endpoints + endpoints = endpoint_filter.get_filtered_endpoints( + endpoints, + request_stats, + engine_stats + ) + if not endpoints: + logger.warning(f"Endpoint filter {endpoint_filter.name} " + f"removed all endpoints from " + f"{previous_endpoints}. Reverting to previous " + f"endpoints and skipping all remaining " + f"endpoint filters.") + endpoints = previous_endpoints + break - return url + selected_endpoint = self.affinity.get_high_affinity_endpoint( + request, + request_json, + endpoints + ) + self.affinity.on_request_routed( + request, + request_json, + selected_endpoint + ) + + return selected_endpoint + + +_router = None # Instead of managing a global _global_router, we can define the initialization functions as: def initialize_routing_logic( - routing_logic: RoutingLogic, *args, **kwargs + **kwargs ) -> RoutingInterface: - if routing_logic == RoutingLogic.ROUND_ROBIN: - logger.info("Initializing round-robin routing logic") - return RoundRobinRouter() - elif routing_logic == RoutingLogic.SESSION_BASED: - logger.info(f"Initializing session-based routing logic with kwargs: {kwargs}") - return SessionRouter(kwargs.get("session_key")) - else: - raise ValueError(f"Invalid routing logic {routing_logic}") + + assert _router is None, "Routing logic already initialized" + _router = Router(**kwargs) + return _router def reconfigure_routing_logic( - routing_logic: RoutingLogic, *args, **kwargs + **kwargs ) -> RoutingInterface: - # Remove the existing routers from the singleton registry - for cls in (SessionRouter, RoundRobinRouter): - if cls in SingletonABCMeta._instances: - del SingletonABCMeta._instances[cls] - return initialize_routing_logic(routing_logic, *args, **kwargs) + _router.reconfigure(**kwargs) + return _router def get_routing_logic() -> RoutingInterface: - # Look up in our singleton registry which router (if any) has been created. - for cls in (SessionRouter, RoundRobinRouter): - if cls in SingletonABCMeta._instances: - return cls() - raise ValueError("The global router has not been initialized") + assert _router is not None, ("Routing logic not initialized. " + "Please call initialize_routing_logic() first.") + return _router