From c1eb84bdf92ff2f597d5d49733bbddc31a9a84fa Mon Sep 17 00:00:00 2001 From: Yihua Cheng Date: Sat, 1 Mar 2025 16:35:31 -0600 Subject: [PATCH] [Feat] dynamic configuration support for router (#207) * [Add] dynamic router config support Signed-off-by: ApostaC * [Fix] small errors and add readme Signed-off-by: ApostaC * [Add] proper close and reconfigure for different components Signed-off-by: ApostaC * [Add] getting dynamic config from the /health endpoint Signed-off-by: ApostaC --------- Signed-off-by: ApostaC --- src/vllm_router/README.md | 53 +++++++ src/vllm_router/dynamic_config.py | 228 +++++++++++++++++++++++++++ src/vllm_router/engine_stats.py | 42 +++-- src/vllm_router/router.py | 73 ++++++--- src/vllm_router/routing_logic.py | 21 +-- src/vllm_router/service_discovery.py | 109 ++++++++++--- src/vllm_router/utils.py | 53 +++++++ 7 files changed, 508 insertions(+), 71 deletions(-) create mode 100644 src/vllm_router/dynamic_config.py diff --git a/src/vllm_router/README.md b/src/vllm_router/README.md index 1039986f..6c711cc8 100644 --- a/src/vllm_router/README.md +++ b/src/vllm_router/README.md @@ -44,6 +44,10 @@ The router can be configured using command-line arguments. Below are the availab - `--log-stats`: Log statistics every 30 seconds. +### Dynamic Config Options + +- `--dynamic-config-json`: The path to the json file containing the dynamic configuration. + ## Build docker image ```bash @@ -69,3 +73,52 @@ vllm-router --port 8000 \ --log-stats \ --routing-logic roundrobin ``` + +## Dynamic Router Config + +The router can be configured dynamically using a json file when passing the `--dynamic-config-json` option. +The router will watch the json file for changes and update the configuration accordingly (every 10 seconds). + +Currently, the dynamic config supports the following fields: + +**Required fields:** + +- `service_discovery`: The service discovery type. Options are `static` or `k8s`. +- `routing_logic`: The routing logic to use. Options are `roundrobin` or `session`. + +**Optional fields:** + +- (When using `static` service discovery) `static_backends`: The URLs of static serving engines, separated by commas (e.g., `http://localhost:9001,http://localhost:9002,http://localhost:9003`). +- (When using `static` service discovery) `static_models`: The models running in the static serving engines, separated by commas (e.g., `model1,model2`). +- (When using `k8s` service discovery) `k8s_port`: The port of vLLM processes when using K8s service discovery. Default is `8000`. +- (When using `k8s` service discovery) `k8s_namespace`: The namespace of vLLM pods when using K8s service discovery. Default is `default`. +- (When using `k8s` service discovery) `k8s_label_selector`: The label selector to filter vLLM pods when using K8s service discovery. +- `session_key`: The key (in the header) to identify a session when using session-based routing. + +Here is an example dynamic config file: + +```json +{ + "service_discovery": "static", + "routing_logic": "roundrobin", + "static_backends": "http://localhost:9001,http://localhost:9002,http://localhost:9003", + "static_models": "facebook/opt-125m,meta-llama/Llama-3.1-8B-Instruct,facebook/opt-125m" +} +``` + +### Get current dynamic config + +If the dynamic config is enabled, the router will reflect the current dynamic config in the `/health` endpoint. + +```bash +curl http://:/health +``` + +The response will be a JSON object with the current dynamic config. + +```json +{ + "status": "healthy", + "dynamic_config": +} +``` diff --git a/src/vllm_router/dynamic_config.py b/src/vllm_router/dynamic_config.py new file mode 100644 index 00000000..bc7e2549 --- /dev/null +++ b/src/vllm_router/dynamic_config.py @@ -0,0 +1,228 @@ +import json +import threading +import time +from dataclasses import dataclass +from typing import Optional + +from fastapi import FastAPI + +from vllm_router.log import init_logger +from vllm_router.routing_logic import ReconfigureRoutingLogic +from vllm_router.service_discovery import ( + ReconfigureServiceDiscovery, + ServiceDiscoveryType, +) +from vllm_router.utils import SingletonMeta, parse_static_model_names, parse_static_urls + +logger = init_logger(__name__) + + +@dataclass +class DynamicRouterConfig: + """ + Re-configurable configurations for the VLLM router. + """ + + # Required configurations + service_discovery: str + routing_logic: str + + # Optional configurations + # Service discovery configurations + static_backends: Optional[str] = None + static_models: Optional[str] = None + k8s_port: Optional[int] = None + k8s_namespace: Optional[str] = None + k8s_label_selector: Optional[str] = None + + # Routing logic configurations + session_key: Optional[str] = None + + # Batch API configurations + # TODO (ApostaC): Support dynamic reconfiguration of batch API + # enable_batch_api: bool + # file_storage_class: str + # file_storage_path: str + # batch_processor: str + + # Stats configurations + # TODO (ApostaC): Support dynamic reconfiguration of stats monitor + # engine_stats_interval: int + # request_stats_window: int + # log_stats: bool + # log_stats_interval: int + + @staticmethod + def from_args(args) -> "DynamicRouterConfig": + return DynamicRouterConfig( + service_discovery=args.service_discovery, + static_backends=args.static_backends, + static_models=args.static_models, + k8s_port=args.k8s_port, + k8s_namespace=args.k8s_namespace, + k8s_label_selector=args.k8s_label_selector, + # Routing logic configurations + routing_logic=args.routing_logic, + session_key=args.session_key, + ) + + @staticmethod + def from_json(json_path: str) -> "DynamicRouterConfig": + with open(json_path, "r") as f: + config = json.load(f) + return DynamicRouterConfig(**config) + + def to_json_str(self) -> str: + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class DynamicConfigWatcher(metaclass=SingletonMeta): + """ + Watches a config json file for changes and updates the DynamicRouterConfig accordingly. + """ + + def __init__( + self, + config_json: str, + watch_interval: int, + init_config: DynamicRouterConfig, + app: FastAPI, + ): + """ + Initializes the ConfigMapWatcher with the given ConfigMap name and namespace. + + Args: + config_json: the path to the json file containing the dynamic configuration + watch_interval: the interval in seconds at which to watch the for changes + app: the fastapi app to reconfigure + """ + self.config_json = config_json + self.watch_interval = watch_interval + self.current_config = init_config + self.app = app + + # Watcher thread + self.running = True + self.watcher_thread = threading.Thread(target=self._watch_worker) + self.watcher_thread.start() + assert hasattr(self.app, "state") + + def get_current_config(self) -> DynamicRouterConfig: + return self.current_config + + def reconfigure_service_discovery(self, config: DynamicRouterConfig): + """ + Reconfigures the router with the given config. + """ + if config.service_discovery == "static": + ReconfigureServiceDiscovery( + ServiceDiscoveryType.STATIC, + urls=parse_static_urls(config.static_backends), + models=parse_static_model_names(config.static_models), + ) + elif config.service_discovery == "k8s": + ReconfigureServiceDiscovery( + ServiceDiscoveryType.K8S, + namespace=config.k8s_namespace, + port=config.k8s_port, + label_selector=config.k8s_label_selector, + ) + else: + raise ValueError( + f"Invalid service discovery type: {config.service_discovery}" + ) + + logger.info(f"DynamicConfigWatcher: Service discovery reconfiguration complete") + + def reconfigure_routing_logic(self, config: DynamicRouterConfig): + """ + Reconfigures the router with the given config. + """ + routing_logic = ReconfigureRoutingLogic( + config.routing_logic, session_key=config.session_key + ) + self.app.state.router = routing_logic + logger.info(f"DynamicConfigWatcher: Routing logic reconfiguration complete") + + def reconfigure_batch_api(self, config: DynamicRouterConfig): + """ + Reconfigures the router with the given config. + """ + # TODO (ApostaC): Implement reconfigure_batch_api + pass + + def reconfigure_stats(self, config: DynamicRouterConfig): + """ + Reconfigures the router with the given config. + """ + # TODO (ApostaC): Implement reconfigure_stats + pass + + def reconfigure_all(self, config: DynamicRouterConfig): + """ + Reconfigures the router with the given config. + """ + self.reconfigure_service_discovery(config) + self.reconfigure_routing_logic(config) + self.reconfigure_batch_api(config) + self.reconfigure_stats(config) + + def _sleep_or_break(self, check_interval: float = 1): + """ + Sleep for self.watch_interval seconds if self.running is True. + Otherwise, break the loop. + """ + for _ in range(int(self.watch_interval / check_interval)): + if not self.running: + break + time.sleep(check_interval) + + def _watch_worker(self): + """ + Watches the config file for changes and updates the DynamicRouterConfig accordingly. + On every watch_interval, it will try loading the config file and compare the changes. + If the config file has changed, it will reconfigure the system with the new config. + """ + while self.running: + try: + config = DynamicRouterConfig.from_json(self.config_json) + if config != self.current_config: + logger.info( + f"DynamicConfigWatcher: Config changed, reconfiguring..." + ) + self.reconfigure_all(config) + logger.info( + f"DynamicConfigWatcher: Config reconfiguration complete" + ) + self.current_config = config + except Exception as e: + logger.warning(f"DynamicConfigWatcher: Error loading config file: {e}") + + self._sleep_or_break() + + def close(self): + """ + Closes the watcher thread. + """ + self.running = False + self.watcher_thread.join() + logger.info("DynamicConfigWatcher: Closed") + + +def InitializeDynamicConfigWatcher( + config_json: str, + watch_interval: int, + init_config: DynamicRouterConfig, + app: FastAPI, +): + """ + Initializes the DynamicConfigWatcher with the given config json and watch interval. + """ + return DynamicConfigWatcher(config_json, watch_interval, init_config, app) + + +def GetDynamicConfigWatcher() -> DynamicConfigWatcher: + """ + Returns the DynamicConfigWatcher singleton. + """ + return DynamicConfigWatcher(_create=False) diff --git a/src/vllm_router/engine_stats.py b/src/vllm_router/engine_stats.py index 2e8705a1..bf150890 100644 --- a/src/vllm_router/engine_stats.py +++ b/src/vllm_router/engine_stats.py @@ -1,27 +1,18 @@ import threading import time from dataclasses import dataclass -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import requests from prometheus_client.parser import text_string_to_metric_families from vllm_router.log import init_logger from vllm_router.service_discovery import GetServiceDiscovery +from vllm_router.utils import SingletonMeta logger = init_logger(__name__) -class SingletonMeta(type): - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - instance = super().__call__(*args, **kwargs) - cls._instances[cls] = instance - return cls._instances[cls] - - @dataclass class EngineStats: # Number of running requests @@ -92,10 +83,12 @@ def __init__(self, scrape_interval: float): raise ValueError( "EngineStatsScraper must be initialized with scrape_interval" ) - self.service_discovery = GetServiceDiscovery() # (remains unchanged) self.engine_stats: Dict[str, EngineStats] = {} self.engine_stats_lock = threading.Lock() self.scrape_interval = scrape_interval + + # scrape thread + self.running = True self.scrape_thread = threading.Thread(target=self._scrape_worker, daemon=True) self.scrape_thread.start() self._initialized = True @@ -108,7 +101,7 @@ def _scrape_one_endpoint(self, url: str): url (str): The URL of the serving engine (does not contain endpoint) """ try: - response = requests.get(url + "/metrics") + response = requests.get(url + "/metrics", timeout=self.scrape_interval) response.raise_for_status() engine_stats = EngineStats.FromVllmScrape(response.text) except Exception as e: @@ -126,7 +119,7 @@ def _scrape_metrics(self): """ collected_engine_stats = {} - endpoints = self.service_discovery.get_endpoint_info() + endpoints = GetServiceDiscovery().get_endpoint_info() logger.info(f"Scraping metrics from {len(endpoints)} serving engine(s)") for info in endpoints: url = info.url @@ -142,6 +135,16 @@ def _scrape_metrics(self): for url, stats in collected_engine_stats.items(): self.engine_stats[url] = stats + def _sleep_or_break(self, check_interval: float = 1): + """ + Sleep for self.scrape_interval seconds if self.running is True. + Otherwise, break the loop. + """ + for _ in range(int(self.scrape_interval / check_interval)): + if not self.running: + break + time.sleep(check_interval) + def _scrape_worker(self): """ Periodically scrape metrics from all serving engines in the background. @@ -151,9 +154,9 @@ def _scrape_worker(self): metrics from all serving engines and store them in self.engine_stats. """ - while True: + while self.running: self._scrape_metrics() - time.sleep(self.scrape_interval) + self._sleep_or_break() def get_engine_stats(self) -> Dict[str, EngineStats]: """ @@ -175,6 +178,13 @@ def get_health(self) -> bool: """ return self.scrape_thread.is_alive() + def close(self): + """ + Stop the background thread and cleanup resources. + """ + self.running = False + self.scrape_thread.join() + def InitializeEngineStatsScraper(scrape_interval: float) -> EngineStatsScraper: return EngineStatsScraper(scrape_interval) diff --git a/src/vllm_router/router.py b/src/vllm_router/router.py index 0a3f52ee..1f723b9e 100644 --- a/src/vllm_router/router.py +++ b/src/vllm_router/router.py @@ -1,4 +1,5 @@ import argparse +import json import logging import threading import time @@ -12,6 +13,11 @@ from prometheus_client import CONTENT_TYPE_LATEST, Gauge, generate_latest from vllm_router.batch import BatchProcessor, initialize_batch_processor +from vllm_router.dynamic_config import ( + DynamicRouterConfig, + GetDynamicConfigWatcher, + InitializeDynamicConfigWatcher, +) from vllm_router.engine_stats import GetEngineStatsScraper, InitializeEngineStatsScraper from vllm_router.files import Storage, initialize_storage from vllm_router.httpx_client import HTTPXClientWrapper @@ -26,7 +32,12 @@ InitializeServiceDiscovery, ServiceDiscoveryType, ) -from vllm_router.utils import set_ulimit, validate_url +from vllm_router.utils import ( + parse_static_model_names, + parse_static_urls, + set_ulimit, + validate_url, +) from vllm_router.version import __version__ httpx_client_wrapper = HTTPXClientWrapper() @@ -41,6 +52,21 @@ async def lifespan(app: FastAPI): yield await httpx_client_wrapper.stop() + # Close the threaded-components + logger.info("Closing engine stats scraper") + engine_stats_scraper = GetEngineStatsScraper() + engine_stats_scraper.close() + + logger.info("Closing service discovery module") + service_discovery = GetServiceDiscovery() + service_discovery.close() + + # Close the optional dynamic config watcher + dyn_cfg_watcher = GetDynamicConfigWatcher() + if dyn_cfg_watcher is not None: + logger.info("Closing dynamic config watcher") + dyn_cfg_watcher.close() + app = FastAPI(lifespan=lifespan) @@ -544,7 +570,18 @@ async def health() -> Response: return JSONResponse( content={"status": "Engine stats scraper is down."}, status_code=503 ) - return Response(status_code=200) + + if GetDynamicConfigWatcher() is not None: + dynamic_config = GetDynamicConfigWatcher().get_current_config() + return JSONResponse( + content={ + "status": "healthy", + "dynamic_config": json.loads(dynamic_config.to_json_str()), + }, + status_code=200, + ) + else: + return JSONResponse(content={"status": "healthy"}, status_code=200) # --- Prometheus Metrics Endpoint --- @@ -725,6 +762,13 @@ def parse_args(): help="The interval in seconds to log statistics.", ) + parser.add_argument( + "--dynamic-config-json", + type=str, + default=None, + help="The path to the json file containing the dynamic configuration.", + ) + # Add --version argument parser.add_argument( "--version", @@ -738,22 +782,6 @@ def parse_args(): return args -def parse_static_urls(args): - urls = args.static_backends.split(",") - backend_urls = [] - for url in urls: - if validate_url(url): - backend_urls.append(url) - else: - logger.warning(f"Skipping invalid URL: {url}") - return backend_urls - - -def parse_static_model_names(args): - models = args.static_models.split(",") - return models - - def InitializeAll(args): """ Initialize all the components of the router with the given arguments. @@ -767,8 +795,8 @@ def InitializeAll(args): if args.service_discovery == "static": InitializeServiceDiscovery( ServiceDiscoveryType.STATIC, - urls=parse_static_urls(args), - models=parse_static_model_names(args), + urls=parse_static_urls(args.static_backends), + models=parse_static_model_names(args.static_models), ) elif args.service_discovery == "k8s": InitializeServiceDiscovery( @@ -800,6 +828,11 @@ def InitializeAll(args): app.state.request_stats_monitor = GetRequestStatsMonitor() app.state.router = GetRoutingLogic() + # Initialize dynamic config watcher + if args.dynamic_config_json: + init_config = DynamicRouterConfig.from_args(args) + InitializeDynamicConfigWatcher(args.dynamic_config_json, 10, init_config, app) + def log_stats(interval: int = 10): """ diff --git a/src/vllm_router/routing_logic.py b/src/vllm_router/routing_logic.py index 0b5daf21..e69d3f58 100644 --- a/src/vllm_router/routing_logic.py +++ b/src/vllm_router/routing_logic.py @@ -10,6 +10,7 @@ from vllm_router.log import init_logger from vllm_router.request_stats import RequestStats from vllm_router.service_discovery import EndpointInfo +from vllm_router.utils import SingletonABCMeta logger = init_logger(__name__) @@ -19,16 +20,6 @@ class RoutingLogic(str, enum.Enum): SESSION_BASED = "session" -class SingletonABCMeta(abc.ABCMeta): - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - instance = super().__call__(*args, **kwargs) - cls._instances[cls] = instance - return cls._instances[cls] - - class RoutingInterface(metaclass=SingletonABCMeta): @abc.abstractmethod def route_request( @@ -196,6 +187,16 @@ def InitializeRoutingLogic( raise ValueError(f"Invalid routing logic {routing_logic}") +def ReconfigureRoutingLogic( + routing_logic: RoutingLogic, *args, **kwargs +) -> RoutingInterface: + # Remove the existing routers from the singleton registry + for cls in (SessionRouter, RoundRobinRouter): + if cls in SingletonABCMeta._instances: + del SingletonABCMeta._instances[cls] + return InitializeRoutingLogic(routing_logic, *args, **kwargs) + + def GetRoutingLogic() -> RoutingInterface: # Look up in our singleton registry which router (if any) has been created. for cls in (SessionRouter, RoundRobinRouter): diff --git a/src/vllm_router/service_discovery.py b/src/vllm_router/service_discovery.py index 77ad2120..3a1248b7 100644 --- a/src/vllm_router/service_discovery.py +++ b/src/vllm_router/service_discovery.py @@ -53,6 +53,12 @@ def get_health(self) -> bool: """ return True + def close(self) -> None: + """ + Close the service discovery module. + """ + pass + class StaticServiceDiscovery(ServiceDiscovery): def __init__(self, urls: List[str], models: List[str]): @@ -106,6 +112,7 @@ def __init__(self, namespace: str, port: str, label_selector=None): self.k8s_watcher = watch.Watch() # Start watching engines + self.running = True self.watcher_thread = threading.Thread(target=self._watch_engines, daemon=True) self.watcher_thread.start() @@ -143,25 +150,31 @@ def _get_model_name(self, pod_ip) -> Optional[str]: return model_name def _watch_engines(self): - # TODO (ApostaC): Add error handling - - for event in self.k8s_watcher.stream( - self.k8s_api.list_namespaced_pod, - namespace=self.namespace, - label_selector=self.label_selector, - ): - pod = event["object"] - event_type = event["type"] - pod_name = pod.metadata.name - pod_ip = pod.status.pod_ip - is_pod_ready = self._check_pod_ready(pod.status.container_statuses) - if is_pod_ready: - model_name = self._get_model_name(pod_ip) - else: - model_name = None - self._on_engine_update( - pod_name, pod_ip, event_type, is_pod_ready, model_name - ) + # TODO (ApostaC): remove the hard-coded timeouts + + while self.running: + try: + for event in self.k8s_watcher.stream( + self.k8s_api.list_namespaced_pod, + namespace=self.namespace, + label_selector=self.label_selector, + timeout_seconds=30, + ): + pod = event["object"] + event_type = event["type"] + pod_name = pod.metadata.name + pod_ip = pod.status.pod_ip + is_pod_ready = self._check_pod_ready(pod.status.container_statuses) + if is_pod_ready: + model_name = self._get_model_name(pod_ip) + else: + model_name = None + self._on_engine_update( + pod_name, pod_ip, event_type, is_pod_ready, model_name + ) + except Exception as e: + logger.error(f"K8s watcher error: {e}") + time.sleep(0.5) def _add_engine(self, engine_name: str, engine_ip: str, model_name: str): logger.info( @@ -240,6 +253,37 @@ def get_health(self) -> bool: """ return self.watcher_thread.is_alive() + def close(self): + """ + Close the service discovery module. + """ + self.running = False + self.k8s_watcher.stop() + self.watcher_thread.join() + + +def _create_service_discovery( + service_discovery_type: ServiceDiscoveryType, *args, **kwargs +) -> ServiceDiscovery: + """ + Create a service discovery module with the given type and arguments. + + Args: + service_discovery_type: the type of service discovery module + *args: positional arguments for the service discovery module + **kwargs: keyword arguments for the service discovery module + + Returns: + the created service discovery module + """ + + if service_discovery_type == ServiceDiscoveryType.STATIC: + return StaticServiceDiscovery(*args, **kwargs) + elif service_discovery_type == ServiceDiscoveryType.K8S: + return K8sServiceDiscovery(*args, **kwargs) + else: + raise ValueError("Invalid service discovery type") + def InitializeServiceDiscovery( service_discovery_type: ServiceDiscoveryType, *args, **kwargs @@ -263,13 +307,28 @@ def InitializeServiceDiscovery( if _global_service_discovery is not None: raise ValueError("Service discovery module already initialized") - if service_discovery_type == ServiceDiscoveryType.STATIC: - _global_service_discovery = StaticServiceDiscovery(*args, **kwargs) - elif service_discovery_type == ServiceDiscoveryType.K8S: - _global_service_discovery = K8sServiceDiscovery(*args, **kwargs) - else: - raise ValueError("Invalid service discovery type") + _global_service_discovery = _create_service_discovery( + service_discovery_type, *args, **kwargs + ) + return _global_service_discovery + + +def ReconfigureServiceDiscovery( + service_discovery_type: ServiceDiscoveryType, *args, **kwargs +) -> ServiceDiscovery: + """ + Reconfigure the service discovery module with the given type and arguments. + """ + global _global_service_discovery + if _global_service_discovery is None: + raise ValueError("Service discovery module not initialized") + + new_service_discovery = _create_service_discovery( + service_discovery_type, *args, **kwargs + ) + _global_service_discovery.close() + _global_service_discovery = new_service_discovery return _global_service_discovery diff --git a/src/vllm_router/utils.py b/src/vllm_router/utils.py index a3dbe43c..9e84a0ae 100644 --- a/src/vllm_router/utils.py +++ b/src/vllm_router/utils.py @@ -1,6 +1,43 @@ +import abc import re import resource +from vllm_router.log import init_logger + +logger = init_logger(__name__) + + +class SingletonMeta(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + """ + Note: if the class is called with _create=False, it will return None + if the instance does not exist. + """ + if cls not in cls._instances: + if kwargs.get("_create") is False: + return None + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + + +class SingletonABCMeta(abc.ABCMeta): + _instances = {} + + def __call__(cls, *args, **kwargs): + """ + Note: if the class is called with _create=False, it will return None + if the instance does not exist. + """ + if cls not in cls._instances: + if kwargs.get("create") is False: + return None + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + def validate_url(url: str) -> bool: """ @@ -40,3 +77,19 @@ def set_ulimit(target_soft_limit=65535): current_soft, e, ) + + +def parse_static_urls(static_backends: str): + urls = static_backends.split(",") + backend_urls = [] + for url in urls: + if validate_url(url): + backend_urls.append(url) + else: + logger.warning(f"Skipping invalid URL: {url}") + return backend_urls + + +def parse_static_model_names(static_models: str): + models = static_models.split(",") + return models