Skip to content

Commit

Permalink
Add first set of APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
KuntaiDu committed Mar 5, 2025
1 parent 8e39751 commit 061458c
Show file tree
Hide file tree
Showing 14 changed files with 163 additions and 120 deletions.
28 changes: 0 additions & 28 deletions src/vllm_router/affinity/factory.py

This file was deleted.

21 changes: 0 additions & 21 deletions src/vllm_router/overload_detector/factory.py

This file was deleted.

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ def update_endpoints_stats(
request_stats: Dict[str, RequestStats],
) -> None:
"""
Update the endpoint stats.
Update the endpoint stats. This will not remove any endpoints.
"""
pass
32 changes: 32 additions & 0 deletions src/vllm_router/routers/affinity/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

from vllm_router.routers.affinity.base import BaseAffinity

import json
from logging import getLogger

logger = getLogger(__name__)

from vllm_router.affinity.round_robin_affinity import RoundRobinAffinity
from vllm_router.affinity.session_based_affinity import SessionBasedAffinity
from vllm_router.affinity.longest_prefix_affinity import LongestPrefixAffinity
from vllm_router.affinity.simhash_affinity import SimhashAffinity

affinity_name_to_class = {
"round_robin": RoundRobinAffinity,
"session": SessionBasedAffinity,
"longest_prefix": LongestPrefixAffinity,
"simhash": SimhashAffinity,
}

def get_affinity(affinity_name: str, affinity_config: Dict[str, Any] = {}, **kwargs) -> BaseAffinity:

if affinity_name not in affinity_name_to_class:
raise ValueError(f"Invalid affinity name: {affinity_name}")


assert kwargs == {}, ("There are extra kwargs forwarded to the affinity "
"factory method. This is likely unintended. "
"Received kwargs: %s" % kwargs)

logger.info(f"Using affinity type: {affinity_name} with config: {affinity_config}")
return affinity_name_to_class[affinity_name](**affinity_config)
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(self, **kwargs):
self.trie = HashTrie(chunk_size=chunk_size)
self.chunk_size = chunk_size
self.logger = logging.getLogger("LCPMatcher")
self.name = "longest_prefix_affinity"

def get_high_affinity_endpoint(
self,
Expand Down Expand Up @@ -155,7 +156,5 @@ def update_endpoints_stats(
engine_stats: Dict[str, EngineStats],
request_stats: Dict[str, RequestStats],
) -> None:
"""
Set the endpoints for the trie.
"""
self.trie.root.endpoints = endpoints
pass

Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,18 @@ def __init__(
**kwargs
):
self.index = 0
self.endpoints = set()

self.name = "round_robin_affinity"

def get_high_affinity_endpoint(
self,
request: Request,
request_json: Dict[str, Any],
unavailable_endpoints: Set[str],
available_endpoints: Set[str],
) -> str:

available_endpoints = list(self.endpoints - unavailable_endpoints)

if not available_endpoints:
raise ValueError(f"No available endpoints for request: {request}")

available_endpoints = list(available_endpoints)
endpoint = available_endpoints[self.index % len(available_endpoints)]
self.index = self.index + 1
return endpoint
Expand All @@ -49,5 +46,4 @@ def update_endpoints_stats(
engine_stats: Dict[str, EngineStats],
request_stats: Dict[str, RequestStats],
) -> None:

self.endpoints = endpoints
pass
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,31 @@ def __init__(
self,
**kwargs
):
if hasattr(self, "_initialized"):
return
if "session_key" not in kwargs:
raise ValueError("Using session affinity without specifying "
"session_key in affinity config. Please specify a session_key.")

self.session_key = kwargs["session_key"]
self.hash_ring = HashRing()
self._initialized = True
self.name = "session_affinity"

def get_high_affinity_endpoint(
self,
request: Request,
request_json: Dict[str, Any],
unavailable_endpoints: Set[str],
available_endpoints: Set[str],
) -> str:

assert unavailable_endpoints.issubset(self.endpoints)

assert available_endpoints.issubset(self.endpoints), (
f"Available endpoints must be a subset of the endpoints in the hash"
f"ring. \nAvailable endpoints: {available_endpoints} \n"
f"Endpoints in hash ring: {self.endpoints}\n"
)

session_id = request.headers.get(self.session_key, None)

# Iterate through nodes starting from the hash position
for endpoint in self.hash_ring.iterate_nodes(str(session_id), distinct=True):

if endpoint not in unavailable_endpoints:
if endpoint in available_endpoints:
return endpoint

raise ValueError(f"No endpoint found for request: {request}")
Expand All @@ -58,10 +57,6 @@ def update_endpoints_stats(
# Convert the new endpoint URLs to a set for easy comparison
new_nodes = endpoints

# Remove nodes that are no longer in the list
for node in current_nodes - new_nodes:
self.hash_ring.remove_node(node)

# Add new nodes that are not already in the hash ring
for node in new_nodes - current_nodes:
self.hash_ring.add_node(node)
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
self.hash_ring = HashRing()
self.hash_func = hash_type.get_hash_func(max_length=max_length)
self.endpoints = set()
self.name = "simhash_affinity"


def get_high_affinity_endpoint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from typing import Set, Dict
import abc

class BaseOverloadDetector(metaclass=abc.ABCMeta):
class BaseEndpointFilter(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_overload_endpoints(
def get_filtered_endpoints(
self,
endpoints: Set[str],
request_stats: Dict[str, RequestStats],
engine_stats: Dict[str, EngineStats],
) -> Set[str]:
"""
Check if the endpoint is overloaded.
Filter the endpoints based on the request stats and engine stats.
"""
pass
pass
24 changes: 24 additions & 0 deletions src/vllm_router/routers/endpoint_filter/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

from vllm_router.routers.endpoint_filter.base import BaseEndpointFilter

import json
from logging import getLogger

logger = getLogger(__name__)

from vllm_router.routers.endpoint_filter.num_queueing_request_filter import NumQueueingRequestFilter

endpoint_filter_name_to_class = {
"num_queueing_request_filter": NumQueueingRequestFilter,
}

def get_endpoint_filter(endpoint_filter_name: str, endpoint_filter_config: Dict[str, Any] = {}, **kwargs) -> BaseEndpointFilter:
if endpoint_filter_name not in endpoint_filter_name_to_class:
raise ValueError(f"Invalid endpoint filter name: {endpoint_filter_name}")

assert kwargs == {}, ("There are extra kwargs forwarded to the endpoint filter "
"factory method. This is likely unintended. "
"Received kwargs: %s" % kwargs)

logger.info(f"Using endpoint filter type: {endpoint_filter_name} with config: {endpoint_filter_config}")
return endpoint_filter_name_to_class[endpoint_filter_name](**endpoint_filter_config)
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@

from vllm_router.load_metrics.base import BaseLoadMetric
from vllm_router.routers.endpoint_filter.base import BaseEndpointFilter
from vllm_router.types import RequestStats, EngineStats

import logging

logger = logging.getLogger(__name__)

class NumQueueingRequest(BaseOverloadDetector):
class NumQueueingRequestFilter(BaseEndpointFilter):

def __init__(self, **kwargs):

if "percentile" not in kwargs:
logger.warning("Using num_queueing_request overload detector "
"without specifying percentile in overload detector config."
logger.warning("Using num_queueing_request endpoint filter "
"without specifying percentile in endpoint filter config."
"Setting percentile to default value: 0.9")
percentile = 0.9
else:
percentile = kwargs["percentile"]

self.percentile = percentile
self.name = "num_queueing_request_filter"

def get_overload_endpoints(
def get_filtered_endpoints(
self,
endpoints: Set[str],
request_stats: Dict[str, RequestStats],
Expand Down
Loading

0 comments on commit 061458c

Please sign in to comment.