From 95efeccf042a73001c867b875532caf19deed592 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Brian=20Park=20/=20=EB=B0=95=EA=B8=B0=EC=98=81?= <125226023+BrianPark314@users.noreply.github.com> Date: Wed, 5 Mar 2025 06:17:22 +0900 Subject: [PATCH] =?UTF-8?q?refactor:=20standard=20fastapi=20project=20stru?= =?UTF-8?q?cture=20for=20better=20main=E2=80=A6=20(#217)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: maintain standard fastapi project structure for better maintainability Signed-off-by: BrianPark314 * chore: merge main into branch Signed-off-by: BrianPark314 * fix: run pre-commit Signed-off-by: BrianPark314 * fix: setup.py Signed-off-by: BrianPark314 * fix: service discovery time import Signed-off-by: BrianPark314 * fix: service discovery time import Signed-off-by: BrianPark314 * fix: delete unused file Signed-off-by: BrianPark314 * fix: merge main and fix log_stats.py Signed-off-by: BrianPark314 * chore: run pre-commit Signed-off-by: BrianPark314 * fix: correct experimental imports Signed-off-by: BrianPark314 * fix: parser error Signed-off-by: BrianPark314 * fix: experimental feature flag Signed-off-by: BrianPark314 * fix: wrapper call issue Signed-off-by: BrianPark314 * fix: add missing lifespan Signed-off-by: BrianPark314 * chore: add TODO comment Signed-off-by: BrianPark314 --------- Signed-off-by: BrianPark314 Co-authored-by: BrianPark314 --- setup.py | 2 +- src/tests/test_file_storage.py | 3 +- src/tests/test_session_router.py | 8 +- src/tests/test_singleton.py | 12 +- src/vllm_router/app.py | 230 ++++ src/vllm_router/dynamic_config.py | 14 +- .../experimental/semantic_cache/__init__.py | 4 +- .../semantic_cache/semantic_cache.py | 3 +- src/vllm_router/files/__init__.py | 10 - src/vllm_router/httpx_client.py | 3 - src/vllm_router/parsers/__init__.py | 0 src/vllm_router/parsers/parser.py | 200 ++++ src/vllm_router/perf-test.sh | 2 +- src/vllm_router/protocols.py | 1 + src/vllm_router/router.py | 1055 ----------------- src/vllm_router/routers/__init__.py | 0 src/vllm_router/routers/batches_router.py | 100 ++ src/vllm_router/routers/files_router.py | 68 ++ src/vllm_router/routers/main_router.py | 159 +++ src/vllm_router/routers/metrics_router.py | 64 + .../{ => routers}/routing_logic.py | 15 +- src/vllm_router/run-router.sh | 4 +- src/vllm_router/service_discovery.py | 11 +- src/vllm_router/services/__init__.py | 0 .../batch_service}/__init__.py | 5 +- .../batch_service}/batch.py | 0 .../batch_service}/local_processor.py | 4 +- .../batch_service}/processor.py | 4 +- .../services/files_service/__init__.py | 8 + .../files_service}/file_storage.py | 4 +- .../files_service/openai_files.py} | 0 .../files_service}/storage.py | 4 +- .../services/metrics_service/__init__.py | 32 + .../metrics_service/prometheus_gauge.py | 0 .../services/request_service/__init__.py | 0 .../services/request_service/request.py | 178 +++ src/vllm_router/stats/__init__.py | 0 src/vllm_router/{ => stats}/engine_stats.py | 14 +- src/vllm_router/stats/log_stats.py | 82 ++ src/vllm_router/{ => stats}/request_stats.py | 4 +- 40 files changed, 1177 insertions(+), 1130 deletions(-) create mode 100644 src/vllm_router/app.py delete mode 100644 src/vllm_router/files/__init__.py create mode 100644 src/vllm_router/parsers/__init__.py create mode 100644 src/vllm_router/parsers/parser.py delete mode 100644 src/vllm_router/router.py create mode 100644 src/vllm_router/routers/__init__.py create mode 100644 src/vllm_router/routers/batches_router.py create mode 100644 src/vllm_router/routers/files_router.py create mode 100644 src/vllm_router/routers/main_router.py create mode 100644 src/vllm_router/routers/metrics_router.py rename src/vllm_router/{ => routers}/routing_logic.py (95%) create mode 100644 src/vllm_router/services/__init__.py rename src/vllm_router/{batch => services/batch_service}/__init__.py (74%) rename src/vllm_router/{batch => services/batch_service}/batch.py (100%) rename src/vllm_router/{batch => services/batch_service}/local_processor.py (98%) rename src/vllm_router/{batch => services/batch_service}/processor.py (89%) create mode 100644 src/vllm_router/services/files_service/__init__.py rename src/vllm_router/{files => services/files_service}/file_storage.py (96%) rename src/vllm_router/{files/files.py => services/files_service/openai_files.py} (100%) rename src/vllm_router/{files => services/files_service}/storage.py (97%) create mode 100644 src/vllm_router/services/metrics_service/__init__.py create mode 100644 src/vllm_router/services/metrics_service/prometheus_gauge.py create mode 100644 src/vllm_router/services/request_service/__init__.py create mode 100644 src/vllm_router/services/request_service/request.py create mode 100644 src/vllm_router/stats/__init__.py rename src/vllm_router/{ => stats}/engine_stats.py (93%) create mode 100644 src/vllm_router/stats/log_stats.py rename src/vllm_router/{ => stats}/request_stats.py (99%) diff --git a/setup.py b/setup.py index b823f564..8f8c8b31 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ install_requires=install_requires, entry_points={ "console_scripts": [ - "vllm-router=vllm_router.router:main", + "vllm-router=vllm_router.app:main", ], }, description="The router for vLLM", diff --git a/src/tests/test_file_storage.py b/src/tests/test_file_storage.py index d0dcd800..df9d85f2 100644 --- a/src/tests/test_file_storage.py +++ b/src/tests/test_file_storage.py @@ -4,7 +4,8 @@ import pytest -from vllm_router.files import FileStorage, OpenAIFile +from vllm_router.services.files_service.file_storage import FileStorage +from vllm_router.services.files_service.openai_files import OpenAIFile TEST_BASE_PATH = "/tmp/test_vllm_files" pytest_plugins = ("pytest_asyncio",) diff --git a/src/tests/test_session_router.py b/src/tests/test_session_router.py index 7e857ade..4f12ab43 100644 --- a/src/tests/test_session_router.py +++ b/src/tests/test_session_router.py @@ -1,10 +1,6 @@ -import sys -from typing import Dict, List -from unittest.mock import Mock +from typing import Dict -import pytest - -from vllm_router.routing_logic import SessionRouter +from vllm_router.routers.routing_logic import SessionRouter class EndpointInfo: diff --git a/src/tests/test_singleton.py b/src/tests/test_singleton.py index cdacda7d..10e82a1b 100644 --- a/src/tests/test_singleton.py +++ b/src/tests/test_singleton.py @@ -2,11 +2,11 @@ import unittest # Import the classes and helper functions from your module. -from vllm_router.request_stats import ( - GetRequestStatsMonitor, - InitializeRequestStatsMonitor, +from vllm_router.stats.request_stats import ( RequestStatsMonitor, SingletonMeta, + get_request_stats_monitor, + initialize_request_stats_monitor, ) @@ -19,9 +19,9 @@ def setUp(self): def test_singleton_initialization(self): sliding_window = 10.0 # First initialization using the helper. - monitor1 = InitializeRequestStatsMonitor(sliding_window) + monitor1 = initialize_request_stats_monitor(sliding_window) # Subsequent retrieval using GetRequestStatsMonitor() should return the same instance. - monitor2 = GetRequestStatsMonitor() + monitor2 = get_request_stats_monitor() self.assertIs( monitor1, monitor2, @@ -39,7 +39,7 @@ def test_singleton_initialization(self): def test_initialization_without_parameter_after_initialized(self): sliding_window = 10.0 # First, initialize with the sliding_window. - monitor1 = InitializeRequestStatsMonitor(sliding_window) + monitor1 = initialize_request_stats_monitor(sliding_window) # Now, calling the constructor without a parameter should not raise an error # and should return the already initialized instance. monitor2 = RequestStatsMonitor() diff --git a/src/vllm_router/app.py b/src/vllm_router/app.py new file mode 100644 index 00000000..d2643c63 --- /dev/null +++ b/src/vllm_router/app.py @@ -0,0 +1,230 @@ +import logging +import threading +from contextlib import asynccontextmanager + +import uvicorn +from fastapi import FastAPI + +from vllm_router.dynamic_config import ( + DynamicRouterConfig, + get_dynamic_config_watcher, + initialize_dynamic_config_watcher, +) +from vllm_router.experimental import get_feature_gates, initialize_feature_gates + +try: + # Semantic cache integration + from vllm_router.experimental.semantic_cache import ( + GetSemanticCache, + enable_semantic_cache, + initialize_semantic_cache, + is_semantic_cache_enabled, + ) + from vllm_router.experimental.semantic_cache_integration import ( + add_semantic_cache_args, + check_semantic_cache, + semantic_cache_hit_ratio, + semantic_cache_hits, + semantic_cache_latency, + semantic_cache_misses, + semantic_cache_size, + store_in_semantic_cache, + ) + + semantic_cache_available = True +except ImportError: + semantic_cache_available = False + +from vllm_router.httpx_client import HTTPXClientWrapper +from vllm_router.parsers.parser import parse_args +from vllm_router.routers.batches_router import batches_router +from vllm_router.routers.files_router import files_router +from vllm_router.routers.main_router import main_router +from vllm_router.routers.metrics_router import metrics_router +from vllm_router.routers.routing_logic import ( + get_routing_logic, + initialize_routing_logic, +) +from vllm_router.service_discovery import ( + ServiceDiscoveryType, + get_service_discovery, + initialize_service_discovery, +) +from vllm_router.services.batch_service import initialize_batch_processor +from vllm_router.services.files_service import initialize_storage +from vllm_router.stats.engine_stats import ( + get_engine_stats_scraper, + initialize_engine_stats_scraper, +) +from vllm_router.stats.log_stats import log_stats +from vllm_router.stats.request_stats import ( + get_request_stats_monitor, + initialize_request_stats_monitor, +) +from vllm_router.utils import parse_static_model_names, parse_static_urls, set_ulimit + +logger = logging.getLogger("uvicorn") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + app.state.httpx_client_wrapper.start() + if hasattr(app.state, "batch_processor"): + await app.state.batch_processor.initialize() + yield + await app.state.httpx_client_wrapper.stop() + + # Close the threaded-components + logger.info("Closing engine stats scraper") + engine_stats_scraper = get_engine_stats_scraper() + engine_stats_scraper.close() + + logger.info("Closing service discovery module") + service_discovery = get_service_discovery() + service_discovery.close() + + # Close the optional dynamic config watcher + dyn_cfg_watcher = get_dynamic_config_watcher() + if dyn_cfg_watcher is not None: + logger.info("Closing dynamic config watcher") + dyn_cfg_watcher.close() + + +# TODO: This method needs refactoring, since it has nested if statements and is too long +def initialize_all(app: FastAPI, args): + """ + Initialize all the components of the router with the given arguments. + + Args: + app (FastAPI): FastAPI application + args: the parsed command-line arguments + + Raises: + ValueError: if the service discovery type is invalid + """ + if args.service_discovery == "static": + initialize_service_discovery( + ServiceDiscoveryType.STATIC, + urls=parse_static_urls(args.static_backends), + models=parse_static_model_names(args.static_models), + ) + elif args.service_discovery == "k8s": + initialize_service_discovery( + ServiceDiscoveryType.K8S, + namespace=args.k8s_namespace, + port=args.k8s_port, + label_selector=args.k8s_label_selector, + ) + else: + raise ValueError(f"Invalid service discovery type: {args.service_discovery}") + + # Initialize singletons via custom functions. + initialize_engine_stats_scraper(args.engine_stats_interval) + initialize_request_stats_monitor(args.request_stats_window) + + if args.enable_batch_api: + logger.info("Initializing batch API") + app.state.batch_storage = initialize_storage( + args.file_storage_class, args.file_storage_path + ) + app.state.batch_processor = initialize_batch_processor( + args.batch_processor, args.file_storage_path, app.state.batch_storage + ) + + initialize_routing_logic(args.routing_logic, session_key=args.session_key) + + # Initialize feature gates + initialize_feature_gates(args.feature_gates) + # Check if the SemanticCache feature gate is enabled + feature_gates = get_feature_gates() + if semantic_cache_available: + if feature_gates.is_enabled("SemanticCache"): + # The feature gate is enabled, explicitly enable the semantic cache + enable_semantic_cache() + + # Verify that the semantic cache was successfully enabled + if not is_semantic_cache_enabled(): + logger.error("Failed to enable semantic cache feature") + + logger.info("SemanticCache feature gate is enabled") + + # Initialize the semantic cache with the model if specified + if args.semantic_cache_model: + logger.info( + f"Initializing semantic cache with model: {args.semantic_cache_model}" + ) + logger.info( + f"Semantic cache directory: {args.semantic_cache_dir or 'default'}" + ) + logger.info( + f"Semantic cache threshold: {args.semantic_cache_threshold}" + ) + + cache = initialize_semantic_cache( + embedding_model=args.semantic_cache_model, + cache_dir=args.semantic_cache_dir, + default_similarity_threshold=args.semantic_cache_threshold, + ) + + # Update cache size metric + if cache and hasattr(cache, "db") and hasattr(cache.db, "index"): + semantic_cache_size.labels(server="router").set( + cache.db.index.ntotal + ) + logger.info( + f"Semantic cache initialized with {cache.db.index.ntotal} entries" + ) + + logger.info( + f"Semantic cache initialized with model {args.semantic_cache_model}" + ) + else: + logger.warning( + "SemanticCache feature gate is enabled but no embedding model specified. " + "The semantic cache will not be functional without an embedding model. " + "Use --semantic-cache-model to specify an embedding model." + ) + elif args.semantic_cache_model: + logger.warning( + "Semantic cache model specified but SemanticCache feature gate is not enabled. " + "Enable the feature gate with --feature-gates=SemanticCache=true" + ) + + # --- Hybrid addition: attach singletons to FastAPI state --- + app.state.engine_stats_scraper = get_engine_stats_scraper() + app.state.request_stats_monitor = get_request_stats_monitor() + app.state.router = get_routing_logic() + + # Initialize dynamic config watcher + if args.dynamic_config_json: + init_config = DynamicRouterConfig.from_args(args) + initialize_dynamic_config_watcher( + args.dynamic_config_json, 10, init_config, app + ) + + +app = FastAPI(lifespan=lifespan) +app.include_router(main_router) +app.include_router(files_router) +app.include_router(batches_router) +app.include_router(metrics_router) +app.state.httpx_client_wrapper = HTTPXClientWrapper() +app.state.semantic_cache_available = semantic_cache_available + + +def main(): + args = parse_args() + initialize_all(app, args) + if args.log_stats: + threading.Thread( + target=log_stats, args=(args.log_stats_interval,), daemon=True + ).start() + + # Workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active. + set_ulimit() + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/src/vllm_router/dynamic_config.py b/src/vllm_router/dynamic_config.py index bc7e2549..b01e8129 100644 --- a/src/vllm_router/dynamic_config.py +++ b/src/vllm_router/dynamic_config.py @@ -7,10 +7,10 @@ from fastapi import FastAPI from vllm_router.log import init_logger -from vllm_router.routing_logic import ReconfigureRoutingLogic +from vllm_router.routers.routing_logic import reconfigure_routing_logic from vllm_router.service_discovery import ( - ReconfigureServiceDiscovery, ServiceDiscoveryType, + reconfigure_service_discovery, ) from vllm_router.utils import SingletonMeta, parse_static_model_names, parse_static_urls @@ -115,13 +115,13 @@ def reconfigure_service_discovery(self, config: DynamicRouterConfig): Reconfigures the router with the given config. """ if config.service_discovery == "static": - ReconfigureServiceDiscovery( + reconfigure_service_discovery( ServiceDiscoveryType.STATIC, urls=parse_static_urls(config.static_backends), models=parse_static_model_names(config.static_models), ) elif config.service_discovery == "k8s": - ReconfigureServiceDiscovery( + reconfigure_service_discovery( ServiceDiscoveryType.K8S, namespace=config.k8s_namespace, port=config.k8s_port, @@ -138,7 +138,7 @@ def reconfigure_routing_logic(self, config: DynamicRouterConfig): """ Reconfigures the router with the given config. """ - routing_logic = ReconfigureRoutingLogic( + routing_logic = reconfigure_routing_logic( config.routing_logic, session_key=config.session_key ) self.app.state.router = routing_logic @@ -209,7 +209,7 @@ def close(self): logger.info("DynamicConfigWatcher: Closed") -def InitializeDynamicConfigWatcher( +def initialize_dynamic_config_watcher( config_json: str, watch_interval: int, init_config: DynamicRouterConfig, @@ -221,7 +221,7 @@ def InitializeDynamicConfigWatcher( return DynamicConfigWatcher(config_json, watch_interval, init_config, app) -def GetDynamicConfigWatcher() -> DynamicConfigWatcher: +def get_dynamic_config_watcher() -> DynamicConfigWatcher: """ Returns the DynamicConfigWatcher singleton. """ diff --git a/src/vllm_router/experimental/semantic_cache/__init__.py b/src/vllm_router/experimental/semantic_cache/__init__.py index 72528aa6..78c14b42 100644 --- a/src/vllm_router/experimental/semantic_cache/__init__.py +++ b/src/vllm_router/experimental/semantic_cache/__init__.py @@ -10,15 +10,15 @@ from vllm_router.experimental.feature_gates import get_feature_gates from vllm_router.experimental.semantic_cache.semantic_cache import ( GetSemanticCache, - InitializeSemanticCache, SemanticCache, + initialize_semantic_cache, ) logger = logging.getLogger(__name__) __all__ = [ "SemanticCache", - "InitializeSemanticCache", + "initialize_semantic_cache", "GetSemanticCache", "is_semantic_cache_enabled", "enable_semantic_cache", diff --git a/src/vllm_router/experimental/semantic_cache/semantic_cache.py b/src/vllm_router/experimental/semantic_cache/semantic_cache.py index bc7020b1..7b76ac4a 100644 --- a/src/vllm_router/experimental/semantic_cache/semantic_cache.py +++ b/src/vllm_router/experimental/semantic_cache/semantic_cache.py @@ -8,7 +8,6 @@ from vllm_router.experimental.semantic_cache.db_adapters import ( FAISSAdapter, - VectorDBAdapterBase, ) logger = logging.getLogger(__name__) @@ -318,7 +317,7 @@ def complete_store( _semantic_cache_instance = None -def InitializeSemanticCache( +def initialize_semantic_cache( embedding_model: str = "all-MiniLM-L6-v2", cache_dir: str = None, default_similarity_threshold: float = 0.95, diff --git a/src/vllm_router/files/__init__.py b/src/vllm_router/files/__init__.py deleted file mode 100644 index bc7cf78a..00000000 --- a/src/vllm_router/files/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from vllm_router.files.file_storage import FileStorage -from vllm_router.files.files import OpenAIFile -from vllm_router.files.storage import Storage, initialize_storage - -__all__ = [ - "OpenAIFile", - "Storage", - "FileStorage", - "initialize_storage", -] diff --git a/src/vllm_router/httpx_client.py b/src/vllm_router/httpx_client.py index 231f5fd6..bda1563f 100644 --- a/src/vllm_router/httpx_client.py +++ b/src/vllm_router/httpx_client.py @@ -1,7 +1,4 @@ -import logging - import httpx -from fastapi import FastAPI from vllm_router.log import init_logger diff --git a/src/vllm_router/parsers/__init__.py b/src/vllm_router/parsers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/vllm_router/parsers/parser.py b/src/vllm_router/parsers/parser.py new file mode 100644 index 00000000..912fb5a6 --- /dev/null +++ b/src/vllm_router/parsers/parser.py @@ -0,0 +1,200 @@ +import argparse + +from vllm_router.version import __version__ + +try: + # Semantic cache integration + from vllm_router.experimental.semantic_cache import ( + GetSemanticCache, + enable_semantic_cache, + initialize_semantic_cache, + is_semantic_cache_enabled, + ) + from vllm_router.experimental.semantic_cache_integration import ( + add_semantic_cache_args, + check_semantic_cache, + semantic_cache_hit_ratio, + semantic_cache_hits, + semantic_cache_latency, + semantic_cache_misses, + semantic_cache_size, + store_in_semantic_cache, + ) + + semantic_cache_available = True +except ImportError: + semantic_cache_available = False + + +# --- Argument Parsing and Initialization --- +def validate_args(args): + if args.service_discovery == "static": + if args.static_backends is None: + raise ValueError( + "Static backends must be provided when using static service discovery." + ) + if args.static_models is None: + raise ValueError( + "Static models must be provided when using static service discovery." + ) + if args.service_discovery == "k8s" and args.k8s_port is None: + raise ValueError("K8s port must be provided when using K8s service discovery.") + if args.routing_logic == "session" and args.session_key is None: + raise ValueError( + "Session key must be provided when using session routing logic." + ) + if args.log_stats and args.log_stats_interval <= 0: + raise ValueError("Log stats interval must be greater than 0.") + if args.engine_stats_interval <= 0: + raise ValueError("Engine stats interval must be greater than 0.") + if args.request_stats_window <= 0: + raise ValueError("Request stats window must be greater than 0.") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run the FastAPI app.") + parser.add_argument( + "--host", default="0.0.0.0", help="The host to run the server on." + ) + parser.add_argument( + "--port", type=int, default=8001, help="The port to run the server on." + ) + parser.add_argument( + "--service-discovery", + required=True, + choices=["static", "k8s"], + help="The service discovery type.", + ) + parser.add_argument( + "--static-backends", + type=str, + default=None, + help="The URLs of static backends, separated by commas. E.g., http://localhost:8000,http://localhost:8001", + ) + parser.add_argument( + "--static-models", + type=str, + default=None, + help="The models of static backends, separated by commas. E.g., model1,model2", + ) + parser.add_argument( + "--k8s-port", + type=int, + default=8000, + help="The port of vLLM processes when using K8s service discovery.", + ) + parser.add_argument( + "--k8s-namespace", + type=str, + default="default", + help="The namespace of vLLM pods when using K8s service discovery.", + ) + parser.add_argument( + "--k8s-label-selector", + type=str, + default="", + help="The label selector to filter vLLM pods when using K8s service discovery.", + ) + parser.add_argument( + "--routing-logic", + type=str, + required=True, + choices=["roundrobin", "session"], + help="The routing logic to use", + ) + parser.add_argument( + "--session-key", + type=str, + default=None, + help="The key (in the header) to identify a session.", + ) + + # Batch API + # TODO(gaocegege): Make these batch api related arguments to a separate config. + parser.add_argument( + "--enable-batch-api", + action="store_true", + help="Enable the batch API for processing files.", + ) + parser.add_argument( + "--file-storage-class", + type=str, + default="local_file", + choices=["local_file"], + help="The file storage class to use.", + ) + parser.add_argument( + "--file-storage-path", + type=str, + default="/tmp/vllm_files", + help="The path to store files.", + ) + parser.add_argument( + "--batch-processor", + type=str, + default="local", + choices=["local"], + help="The batch processor to use.", + ) + + # Monitoring + parser.add_argument( + "--engine-stats-interval", + type=int, + default=30, + help="The interval in seconds to scrape engine statistics.", + ) + parser.add_argument( + "--request-stats-window", + type=int, + default=60, + help="The sliding window in seconds to compute request statistics.", + ) + parser.add_argument( + "--log-stats", action="store_true", help="Log statistics periodically." + ) + parser.add_argument( + "--log-stats-interval", + type=int, + default=10, + help="The interval in seconds to log statistics.", + ) + + parser.add_argument( + "--dynamic-config-json", + type=str, + default=None, + help="The path to the json file containing the dynamic configuration.", + ) + + # Add --version argument + parser.add_argument( + "--version", + action="version", + version=f"%(prog)s {__version__}", + help="Show version and exit", + ) + + if semantic_cache_available: + add_semantic_cache_args(parser) + + # Add feature gates argument + parser.add_argument( + "--feature-gates", + type=str, + default="", + help="Comma-separated list of feature gates (e.g., 'SemanticCache=true')", + ) + + # Add log level argument + parser.add_argument( + "--log-level", + type=str, + default="info", + choices=["critical", "error", "warning", "info", "debug", "trace"], + help="Log level for uvicorn. Default is 'info'.", + ) + + args = parser.parse_args() + validate_args(args) + return args diff --git a/src/vllm_router/perf-test.sh b/src/vllm_router/perf-test.sh index 24bb2e43..a42663bf 100644 --- a/src/vllm_router/perf-test.sh +++ b/src/vllm_router/perf-test.sh @@ -8,7 +8,7 @@ fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # Run router.py from the correct directory -python3 "$SCRIPT_DIR/router.py" --port "$1" \ +python3 "$SCRIPT_DIR/app.py" --port "$1" \ --service-discovery static \ --static-backends "http://localhost:9004,http://localhost:9001,http://localhost:9002,http://localhost:9003" \ --static-models "fake_model_name,fake_model_name,fake_model_name,fake_model_name" \ diff --git a/src/vllm_router/protocols.py b/src/vllm_router/protocols.py index 2bf553fb..db9ab627 100644 --- a/src/vllm_router/protocols.py +++ b/src/vllm_router/protocols.py @@ -1,3 +1,4 @@ +import time from typing import List, Optional from pydantic import BaseModel, ConfigDict, Field, model_validator diff --git a/src/vllm_router/router.py b/src/vllm_router/router.py deleted file mode 100644 index 258185e5..00000000 --- a/src/vllm_router/router.py +++ /dev/null @@ -1,1055 +0,0 @@ -import argparse -import json -import logging -import threading -import time -import uuid -from contextlib import asynccontextmanager -from urllib.parse import urlparse - -import uvicorn -from fastapi import FastAPI, Request, UploadFile -from fastapi.responses import JSONResponse, Response, StreamingResponse -from prometheus_client import CONTENT_TYPE_LATEST, Gauge, generate_latest - -from vllm_router.batch import BatchProcessor, initialize_batch_processor -from vllm_router.dynamic_config import ( - DynamicRouterConfig, - GetDynamicConfigWatcher, - InitializeDynamicConfigWatcher, -) -from vllm_router.engine_stats import GetEngineStatsScraper, InitializeEngineStatsScraper - -# Import experimental feature gates and semantic cache -from vllm_router.experimental.feature_gates import ( - get_feature_gates, - initialize_feature_gates, -) - -try: - # Semantic cache integration - from vllm_router.experimental.semantic_cache import ( - GetSemanticCache, - InitializeSemanticCache, - enable_semantic_cache, - is_semantic_cache_enabled, - ) - from vllm_router.experimental.semantic_cache_integration import ( - add_semantic_cache_args, - check_semantic_cache, - semantic_cache_hit_ratio, - semantic_cache_hits, - semantic_cache_latency, - semantic_cache_misses, - semantic_cache_size, - store_in_semantic_cache, - ) - - semantic_cache_available = True -except ImportError: - semantic_cache_available = False - -from vllm_router.files import Storage, initialize_storage -from vllm_router.httpx_client import HTTPXClientWrapper -from vllm_router.protocols import ModelCard, ModelList -from vllm_router.request_stats import ( - GetRequestStatsMonitor, - InitializeRequestStatsMonitor, -) -from vllm_router.routing_logic import GetRoutingLogic, InitializeRoutingLogic -from vllm_router.service_discovery import ( - GetServiceDiscovery, - InitializeServiceDiscovery, - ServiceDiscoveryType, -) -from vllm_router.utils import ( - parse_static_model_names, - parse_static_urls, - set_ulimit, - validate_url, -) -from vllm_router.version import __version__ - -httpx_client_wrapper = HTTPXClientWrapper() -from vllm_router.log import init_logger - -logger = init_logger(__name__) - - -@asynccontextmanager -async def lifespan(app: FastAPI): - httpx_client_wrapper.start() - if hasattr(app.state, "batch_processor"): - await app.state.batch_processor.initialize() - yield - await httpx_client_wrapper.stop() - - # Close the threaded-components - logger.info("Closing engine stats scraper") - engine_stats_scraper = GetEngineStatsScraper() - engine_stats_scraper.close() - - logger.info("Closing service discovery module") - service_discovery = GetServiceDiscovery() - service_discovery.close() - - # Close the optional dynamic config watcher - dyn_cfg_watcher = GetDynamicConfigWatcher() - if dyn_cfg_watcher is not None: - logger.info("Closing dynamic config watcher") - dyn_cfg_watcher.close() - - -app = FastAPI(lifespan=lifespan) - -# --- Prometheus Gauges --- -# Existing metrics -num_requests_running = Gauge( - "vllm:num_requests_running", "Number of running requests", ["server"] -) -num_requests_waiting = Gauge( - "vllm:num_requests_waiting", "Number of waiting requests", ["server"] -) -current_qps = Gauge("vllm:current_qps", "Current Queries Per Second", ["server"]) -avg_decoding_length = Gauge( - "vllm:avg_decoding_length", "Average Decoding Length", ["server"] -) -num_prefill_requests = Gauge( - "vllm:num_prefill_requests", "Number of Prefill Requests", ["server"] -) -num_decoding_requests = Gauge( - "vllm:num_decoding_requests", "Number of Decoding Requests", ["server"] -) - -# New metrics per dashboard update -healthy_pods_total = Gauge( - "vllm:healthy_pods_total", "Number of healthy vLLM pods", ["server"] -) -avg_latency = Gauge( - "vllm:avg_latency", "Average end-to-end request latency", ["server"] -) -avg_itl = Gauge("vllm:avg_itl", "Average Inter-Token Latency", ["server"]) -num_requests_swapped = Gauge( - "vllm:num_requests_swapped", "Number of swapped requests", ["server"] -) - - -# --- Request Processing & Routing --- -# TODO: better request id system -async def process_request( - method, header, body, backend_url, request_id, endpoint, debug_request=None -): - """ - Process a request by sending it to the chosen backend. - - Args: - method: The HTTP method to use when sending the request to the backend. - header: The headers to send with the request to the backend. - body: The content of the request to send to the backend. - backend_url: The URL of the backend to send the request to. - request_id: A unique identifier for the request. - endpoint: The endpoint to send the request to on the backend. - debug_request: The original request object from the client, used for - optional debug logging. - - Yields: - The response headers and status code, followed by the response content. - - Raises: - HTTPError: If the backend returns a 4xx or 5xx status code. - """ - first_token = False - total_len = 0 - start_time = time.time() - app.state.request_stats_monitor.on_new_request(backend_url, request_id, start_time) - # Check if this is a streaming request - is_streaming = False - try: - request_json = json.loads(body) - is_streaming = request_json.get("stream", False) - except: - # If we can't parse the body as JSON, assume it's not streaming - pass - - # For non-streaming requests, collect the full response to cache it properly - full_response = bytearray() if not is_streaming else None - - client = httpx_client_wrapper() - async with client.stream( - method=method, - url=backend_url + endpoint, - headers=dict(header), - content=body, - timeout=None, - ) as backend_response: - # Yield headers and status code first. - yield backend_response.headers, backend_response.status_code - # Stream response content. - async for chunk in backend_response.aiter_bytes(): - total_len += len(chunk) - if not first_token: - first_token = True - app.state.request_stats_monitor.on_request_response( - backend_url, request_id, time.time() - ) - # For non-streaming requests, collect the full response - if full_response is not None: - full_response.extend(chunk) - yield chunk - - app.state.request_stats_monitor.on_request_complete( - backend_url, request_id, time.time() - ) - - if semantic_cache_available: - # if debug_request: - # logger.debug(f"Finished the request with request id: {debug_request.headers.get('x-request-id', None)} at {time.time()}") - # Store in semantic cache if applicable - # Use the full response for non-streaming requests, or the last chunk for streaming - cache_chunk = bytes(full_response) if full_response is not None else chunk - await store_in_semantic_cache( - endpoint=endpoint, method=method, body=body, chunk=cache_chunk - ) - - -async def route_general_request(request: Request, endpoint: str): - """ - Route the incoming request to the backend server and stream the response back to the client. - - This function extracts the requested model from the request body and retrieves the - corresponding endpoints. It uses routing logic to determine the best server URL to handle - the request, then streams the request to that server. If the requested model is not available, - it returns an error response. - - Args: - request (Request): The incoming HTTP request. - endpoint (str): The endpoint to which the request should be routed. - - Returns: - StreamingResponse: A response object that streams data from the backend server to the client. - """ - - in_router_time = time.time() - request_id = str(uuid.uuid4()) - request_body = await request.body() - request_json = await request.json() # TODO (ApostaC): merge two awaits into one - requested_model = request_json.get("model", None) - if requested_model is None: - return JSONResponse( - status_code=400, - content={"error": "Invalid request: missing 'model' in request body."}, - ) - - # TODO (ApostaC): merge two awaits into one - endpoints = GetServiceDiscovery().get_endpoint_info() - engine_stats = request.app.state.engine_stats_scraper.get_engine_stats() - request_stats = request.app.state.request_stats_monitor.get_request_stats( - time.time() - ) - - endpoints = list(filter(lambda x: x.model_name == requested_model, endpoints)) - if not endpoints: - return JSONResponse( - status_code=400, content={"error": f"Model {requested_model} not found."} - ) - - logger.debug(f"Routing request {request_id} for model: {requested_model}") - server_url = request.app.state.router.route_request( - endpoints, engine_stats, request_stats, request - ) - curr_time = time.time() - logger.info( - f"Routing request {request_id} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}" - ) - stream_generator = process_request( - request.method, - request.headers, - request_body, - server_url, - request_id, - endpoint=endpoint, - ) - headers, status_code = await anext(stream_generator) - return StreamingResponse( - stream_generator, - status_code=status_code, - headers={key: value for key, value in headers.items()}, - media_type="text/event-stream", - ) - - -# --- File Endpoints --- -@app.post("/v1/files") -async def route_files(request: Request): - """ - Handle file upload requests and save the files to the configured storage. - - Args: - request (Request): The incoming HTTP request. - - Returns: - JSONResponse: A JSON response containing the file metadata. - - Raises: - JSONResponse: A JSON response with a 400 status code if the request is invalid, - or a 500 status code if an error occurs during file saving. - """ - form = await request.form() - purpose = form.get("purpose", "unknown") - if "file" not in form: - return JSONResponse( - status_code=400, content={"error": "Missing required parameter 'file'"} - ) - file_obj: UploadFile = form["file"] - file_content = await file_obj.read() - try: - storage: Storage = app.state.batch_storage - file_info = await storage.save_file( - file_name=file_obj.filename, content=file_content, purpose=purpose - ) - return JSONResponse(content=file_info.metadata()) - except Exception as e: - return JSONResponse( - status_code=500, content={"error": f"Failed to save file: {str(e)}"} - ) - - -@app.get("/v1/files/{file_id}") -async def route_get_file(file_id: str): - try: - storage: Storage = app.state.batch_storage - file = await storage.get_file(file_id) - return JSONResponse(content=file.metadata()) - except FileNotFoundError: - return JSONResponse( - status_code=404, content={"error": f"File {file_id} not found"} - ) - - -@app.get("/v1/files/{file_id}/content") -async def route_get_file_content(file_id: str): - try: - # TODO(gaocegege): Stream the file content with chunks to support - # openai uploads interface. - storage: Storage = app.state.batch_storage - file_content = await storage.get_file_content(file_id) - return Response(content=file_content) - except FileNotFoundError: - return JSONResponse( - status_code=404, content={"error": f"File {file_id} not found"} - ) - - -@app.post("/v1/batches") -async def route_batches(request: Request): - """Handle batch requests that process files with specified endpoints.""" - try: - request_json = await request.json() - - # Validate required fields - if "input_file_id" not in request_json: - return JSONResponse( - status_code=400, - content={"error": "Missing required parameter 'input_file_id'"}, - ) - if "endpoint" not in request_json: - return JSONResponse( - status_code=400, - content={"error": "Missing required parameter 'endpoint'"}, - ) - - # Verify file exists - storage: Storage = app.state.batch_storage - file_id = request_json["input_file_id"] - try: - await storage.get_file(file_id) - except FileNotFoundError: - return JSONResponse( - status_code=404, content={"error": f"File {file_id} not found"} - ) - - batch_processor: BatchProcessor = app.state.batch_processor - batch = await batch_processor.create_batch( - input_file_id=file_id, - endpoint=request_json["endpoint"], - completion_window=request_json.get("completion_window", "5s"), - metadata=request_json.get("metadata", None), - ) - - # Return metadata as attribute, not a callable. - return JSONResponse(content=batch.to_dict()) - - except Exception as e: - return JSONResponse( - status_code=500, - content={"error": f"Failed to process batch request: {str(e)}"}, - ) - - -@app.get("/v1/batches/{batch_id}") -async def route_get_batch(batch_id: str): - try: - batch_processor: BatchProcessor = app.state.batch_processor - batch = await batch_processor.retrieve_batch(batch_id) - return JSONResponse(content=batch.to_dict()) - except FileNotFoundError: - return JSONResponse( - status_code=404, content={"error": f"Batch {batch_id} not found"} - ) - - -@app.get("/v1/batches") -async def route_list_batches(limit: int = 20, after: str = None): - try: - batch_processor: BatchProcessor = app.state.batch_processor - batches = await batch_processor.list_batches(limit=limit, after=after) - - # Convert batches to response format - batch_data = [batch.to_dict() for batch in batches] - - response = { - "object": "list", - "data": batch_data, - "first_id": batch_data[0]["id"] if batch_data else None, - "last_id": batch_data[-1]["id"] if batch_data else None, - "has_more": len(batch_data) - == limit, # If we got limit items, there may be more - } - - return JSONResponse(content=response) - except FileNotFoundError: - return JSONResponse(status_code=404, content={"error": "No batches found"}) - - -@app.delete("/v1/batches/{batch_id}") -async def route_cancel_batch(batch_id: str): - try: - batch_processor: BatchProcessor = app.state.batch_processor - batch = await batch_processor.cancel_batch(batch_id) - return JSONResponse(content=batch.to_dict()) - except FileNotFoundError: - return JSONResponse( - status_code=404, content={"error": f"Batch {batch_id} not found"} - ) - - -@app.post("/v1/batches") -async def route_batches(request: Request): - """Handle batch requests that process files with specified endpoints.""" - try: - request_json = await request.json() - - # Validate required fields - if "input_file_id" not in request_json: - return JSONResponse( - status_code=400, - content={"error": "Missing required parameter 'input_file_id'"}, - ) - if "endpoint" not in request_json: - return JSONResponse( - status_code=400, - content={"error": "Missing required parameter 'endpoint'"}, - ) - - # Verify file exists - storage: Storage = app.state.batch_storage - file_id = request_json["input_file_id"] - try: - await storage.get_file(file_id) - except FileNotFoundError: - return JSONResponse( - status_code=404, content={"error": f"File {file_id} not found"} - ) - - batch_processor: BatchProcessor = app.state.batch_processor - batch = await batch_processor.create_batch( - input_file_id=file_id, - endpoint=request_json["endpoint"], - completion_window=request_json.get("completion_window", "5s"), - metadata=request_json.get("metadata", None), - ) - - # Return metadata as attribute, not a callable. - return JSONResponse(content=batch.to_dict()) - - except Exception as e: - return JSONResponse( - status_code=500, - content={"error": f"Failed to process batch request: {str(e)}"}, - ) - - -@app.get("/v1/batches/{batch_id}") -async def route_get_batch(batch_id: str): - try: - batch_processor: BatchProcessor = app.state.batch_processor - batch = await batch_processor.retrieve_batch(batch_id) - return JSONResponse(content=batch.to_dict()) - except FileNotFoundError: - return JSONResponse( - status_code=404, content={"error": f"Batch {batch_id} not found"} - ) - - -@app.get("/v1/batches") -async def route_list_batches(limit: int = 20, after: str = None): - try: - batch_processor: BatchProcessor = app.state.batch_processor - batches = await batch_processor.list_batches(limit=limit, after=after) - - # Convert batches to response format - batch_data = [batch.to_dict() for batch in batches] - - response = { - "object": "list", - "data": batch_data, - "first_id": batch_data[0]["id"] if batch_data else None, - "last_id": batch_data[-1]["id"] if batch_data else None, - "has_more": len(batch_data) - == limit, # If we got limit items, there may be more - } - - return JSONResponse(content=response) - except FileNotFoundError: - return JSONResponse(status_code=404, content={"error": "No batches found"}) - - -@app.delete("/v1/batches/{batch_id}") -async def route_cancel_batch(batch_id: str): - try: - batch_processor: BatchProcessor = app.state.batch_processor - batch = await batch_processor.cancel_batch(batch_id) - return JSONResponse(content=batch.to_dict()) - except FileNotFoundError: - return JSONResponse( - status_code=404, content={"error": f"Batch {batch_id} not found"} - ) - - -@app.post("/v1/chat/completions") -async def route_chat_completition(request: Request): - if semantic_cache_available: - # Check if the request can be served from the semantic cache - logger.debug("Received chat completion request, checking semantic cache") - cache_response = await check_semantic_cache(request=request) - - if cache_response: - logger.info("Serving response from semantic cache") - return cache_response - - logger.debug("No cache hit, forwarding request to backend") - return await route_general_request(request, "/v1/chat/completions") - - -@app.post("/v1/completions") -async def route_completition(request: Request): - return await route_general_request(request, "/v1/completions") - - -@app.post("/v1/embeddings") -async def route_embeddings(request: Request): - return await route_general_request(request, "/v1/embeddings") - - -@app.post("/v1/rerank") -async def route_v1_rerank(request: Request): - return await route_general_request(request, "/v1/rerank") - - -@app.post("/rerank") -async def route_rerank(request: Request): - return await route_general_request(request, "/rerank") - - -@app.post("/v1/score") -async def route_v1_score(request: Request): - return await route_general_request(request, "/v1/score") - - -@app.post("/score") -async def route_score(request: Request): - return await route_general_request(request, "/score") - - -@app.get("/version") -async def show_version(): - ver = {"version": __version__} - return JSONResponse(content=ver) - - -@app.get("/v1/models") -async def show_models(): - """ - Returns a list of all models available in the stack. - - Args: - None - - Returns: - JSONResponse: A JSON response containing the list of models. - - Raises: - Exception: If there is an error in retrieving the endpoint information. - """ - endpoints = GetServiceDiscovery().get_endpoint_info() - existing_models = set() - model_cards = [] - for endpoint in endpoints: - if endpoint.model_name in existing_models: - continue - model_card = ModelCard( - id=endpoint.model_name, - object="model", - created=endpoint.added_timestamp, - owned_by="vllm", - ) - model_cards.append(model_card) - existing_models.add(endpoint.model_name) - model_list = ModelList(data=model_cards) - return JSONResponse(content=model_list.model_dump()) - - -@app.get("/health") -async def health() -> Response: - """ - Endpoint to check the health status of various components. - - This function verifies the health of the service discovery module and - the engine stats scraper. If either component is down, it returns a - 503 response with the appropriate status message. If both components - are healthy, it returns a 200 OK response. - - Returns: - Response: A JSONResponse with status code 503 if a component is - down, or a plain Response with status code 200 if all components - are healthy. - """ - - if not GetServiceDiscovery().get_health(): - return JSONResponse( - content={"status": "Service discovery module is down."}, status_code=503 - ) - if not GetEngineStatsScraper().get_health(): - return JSONResponse( - content={"status": "Engine stats scraper is down."}, status_code=503 - ) - - if GetDynamicConfigWatcher() is not None: - dynamic_config = GetDynamicConfigWatcher().get_current_config() - return JSONResponse( - content={ - "status": "healthy", - "dynamic_config": json.loads(dynamic_config.to_json_str()), - }, - status_code=200, - ) - else: - return JSONResponse(content={"status": "healthy"}, status_code=200) - - -# --- Prometheus Metrics Endpoint --- -@app.get("/metrics") -async def metrics(): - # Retrieve request stats from the monitor. - """ - Endpoint to expose Prometheus metrics for the vLLM router. - - This function gathers request statistics, engine metrics, and health status - of the service endpoints to update Prometheus gauges. It exports metrics - such as queries per second (QPS), average decoding length, number of prefill - and decoding requests, average latency, average inter-token latency, number - of swapped requests, and the number of healthy pods for each server. The - metrics are used to monitor the performance and health of the vLLM router - services. - - Returns: - Response: A HTTP response containing the latest Prometheus metrics in - the appropriate content type. - """ - - stats = GetRequestStatsMonitor().get_request_stats(time.time()) - for server, stat in stats.items(): - current_qps.labels(server=server).set(stat.qps) - # Assuming stat contains the following attributes: - avg_decoding_length.labels(server=server).set(stat.avg_decoding_length) - num_prefill_requests.labels(server=server).set(stat.in_prefill_requests) - num_decoding_requests.labels(server=server).set(stat.in_decoding_requests) - num_requests_running.labels(server=server).set( - stat.in_prefill_requests + stat.in_decoding_requests - ) - avg_latency.labels(server=server).set(stat.avg_latency) - avg_itl.labels(server=server).set(stat.avg_itl) - num_requests_swapped.labels(server=server).set(stat.num_swapped_requests) - # For healthy pods, we use a hypothetical function from service discovery. - healthy = {} - endpoints = GetServiceDiscovery().get_endpoint_info() - for ep in endpoints: - # Assume each endpoint object has an attribute 'healthy' (1 if healthy, 0 otherwise). - healthy[ep.url] = 1 if getattr(ep, "healthy", True) else 0 - for server, value in healthy.items(): - healthy_pods_total.labels(server=server).set(value) - return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST) - - -# --- Argument Parsing and Initialization --- -def validate_args(args): - if args.service_discovery == "static": - if args.static_backends is None: - raise ValueError( - "Static backends must be provided when using static service discovery." - ) - if args.static_models is None: - raise ValueError( - "Static models must be provided when using static service discovery." - ) - if args.service_discovery == "k8s" and args.k8s_port is None: - raise ValueError("K8s port must be provided when using K8s service discovery.") - if args.routing_logic == "session" and args.session_key is None: - raise ValueError( - "Session key must be provided when using session routing logic." - ) - if args.log_stats and args.log_stats_interval <= 0: - raise ValueError("Log stats interval must be greater than 0.") - if args.engine_stats_interval <= 0: - raise ValueError("Engine stats interval must be greater than 0.") - if args.request_stats_window <= 0: - raise ValueError("Request stats window must be greater than 0.") - - -def parse_args(): - parser = argparse.ArgumentParser(description="Run the FastAPI app.") - parser.add_argument( - "--host", default="0.0.0.0", help="The host to run the server on." - ) - parser.add_argument( - "--port", type=int, default=8001, help="The port to run the server on." - ) - parser.add_argument( - "--service-discovery", - required=True, - choices=["static", "k8s"], - help="The service discovery type.", - ) - parser.add_argument( - "--static-backends", - type=str, - default=None, - help="The URLs of static backends, separated by commas. E.g., http://localhost:8000,http://localhost:8001", - ) - parser.add_argument( - "--static-models", - type=str, - default=None, - help="The models of static backends, separated by commas. E.g., model1,model2", - ) - parser.add_argument( - "--k8s-port", - type=int, - default=8000, - help="The port of vLLM processes when using K8s service discovery.", - ) - parser.add_argument( - "--k8s-namespace", - type=str, - default="default", - help="The namespace of vLLM pods when using K8s service discovery.", - ) - parser.add_argument( - "--k8s-label-selector", - type=str, - default="", - help="The label selector to filter vLLM pods when using K8s service discovery.", - ) - parser.add_argument( - "--routing-logic", - type=str, - required=True, - choices=["roundrobin", "session"], - help="The routing logic to use", - ) - parser.add_argument( - "--session-key", - type=str, - default=None, - help="The key (in the header) to identify a session.", - ) - - # Batch API - # TODO(gaocegege): Make these batch api related arguments to a separate config. - parser.add_argument( - "--enable-batch-api", - action="store_true", - help="Enable the batch API for processing files.", - ) - parser.add_argument( - "--file-storage-class", - type=str, - default="local_file", - choices=["local_file"], - help="The file storage class to use.", - ) - parser.add_argument( - "--file-storage-path", - type=str, - default="/tmp/vllm_files", - help="The path to store files.", - ) - parser.add_argument( - "--batch-processor", - type=str, - default="local", - choices=["local"], - help="The batch processor to use.", - ) - - # Monitoring - parser.add_argument( - "--engine-stats-interval", - type=int, - default=30, - help="The interval in seconds to scrape engine statistics.", - ) - parser.add_argument( - "--request-stats-window", - type=int, - default=60, - help="The sliding window in seconds to compute request statistics.", - ) - parser.add_argument( - "--log-stats", action="store_true", help="Log statistics periodically." - ) - parser.add_argument( - "--log-stats-interval", - type=int, - default=10, - help="The interval in seconds to log statistics.", - ) - - parser.add_argument( - "--dynamic-config-json", - type=str, - default=None, - help="The path to the json file containing the dynamic configuration.", - ) - - # Add --version argument - parser.add_argument( - "--version", - action="version", - version=f"%(prog)s {__version__}", - help="Show version and exit", - ) - - if semantic_cache_available: - # Add semantic cache arguments - add_semantic_cache_args(parser) - - # Add feature gates argument - parser.add_argument( - "--feature-gates", - type=str, - default="", - help="Comma-separated list of feature gates (e.g., 'SemanticCache=true')", - ) - - # Add log level argument - parser.add_argument( - "--log-level", - type=str, - default="info", - choices=["critical", "error", "warning", "info", "debug", "trace"], - help="Log level for uvicorn. Default is 'info'.", - ) - args = parser.parse_args() - validate_args(args) - return args - - -def InitializeAll(args): - """ - Initialize all the components of the router with the given arguments. - - Args: - args: the parsed command-line arguments - - Raises: - ValueError: if the service discovery type is invalid - """ - if args.service_discovery == "static": - InitializeServiceDiscovery( - ServiceDiscoveryType.STATIC, - urls=parse_static_urls(args.static_backends), - models=parse_static_model_names(args.static_models), - ) - elif args.service_discovery == "k8s": - InitializeServiceDiscovery( - ServiceDiscoveryType.K8S, - namespace=args.k8s_namespace, - port=args.k8s_port, - label_selector=args.k8s_label_selector, - ) - else: - raise ValueError(f"Invalid service discovery type: {args.service_discovery}") - - # Initialize singletons via custom functions. - InitializeEngineStatsScraper(args.engine_stats_interval) - InitializeRequestStatsMonitor(args.request_stats_window) - - if args.enable_batch_api: - logger.info("Initializing batch API") - app.state.batch_storage = initialize_storage( - args.file_storage_class, args.file_storage_path - ) - app.state.batch_processor = initialize_batch_processor( - args.batch_processor, args.file_storage_path, app.state.batch_storage - ) - - InitializeRoutingLogic(args.routing_logic, session_key=args.session_key) - - # Initialize feature gates - initialize_feature_gates(args.feature_gates) - # Check if the SemanticCache feature gate is enabled - feature_gates = get_feature_gates() - if semantic_cache_available: - if feature_gates.is_enabled("SemanticCache"): - # The feature gate is enabled, explicitly enable the semantic cache - enable_semantic_cache() - - # Verify that the semantic cache was successfully enabled - if not is_semantic_cache_enabled(): - logger.error("Failed to enable semantic cache feature") - - logger.info("SemanticCache feature gate is enabled") - - # Initialize the semantic cache with the model if specified - if args.semantic_cache_model: - logger.info( - f"Initializing semantic cache with model: {args.semantic_cache_model}" - ) - logger.info( - f"Semantic cache directory: {args.semantic_cache_dir or 'default'}" - ) - logger.info( - f"Semantic cache threshold: {args.semantic_cache_threshold}" - ) - - cache = InitializeSemanticCache( - embedding_model=args.semantic_cache_model, - cache_dir=args.semantic_cache_dir, - default_similarity_threshold=args.semantic_cache_threshold, - ) - - # Update cache size metric - if cache and hasattr(cache, "db") and hasattr(cache.db, "index"): - semantic_cache_size.labels(server="router").set( - cache.db.index.ntotal - ) - logger.info( - f"Semantic cache initialized with {cache.db.index.ntotal} entries" - ) - - logger.info( - f"Semantic cache initialized with model {args.semantic_cache_model}" - ) - else: - logger.warning( - "SemanticCache feature gate is enabled but no embedding model specified. " - "The semantic cache will not be functional without an embedding model. " - "Use --semantic-cache-model to specify an embedding model." - ) - elif args.semantic_cache_model: - logger.warning( - "Semantic cache model specified but SemanticCache feature gate is not enabled. " - "Enable the feature gate with --feature-gates=SemanticCache=true" - ) - - # --- Hybrid addition: attach singletons to FastAPI state --- - app.state.engine_stats_scraper = GetEngineStatsScraper() - app.state.request_stats_monitor = GetRequestStatsMonitor() - app.state.router = GetRoutingLogic() - - # Initialize dynamic config watcher - if args.dynamic_config_json: - init_config = DynamicRouterConfig.from_args(args) - InitializeDynamicConfigWatcher(args.dynamic_config_json, 10, init_config, app) - - -def log_stats(interval: int = 10): - """ - Periodically logs the engine and request statistics for each service endpoint. - - This function retrieves the current service endpoints and their corresponding - engine and request statistics, and logs them at a specified interval. The - statistics include the number of running and queued requests, GPU cache hit - rate, queries per second (QPS), average latency, average inter-token latency - (ITL), and more. These statistics are also updated in the Prometheus metrics. - - Args: - interval (int): The interval in seconds at which statistics are logged. - Default is 10 seconds. - """ - - while True: - time.sleep(interval) - logstr = "\n" + "=" * 50 + "\n" - endpoints = GetServiceDiscovery().get_endpoint_info() - engine_stats = app.state.engine_stats_scraper.get_engine_stats() - request_stats = app.state.request_stats_monitor.get_request_stats(time.time()) - for endpoint in endpoints: - url = endpoint.url - logstr += f"Model: {endpoint.model_name}\n" - logstr += f"Server: {url}\n" - if url in engine_stats: - es = engine_stats[url] - logstr += ( - f" Engine Stats: Running Requests: {es.num_running_requests}, " - f"Queued Requests: {es.num_queuing_requests}, " - f"GPU Cache Hit Rate: {es.gpu_prefix_cache_hit_rate:.2f}\n" - ) - else: - logstr += " Engine Stats: No stats available\n" - if url in request_stats: - rs = request_stats[url] - logstr += ( - f" Request Stats: QPS: {rs.qps:.2f}, " - f"Avg Latency: {rs.avg_latency}, " - f"Avg ITL: {rs.avg_itl}, " - f"Prefill Requests: {rs.in_prefill_requests}, " - f"Decoding Requests: {rs.in_decoding_requests}, " - f"Swapped Requests: {rs.num_swapped_requests}, " - f"Finished: {rs.finished_requests}, " - f"Uptime: {rs.uptime:.2f} sec\n" - ) - current_qps.labels(server=url).set(rs.qps) - avg_decoding_length.labels(server=url).set(rs.avg_decoding_length) - num_prefill_requests.labels(server=url).set(rs.in_prefill_requests) - num_decoding_requests.labels(server=url).set(rs.in_decoding_requests) - num_requests_running.labels(server=url).set( - rs.in_prefill_requests + rs.in_decoding_requests - ) - avg_latency.labels(server=url).set(rs.avg_latency) - avg_itl.labels(server=url).set(rs.avg_itl) - num_requests_swapped.labels(server=url).set(rs.num_swapped_requests) - else: - logstr += " Request Stats: No stats available\n" - logstr += "-" * 50 + "\n" - logstr += "=" * 50 + "\n" - logger.info(logstr) - - -def main(): - args = parse_args() - InitializeAll(args) - if args.log_stats: - threading.Thread( - target=log_stats, args=(args.log_stats_interval,), daemon=True - ).start() - - # Workaround to avoid footguns where uvicorn drops requests with too - # many concurrent requests active. - set_ulimit() - uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level) - - -if __name__ == "__main__": - main() diff --git a/src/vllm_router/routers/__init__.py b/src/vllm_router/routers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/vllm_router/routers/batches_router.py b/src/vllm_router/routers/batches_router.py new file mode 100644 index 00000000..8a388d31 --- /dev/null +++ b/src/vllm_router/routers/batches_router.py @@ -0,0 +1,100 @@ +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse + +from vllm_router.services.batch_service.processor import BatchProcessor +from vllm_router.services.files_service import Storage + +batches_router = APIRouter() + + +@batches_router.post("/v1/batches") +async def route_batches(request: Request): + """Handle batch requests that process files with specified endpoints.""" + try: + request_json = await request.json() + + # Validate required fields + if "input_file_id" not in request_json: + return JSONResponse( + status_code=400, + content={"error": "Missing required parameter 'input_file_id'"}, + ) + if "endpoint" not in request_json: + return JSONResponse( + status_code=400, + content={"error": "Missing required parameter 'endpoint'"}, + ) + + # Verify file exists + storage: Storage = request.app.state.batch_storage + file_id = request_json["input_file_id"] + try: + await storage.get_file(file_id) + except FileNotFoundError: + return JSONResponse( + status_code=404, content={"error": f"File {file_id} not found"} + ) + + batch_processor: BatchProcessor = request.app.state.batch_processor + batch = await batch_processor.create_batch( + input_file_id=file_id, + endpoint=request_json["endpoint"], + completion_window=request_json.get("completion_window", "5s"), + metadata=request_json.get("metadata", None), + ) + + # Return metadata as attribute, not a callable. + return JSONResponse(content=batch.to_dict()) + + except Exception as e: + return JSONResponse( + status_code=500, + content={"error": f"Failed to process batch request: {str(e)}"}, + ) + + +@batches_router.get("/v1/batches/{batch_id}") +async def route_get_batch(request: Request, batch_id: str): + try: + batch_processor: BatchProcessor = request.app.state.batch_processor + batch = await batch_processor.retrieve_batch(batch_id) + return JSONResponse(content=batch.to_dict()) + except FileNotFoundError: + return JSONResponse( + status_code=404, content={"error": f"Batch {batch_id} not found"} + ) + + +@batches_router.get("/v1/batches") +async def route_list_batches(request: Request, limit: int = 20, after: str = None): + try: + batch_processor: BatchProcessor = request.app.state.batch_processor + batches = await batch_processor.list_batches(limit=limit, after=after) + + # Convert batches to response format + batch_data = [batch.to_dict() for batch in batches] + + response = { + "object": "list", + "data": batch_data, + "first_id": batch_data[0]["id"] if batch_data else None, + "last_id": batch_data[-1]["id"] if batch_data else None, + "has_more": len(batch_data) + == limit, # If we got limit items, there may be more + } + + return JSONResponse(content=response) + except FileNotFoundError: + return JSONResponse(status_code=404, content={"error": "No batches found"}) + + +@batches_router.delete("/v1/batches/{batch_id}") +async def route_cancel_batch(request: Request, batch_id: str): + try: + batch_processor: BatchProcessor = request.app.state.batch_processor + batch = await batch_processor.cancel_batch(batch_id) + return JSONResponse(content=batch.to_dict()) + except FileNotFoundError: + return JSONResponse( + status_code=404, content={"error": f"Batch {batch_id} not found"} + ) diff --git a/src/vllm_router/routers/files_router.py b/src/vllm_router/routers/files_router.py new file mode 100644 index 00000000..1354c325 --- /dev/null +++ b/src/vllm_router/routers/files_router.py @@ -0,0 +1,68 @@ +from fastapi import APIRouter, Request, UploadFile +from fastapi.responses import JSONResponse, Response + +from vllm_router.services.files_service import Storage + +files_router = APIRouter() + + +# --- File Endpoints --- +@files_router.post("/v1/files") +async def route_files(request: Request): + """ + Handle file upload requests and save the files to the configured storage. + + Args: + request (Request): The incoming HTTP request. + + Returns: + JSONResponse: A JSON response containing the file metadata. + + Raises: + JSONResponse: A JSON response with a 400 status code if the request is invalid, + or a 500 status code if an error occurs during file saving. + """ + form = await request.form() + purpose = form.get("purpose", "unknown") + if "file" not in form: + return JSONResponse( + status_code=400, content={"error": "Missing required parameter 'file'"} + ) + file_obj: UploadFile = form["file"] + file_content = await file_obj.read() + try: + storage: Storage = request.app.state.batch_storage + file_info = await storage.save_file( + file_name=file_obj.filename, content=file_content, purpose=purpose + ) + return JSONResponse(content=file_info.metadata()) + except Exception as e: + return JSONResponse( + status_code=500, content={"error": f"Failed to save file: {str(e)}"} + ) + + +@files_router.get("/v1/files/{file_id}") +async def route_get_file(request: Request, file_id: str): + try: + storage: Storage = request.app.state.batch_storage + file = await storage.get_file(file_id) + return JSONResponse(content=file.metadata()) + except FileNotFoundError: + return JSONResponse( + status_code=404, content={"error": f"File {file_id} not found"} + ) + + +@files_router.get("/v1/files/{file_id}/content") +async def route_get_file_content(request: Request, file_id: str): + try: + # TODO(gaocegege): Stream the file content with chunks to support + # openai uploads interface. + storage: Storage = request.app.state.batch_storage + file_content = await storage.get_file_content(file_id) + return Response(content=file_content) + except FileNotFoundError: + return JSONResponse( + status_code=404, content={"error": f"File {file_id} not found"} + ) diff --git a/src/vllm_router/routers/main_router.py b/src/vllm_router/routers/main_router.py new file mode 100644 index 00000000..120996b6 --- /dev/null +++ b/src/vllm_router/routers/main_router.py @@ -0,0 +1,159 @@ +import json + +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse, Response + +from vllm_router.dynamic_config import get_dynamic_config_watcher +from vllm_router.log import init_logger +from vllm_router.protocols import ModelCard, ModelList +from vllm_router.service_discovery import get_service_discovery +from vllm_router.services.request_service.request import route_general_request +from vllm_router.stats.engine_stats import get_engine_stats_scraper +from vllm_router.version import __version__ + +try: + # Semantic cache integration + from vllm_router.experimental.semantic_cache import ( + GetSemanticCache, + enable_semantic_cache, + initialize_semantic_cache, + is_semantic_cache_enabled, + ) + from vllm_router.experimental.semantic_cache_integration import ( + add_semantic_cache_args, + check_semantic_cache, + semantic_cache_hit_ratio, + semantic_cache_hits, + semantic_cache_latency, + semantic_cache_misses, + semantic_cache_size, + store_in_semantic_cache, + ) + + semantic_cache_available = True +except ImportError: + semantic_cache_available = False + +main_router = APIRouter() + +logger = init_logger(__name__) + + +@main_router.post("/v1/chat/completions") +async def route_chat_completion(request: Request): + # Check if the request can be served from the semantic cache + logger.debug("Received chat completion request, checking semantic cache") + cache_response = await check_semantic_cache(request=request) + + if cache_response: + logger.info("Serving response from semantic cache") + return cache_response + + logger.debug("No cache hit, forwarding request to backend") + return await route_general_request(request, "/v1/chat/completions") + + +@main_router.post("/v1/completions") +async def route_completion(request: Request): + return await route_general_request(request, "/v1/completions") + + +@main_router.post("/v1/embeddings") +async def route_embeddings(request: Request): + return await route_general_request(request, "/v1/embeddings") + + +@main_router.post("/v1/rerank") +async def route_v1_rerank(request: Request): + return await route_general_request(request, "/v1/rerank") + + +@main_router.post("/rerank") +async def route_rerank(request: Request): + return await route_general_request(request, "/rerank") + + +@main_router.post("/v1/score") +async def route_v1_score(request: Request): + return await route_general_request(request, "/v1/score") + + +@main_router.post("/score") +async def route_score(request: Request): + return await route_general_request(request, "/score") + + +@main_router.get("/version") +async def show_version(): + ver = {"version": __version__} + return JSONResponse(content=ver) + + +@main_router.get("/v1/models") +async def show_models(): + """ + Returns a list of all models available in the stack. + + Args: + None + + Returns: + JSONResponse: A JSON response containing the list of models. + + Raises: + Exception: If there is an error in retrieving the endpoint information. + """ + endpoints = get_service_discovery().get_endpoint_info() + existing_models = set() + model_cards = [] + for endpoint in endpoints: + if endpoint.model_name in existing_models: + continue + model_card = ModelCard( + id=endpoint.model_name, + object="model", + created=endpoint.added_timestamp, + owned_by="vllm", + ) + model_cards.append(model_card) + existing_models.add(endpoint.model_name) + model_list = ModelList(data=model_cards) + return JSONResponse(content=model_list.model_dump()) + + +@main_router.get("/health") +async def health() -> Response: + """ + Endpoint to check the health status of various components. + + This function verifies the health of the service discovery module and + the engine stats scraper. If either component is down, it returns a + 503 response with the appropriate status message. If both components + are healthy, it returns a 200 OK response. + + Returns: + Response: A JSONResponse with status code 503 if a component is + down, or a plain Response with status code 200 if all components + are healthy. + """ + + if not get_service_discovery().get_health(): + return JSONResponse( + content={"status": "Service discovery module is down."}, status_code=503 + ) + if not get_engine_stats_scraper().get_health(): + return JSONResponse( + content={"status": "Engine stats scraper is down."}, status_code=503 + ) + + if get_dynamic_config_watcher() is not None: + dynamic_config = get_dynamic_config_watcher().get_current_config() + return JSONResponse( + content={ + "status": "healthy", + "dynamic_config": json.loads(dynamic_config.to_json_str()), + }, + status_code=200, + ) + else: + return JSONResponse(content={"status": "healthy"}, status_code=200) diff --git a/src/vllm_router/routers/metrics_router.py b/src/vllm_router/routers/metrics_router.py new file mode 100644 index 00000000..0a5218a4 --- /dev/null +++ b/src/vllm_router/routers/metrics_router.py @@ -0,0 +1,64 @@ +import time + +from fastapi import APIRouter, Response +from prometheus_client import CONTENT_TYPE_LATEST, generate_latest + +from vllm_router.service_discovery import get_service_discovery +from vllm_router.services.metrics_service import ( + avg_decoding_length, + avg_itl, + avg_latency, + current_qps, + healthy_pods_total, + num_decoding_requests, + num_prefill_requests, + num_requests_running, + num_requests_swapped, +) +from vllm_router.stats.request_stats import get_request_stats_monitor + +metrics_router = APIRouter() + + +# --- Prometheus Metrics Endpoint --- +@metrics_router.get("/metrics") +async def metrics(): + # Retrieve request stats from the monitor. + """ + Endpoint to expose Prometheus metrics for the vLLM router. + + This function gathers request statistics, engine metrics, and health status + of the service endpoints to update Prometheus gauges. It exports metrics + such as queries per second (QPS), average decoding length, number of prefill + and decoding requests, average latency, average inter-token latency, number + of swapped requests, and the number of healthy pods for each server. The + metrics are used to monitor the performance and health of the vLLM router + services. + + Returns: + Response: A HTTP response containing the latest Prometheus metrics in + the appropriate content type. + """ + + stats = get_request_stats_monitor().get_request_stats(time.time()) + for server, stat in stats.items(): + current_qps.labels(server=server).set(stat.qps) + # Assuming stat contains the following attributes: + avg_decoding_length.labels(server=server).set(stat.avg_decoding_length) + num_prefill_requests.labels(server=server).set(stat.in_prefill_requests) + num_decoding_requests.labels(server=server).set(stat.in_decoding_requests) + num_requests_running.labels(server=server).set( + stat.in_prefill_requests + stat.in_decoding_requests + ) + avg_latency.labels(server=server).set(stat.avg_latency) + avg_itl.labels(server=server).set(stat.avg_itl) + num_requests_swapped.labels(server=server).set(stat.num_swapped_requests) + # For healthy pods, we use a hypothetical function from service discovery. + healthy = {} + endpoints = get_service_discovery().get_endpoint_info() + for ep in endpoints: + # Assume each endpoint object has an attribute 'healthy' (1 if healthy, 0 otherwise). + healthy[ep.url] = 1 if getattr(ep, "healthy", True) else 0 + for server, value in healthy.items(): + healthy_pods_total.labels(server=server).set(value) + return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST) diff --git a/src/vllm_router/routing_logic.py b/src/vllm_router/routers/routing_logic.py similarity index 95% rename from src/vllm_router/routing_logic.py rename to src/vllm_router/routers/routing_logic.py index e69d3f58..147684dc 100644 --- a/src/vllm_router/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -1,15 +1,14 @@ import abc import enum -import hashlib -from typing import Dict, List, Optional +from typing import Dict, List from fastapi import Request from uhashring import HashRing -from vllm_router.engine_stats import EngineStats from vllm_router.log import init_logger -from vllm_router.request_stats import RequestStats from vllm_router.service_discovery import EndpointInfo +from vllm_router.stats.engine_stats import EngineStats +from vllm_router.stats.request_stats import RequestStats from vllm_router.utils import SingletonABCMeta logger = init_logger(__name__) @@ -174,7 +173,7 @@ def route_request( # Instead of managing a global _global_router, we can define the initialization functions as: -def InitializeRoutingLogic( +def initialize_routing_logic( routing_logic: RoutingLogic, *args, **kwargs ) -> RoutingInterface: if routing_logic == RoutingLogic.ROUND_ROBIN: @@ -187,17 +186,17 @@ def InitializeRoutingLogic( raise ValueError(f"Invalid routing logic {routing_logic}") -def ReconfigureRoutingLogic( +def reconfigure_routing_logic( routing_logic: RoutingLogic, *args, **kwargs ) -> RoutingInterface: # Remove the existing routers from the singleton registry for cls in (SessionRouter, RoundRobinRouter): if cls in SingletonABCMeta._instances: del SingletonABCMeta._instances[cls] - return InitializeRoutingLogic(routing_logic, *args, **kwargs) + return initialize_routing_logic(routing_logic, *args, **kwargs) -def GetRoutingLogic() -> RoutingInterface: +def get_routing_logic() -> RoutingInterface: # Look up in our singleton registry which router (if any) has been created. for cls in (SessionRouter, RoundRobinRouter): if cls in SingletonABCMeta._instances: diff --git a/src/vllm_router/run-router.sh b/src/vllm_router/run-router.sh index b1b6ddbb..58643900 100755 --- a/src/vllm_router/run-router.sh +++ b/src/vllm_router/run-router.sh @@ -5,7 +5,7 @@ if [[ $# -ne 1 ]]; then fi # Use this command when testing with k8s service discovery -# python3 -m vllm_router.router --port "$1" \ +# python3 -m vllm_router.app --port "$1" \ # --service-discovery k8s \ # --k8s-label-selector release=test \ # --k8s-namespace default \ @@ -15,7 +15,7 @@ fi # --log-stats # Use this command when testing with static service discovery -python3 -m vllm_router.router --port "$1" \ +python3 -m vllm_router.app --port "$1" \ --service-discovery static \ --static-backends "http://localhost:9000" \ --static-models "fake_model_name" \ diff --git a/src/vllm_router/service_discovery.py b/src/vllm_router/service_discovery.py index 3a1248b7..7643bc98 100644 --- a/src/vllm_router/service_discovery.py +++ b/src/vllm_router/service_discovery.py @@ -285,7 +285,7 @@ def _create_service_discovery( raise ValueError("Invalid service discovery type") -def InitializeServiceDiscovery( +def initialize_service_discovery( service_discovery_type: ServiceDiscoveryType, *args, **kwargs ) -> ServiceDiscovery: """ @@ -313,7 +313,7 @@ def InitializeServiceDiscovery( return _global_service_discovery -def ReconfigureServiceDiscovery( +def reconfigure_service_discovery( service_discovery_type: ServiceDiscoveryType, *args, **kwargs ) -> ServiceDiscovery: """ @@ -332,7 +332,7 @@ def ReconfigureServiceDiscovery( return _global_service_discovery -def GetServiceDiscovery() -> ServiceDiscovery: +def get_service_discovery() -> ServiceDiscovery: """ Get the initialized service discovery module. @@ -352,15 +352,14 @@ def GetServiceDiscovery() -> ServiceDiscovery: if __name__ == "__main__": # Test the service discovery # k8s_sd = K8sServiceDiscovery("default", 8000, "release=test") - InitializeServiceDiscovery( + initialize_service_discovery( ServiceDiscoveryType.K8S, namespace="default", port=8000, label_selector="release=test", ) - k8s_sd = GetServiceDiscovery() - import time + k8s_sd = get_service_discovery() time.sleep(1) while True: diff --git a/src/vllm_router/services/__init__.py b/src/vllm_router/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/vllm_router/batch/__init__.py b/src/vllm_router/services/batch_service/__init__.py similarity index 74% rename from src/vllm_router/batch/__init__.py rename to src/vllm_router/services/batch_service/__init__.py index 1afc7185..140308b2 100644 --- a/src/vllm_router/batch/__init__.py +++ b/src/vllm_router/services/batch_service/__init__.py @@ -1,6 +1,5 @@ -from vllm_router.batch.batch import BatchEndpoint, BatchInfo, BatchRequest, BatchStatus -from vllm_router.batch.processor import BatchProcessor -from vllm_router.files import Storage +from vllm_router.services.batch_service.processor import BatchProcessor +from vllm_router.services.files_service import Storage def initialize_batch_processor( diff --git a/src/vllm_router/batch/batch.py b/src/vllm_router/services/batch_service/batch.py similarity index 100% rename from src/vllm_router/batch/batch.py rename to src/vllm_router/services/batch_service/batch.py diff --git a/src/vllm_router/batch/local_processor.py b/src/vllm_router/services/batch_service/local_processor.py similarity index 98% rename from src/vllm_router/batch/local_processor.py rename to src/vllm_router/services/batch_service/local_processor.py index 506be8a6..9a7c3339 100644 --- a/src/vllm_router/batch/local_processor.py +++ b/src/vllm_router/services/batch_service/local_processor.py @@ -9,9 +9,9 @@ import aiosqlite from vllm_router.batch.batch import BatchInfo, BatchStatus -from vllm_router.batch.processor import BatchProcessor -from vllm_router.files import Storage from vllm_router.log import init_logger +from vllm_router.services.batch_service.processor import BatchProcessor +from vllm_router.services.files_service import Storage logger = init_logger(__name__) diff --git a/src/vllm_router/batch/processor.py b/src/vllm_router/services/batch_service/processor.py similarity index 89% rename from src/vllm_router/batch/processor.py rename to src/vllm_router/services/batch_service/processor.py index ce321da4..196ac3be 100644 --- a/src/vllm_router/batch/processor.py +++ b/src/vllm_router/services/batch_service/processor.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from vllm_router.batch.batch import BatchInfo -from vllm_router.files import Storage +from vllm_router.services.batch_service.batch import BatchInfo +from vllm_router.services.files_service.storage import Storage class BatchProcessor(ABC): diff --git a/src/vllm_router/services/files_service/__init__.py b/src/vllm_router/services/files_service/__init__.py new file mode 100644 index 00000000..a0545ffb --- /dev/null +++ b/src/vllm_router/services/files_service/__init__.py @@ -0,0 +1,8 @@ +from vllm_router.services.files_service.storage import Storage, initialize_storage + +__all__ = [ + "OpenAIFile", + "Storage", + "FileStorage", + "initialize_storage", +] diff --git a/src/vllm_router/files/file_storage.py b/src/vllm_router/services/files_service/file_storage.py similarity index 96% rename from src/vllm_router/files/file_storage.py rename to src/vllm_router/services/files_service/file_storage.py index d8d7ce21..6482b68f 100644 --- a/src/vllm_router/files/file_storage.py +++ b/src/vllm_router/services/files_service/file_storage.py @@ -4,9 +4,9 @@ import aiofiles -from vllm_router.files.files import OpenAIFile -from vllm_router.files.storage import Storage from vllm_router.log import init_logger +from vllm_router.services.files_service.openai_files import OpenAIFile +from vllm_router.services.files_service.storage import Storage logger = init_logger(__name__) diff --git a/src/vllm_router/files/files.py b/src/vllm_router/services/files_service/openai_files.py similarity index 100% rename from src/vllm_router/files/files.py rename to src/vllm_router/services/files_service/openai_files.py diff --git a/src/vllm_router/files/storage.py b/src/vllm_router/services/files_service/storage.py similarity index 97% rename from src/vllm_router/files/storage.py rename to src/vllm_router/services/files_service/storage.py index 5247448c..0f87e376 100644 --- a/src/vllm_router/files/storage.py +++ b/src/vllm_router/services/files_service/storage.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List -from vllm_router.files.files import OpenAIFile +from vllm_router.services.files_service.openai_files import OpenAIFile class Storage(ABC): @@ -150,7 +150,7 @@ def initialize_storage(storage_type: str, base_path: str = None) -> Storage: like base_path should be in a config object. """ if storage_type == "local_file": - from vllm_router.files.file_storage import FileStorage + from vllm_router.services.files_service.file_storage import FileStorage return FileStorage(base_path) else: diff --git a/src/vllm_router/services/metrics_service/__init__.py b/src/vllm_router/services/metrics_service/__init__.py new file mode 100644 index 00000000..926da993 --- /dev/null +++ b/src/vllm_router/services/metrics_service/__init__.py @@ -0,0 +1,32 @@ +from prometheus_client import Gauge + +# --- Prometheus Gauges --- +# Existing metrics +num_requests_running = Gauge( + "vllm:num_requests_running", "Number of running requests", ["server"] +) +num_requests_waiting = Gauge( + "vllm:num_requests_waiting", "Number of waiting requests", ["server"] +) +current_qps = Gauge("vllm:current_qps", "Current Queries Per Second", ["server"]) +avg_decoding_length = Gauge( + "vllm:avg_decoding_length", "Average Decoding Length", ["server"] +) +num_prefill_requests = Gauge( + "vllm:num_prefill_requests", "Number of Prefill Requests", ["server"] +) +num_decoding_requests = Gauge( + "vllm:num_decoding_requests", "Number of Decoding Requests", ["server"] +) + +# New metrics per dashboard update +healthy_pods_total = Gauge( + "vllm:healthy_pods_total", "Number of healthy vLLM pods", ["server"] +) +avg_latency = Gauge( + "vllm:avg_latency", "Average end-to-end request latency", ["server"] +) +avg_itl = Gauge("vllm:avg_itl", "Average Inter-Token Latency", ["server"]) +num_requests_swapped = Gauge( + "vllm:num_requests_swapped", "Number of swapped requests", ["server"] +) diff --git a/src/vllm_router/services/metrics_service/prometheus_gauge.py b/src/vllm_router/services/metrics_service/prometheus_gauge.py new file mode 100644 index 00000000..e69de29b diff --git a/src/vllm_router/services/request_service/__init__.py b/src/vllm_router/services/request_service/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py new file mode 100644 index 00000000..9acad304 --- /dev/null +++ b/src/vllm_router/services/request_service/request.py @@ -0,0 +1,178 @@ +# --- Request Processing & Routing --- +# TODO: better request id system +import json +import time +import uuid + +from fastapi import Request +from fastapi.responses import JSONResponse, StreamingResponse + +from vllm_router.log import init_logger +from vllm_router.service_discovery import get_service_discovery + +try: + # Semantic cache integration + from vllm_router.experimental.semantic_cache import ( + GetSemanticCache, + enable_semantic_cache, + initialize_semantic_cache, + is_semantic_cache_enabled, + ) + from vllm_router.experimental.semantic_cache_integration import ( + add_semantic_cache_args, + check_semantic_cache, + semantic_cache_hit_ratio, + semantic_cache_hits, + semantic_cache_latency, + semantic_cache_misses, + semantic_cache_size, + store_in_semantic_cache, + ) + + semantic_cache_available = True +except ImportError: + semantic_cache_available = False + + +logger = init_logger(__name__) + + +async def process_request( + request: Request, body, backend_url, request_id, endpoint, debug_request=None +): + """ + Process a request by sending it to the chosen backend. + + Args: + request(Request): Request object. + body: The content of the request to send to the backend. + backend_url: The URL of the backend to send the request to. + request_id: A unique identifier for the request. + endpoint: The endpoint to send the request to on the backend. + debug_request: The original request object from the client, used for + optional debug logging. + + Yields: + The response headers and status code, followed by the response content. + + Raises: + HTTPError: If the backend returns a 4xx or 5xx status code. + """ + first_token = False + total_len = 0 + start_time = time.time() + request.app.state.request_stats_monitor.on_new_request( + backend_url, request_id, start_time + ) + # Check if this is a streaming request + is_streaming = False + try: + request_json = json.loads(body) + is_streaming = request_json.get("stream", False) + except: + # If we can't parse the body as JSON, assume it's not streaming + pass + + # For non-streaming requests, collect the full response to cache it properly + full_response = bytearray() if not is_streaming else None + + async with request.app.state.httpx_client_wrapper().stream( + method=request.method, + url=backend_url + endpoint, + headers=dict(request.headers), + content=body, + timeout=None, + ) as backend_response: + # Yield headers and status code first. + yield backend_response.headers, backend_response.status_code + # Stream response content. + async for chunk in backend_response.aiter_bytes(): + total_len += len(chunk) + if not first_token: + first_token = True + request.app.state.request_stats_monitor.on_request_response( + backend_url, request_id, time.time() + ) + # For non-streaming requests, collect the full response + if full_response is not None: + full_response.extend(chunk) + yield chunk + + request.app.state.request_stats_monitor.on_request_complete( + backend_url, request_id, time.time() + ) + + # if debug_request: + # logger.debug(f"Finished the request with request id: {debug_request.headers.get('x-request-id', None)} at {time.time()}") + # Store in semantic cache if applicable + # Use the full response for non-streaming requests, or the last chunk for streaming + if request.app.state.semantic_cache_available: + cache_chunk = bytes(full_response) if full_response is not None else chunk + await store_in_semantic_cache( + endpoint=endpoint, method=request.method, body=body, chunk=cache_chunk + ) + + +async def route_general_request(request: Request, endpoint: str): + """ + Route the incoming request to the backend server and stream the response back to the client. + + This function extracts the requested model from the request body and retrieves the + corresponding endpoints. It uses routing logic to determine the best server URL to handle + the request, then streams the request to that server. If the requested model is not available, + it returns an error response. + + Args: + request (Request): The incoming HTTP request. + endpoint (str): The endpoint to which the request should be routed. + + Returns: + StreamingResponse: A response object that streams data from the backend server to the client. + """ + + in_router_time = time.time() + request_id = str(uuid.uuid4()) + request_body = await request.body() + request_json = await request.json() # TODO (ApostaC): merge two awaits into one + requested_model = request_json.get("model", None) + if requested_model is None: + return JSONResponse( + status_code=400, + content={"error": "Invalid request: missing 'model' in request body."}, + ) + + # TODO (ApostaC): merge two awaits into one + endpoints = get_service_discovery().get_endpoint_info() + engine_stats = request.app.state.engine_stats_scraper.get_engine_stats() + request_stats = request.app.state.request_stats_monitor.get_request_stats( + time.time() + ) + + endpoints = list(filter(lambda x: x.model_name == requested_model, endpoints)) + if not endpoints: + return JSONResponse( + status_code=400, content={"error": f"Model {requested_model} not found."} + ) + + logger.debug(f"Routing request {request_id} for model: {requested_model}") + server_url = request.app.state.router.route_request( + endpoints, engine_stats, request_stats, request + ) + curr_time = time.time() + logger.info( + f"Routing request {request_id} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}" + ) + stream_generator = process_request( + request, + request_body, + server_url, + request_id, + endpoint=endpoint, + ) + headers, status_code = await anext(stream_generator) + return StreamingResponse( + stream_generator, + status_code=status_code, + headers={key: value for key, value in headers.items()}, + media_type="text/event-stream", + ) diff --git a/src/vllm_router/stats/__init__.py b/src/vllm_router/stats/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/vllm_router/engine_stats.py b/src/vllm_router/stats/engine_stats.py similarity index 93% rename from src/vllm_router/engine_stats.py rename to src/vllm_router/stats/engine_stats.py index bf150890..0e3b69cc 100644 --- a/src/vllm_router/engine_stats.py +++ b/src/vllm_router/stats/engine_stats.py @@ -1,13 +1,13 @@ import threading import time from dataclasses import dataclass -from typing import Dict, Optional, Tuple +from typing import Dict import requests from prometheus_client.parser import text_string_to_metric_families from vllm_router.log import init_logger -from vllm_router.service_discovery import GetServiceDiscovery +from vllm_router.service_discovery import get_service_discovery from vllm_router.utils import SingletonMeta logger = init_logger(__name__) @@ -25,7 +25,7 @@ class EngineStats: gpu_cache_usage_perc: float = 0.0 @staticmethod - def FromVllmScrape(vllm_scrape: str): + def from_vllm_scrape(vllm_scrape: str): """ Parse the vllm scrape string and return a EngineStats object @@ -103,7 +103,7 @@ def _scrape_one_endpoint(self, url: str): try: response = requests.get(url + "/metrics", timeout=self.scrape_interval) response.raise_for_status() - engine_stats = EngineStats.FromVllmScrape(response.text) + engine_stats = EngineStats.from_vllm_scrape(response.text) except Exception as e: logger.error(f"Failed to scrape metrics from {url}: {e}") return None @@ -119,7 +119,7 @@ def _scrape_metrics(self): """ collected_engine_stats = {} - endpoints = GetServiceDiscovery().get_endpoint_info() + endpoints = get_service_discovery().get_endpoint_info() logger.info(f"Scraping metrics from {len(endpoints)} serving engine(s)") for info in endpoints: url = info.url @@ -186,10 +186,10 @@ def close(self): self.scrape_thread.join() -def InitializeEngineStatsScraper(scrape_interval: float) -> EngineStatsScraper: +def initialize_engine_stats_scraper(scrape_interval: float) -> EngineStatsScraper: return EngineStatsScraper(scrape_interval) -def GetEngineStatsScraper() -> EngineStatsScraper: +def get_engine_stats_scraper() -> EngineStatsScraper: # This call returns the already-initialized instance (or raises an error if not yet initialized) return EngineStatsScraper() diff --git a/src/vllm_router/stats/log_stats.py b/src/vllm_router/stats/log_stats.py new file mode 100644 index 00000000..2922bf05 --- /dev/null +++ b/src/vllm_router/stats/log_stats.py @@ -0,0 +1,82 @@ +import time + +from fastapi import FastAPI + +from vllm_router.log import init_logger +from vllm_router.service_discovery import get_service_discovery +from vllm_router.services.metrics_service import ( + avg_decoding_length, + avg_itl, + avg_latency, + current_qps, + num_decoding_requests, + num_prefill_requests, + num_requests_running, + num_requests_swapped, +) + +logger = init_logger(__name__) + + +def log_stats(app: FastAPI, interval: int = 10): + """ + Periodically logs the engine and request statistics for each service endpoint. + + This function retrieves the current service endpoints and their corresponding + engine and request statistics, and logs them at a specified interval. The + statistics include the number of running and queued requests, GPU cache hit + rate, queries per second (QPS), average latency, average inter-token latency + (ITL), and more. These statistics are also updated in the Prometheus metrics. + + Args: + app (FastAPI): FastAPI application + interval (int): The interval in seconds at which statistics are logged. + Default is 10 seconds. + """ + + while True: + time.sleep(interval) + logstr = "\n" + "=" * 50 + "\n" + endpoints = get_service_discovery().get_endpoint_info() + engine_stats = app.state.engine_stats_scraper.get_engine_stats() + request_stats = app.state.request_stats_monitor.get_request_stats(time.time()) + for endpoint in endpoints: + url = endpoint.url + logstr += f"Model: {endpoint.model_name}\n" + logstr += f"Server: {url}\n" + if url in engine_stats: + es = engine_stats[url] + logstr += ( + f" Engine Stats: Running Requests: {es.num_running_requests}, " + f"Queued Requests: {es.num_queuing_requests}, " + f"GPU Cache Hit Rate: {es.gpu_prefix_cache_hit_rate:.2f}\n" + ) + else: + logstr += " Engine Stats: No stats available\n" + if url in request_stats: + rs = request_stats[url] + logstr += ( + f" Request Stats: QPS: {rs.qps:.2f}, " + f"Avg Latency: {rs.avg_latency}, " + f"Avg ITL: {rs.avg_itl}, " + f"Prefill Requests: {rs.in_prefill_requests}, " + f"Decoding Requests: {rs.in_decoding_requests}, " + f"Swapped Requests: {rs.num_swapped_requests}, " + f"Finished: {rs.finished_requests}, " + f"Uptime: {rs.uptime:.2f} sec\n" + ) + current_qps.labels(server=url).set(rs.qps) + avg_decoding_length.labels(server=url).set(rs.avg_decoding_length) + num_prefill_requests.labels(server=url).set(rs.in_prefill_requests) + num_decoding_requests.labels(server=url).set(rs.in_decoding_requests) + num_requests_running.labels(server=url).set( + rs.in_prefill_requests + rs.in_decoding_requests + ) + avg_latency.labels(server=url).set(rs.avg_latency) + avg_itl.labels(server=url).set(rs.avg_itl) + num_requests_swapped.labels(server=url).set(rs.num_swapped_requests) + else: + logstr += " Request Stats: No stats available\n" + logstr += "-" * 50 + "\n" + logstr += "=" * 50 + "\n" + logger.info(logstr) diff --git a/src/vllm_router/request_stats.py b/src/vllm_router/stats/request_stats.py similarity index 99% rename from src/vllm_router/request_stats.py rename to src/vllm_router/stats/request_stats.py index 16a43fb0..36552bde 100644 --- a/src/vllm_router/request_stats.py +++ b/src/vllm_router/stats/request_stats.py @@ -282,9 +282,9 @@ def get_request_stats(self, current_time: float) -> Dict[str, RequestStats]: return ret -def InitializeRequestStatsMonitor(sliding_window_size: float): +def initialize_request_stats_monitor(sliding_window_size: float): return RequestStatsMonitor(sliding_window_size) -def GetRequestStatsMonitor(): +def get_request_stats_monitor(): return RequestStatsMonitor()