From ea693916c9315b9fc860a1e6562fbeca148a470c Mon Sep 17 00:00:00 2001 From: sitloboi2012 Date: Sat, 15 Feb 2025 00:07:04 +0700 Subject: [PATCH] update based on Shaoting-Feng comment --- src/vllm_router/request_stats.py | 10 +++ src/vllm_router/router.py | 119 +++++++++++++++++++++++++++++-- 2 files changed, 125 insertions(+), 4 deletions(-) diff --git a/src/vllm_router/request_stats.py b/src/vllm_router/request_stats.py index 46c69b1d..369d4447 100644 --- a/src/vllm_router/request_stats.py +++ b/src/vllm_router/request_stats.py @@ -44,6 +44,16 @@ def __init__(self, sliding_window_size: float): self.values: Deque[float] = deque() def update(self, timestamp: float, value: float): + """ + Update the throughput monitor with a new timestamp + + Args: + timestamp: The timestamp of the data point. + value: The value of the data point. + + This method adds the new data point to the sliding window and + removes any data point that is older than the sliding window size. + """ self.timestamps.append(timestamp) self.values.append(value) while ( diff --git a/src/vllm_router/router.py b/src/vllm_router/router.py index 3fa2fc1c..983299ee 100644 --- a/src/vllm_router/router.py +++ b/src/vllm_router/router.py @@ -74,11 +74,28 @@ async def lifespan(app: FastAPI): # --- Request Processing & Routing --- +# TODO: better request id system async def process_request( method, header, body, backend_url, request_id, endpoint, debug_request=None ): """ - Async generator to stream data from the backend server to the client. + Process a request by sending it to the chosen backend. + + Args: + method: The HTTP method to use when sending the request to the backend. + header: The headers to send with the request to the backend. + body: The content of the request to send to the backend. + backend_url: The URL of the backend to send the request to. + request_id: A unique identifier for the request. + endpoint: The endpoint to send the request to on the backend. + debug_request: The original request object from the client, used for + optional debug logging. + + Yields: + The response headers and status code, followed by the response content. + + Raises: + HTTPError: If the backend returns a 4xx or 5xx status code. """ first_token = False total_len = 0 @@ -113,10 +130,26 @@ async def process_request( async def route_general_request(request: Request, endpoint: str): + """ + Route the incoming request to the backend server and stream the response back to the client. + + This function extracts the requested model from the request body and retrieves the + corresponding endpoints. It uses routing logic to determine the best server URL to handle + the request, then streams the request to that server. If the requested model is not available, + it returns an error response. + + Args: + request (Request): The incoming HTTP request. + endpoint (str): The endpoint to which the request should be routed. + + Returns: + StreamingResponse: A response object that streams data from the backend server to the client. + """ + in_router_time = time.time() request_id = str(uuid.uuid4()) request_body = await request.body() - request_json = await request.json() + request_json = await request.json() # TODO (ApostaC): merge two awaits into one requested_model = request_json.get("model", None) if requested_model is None: return JSONResponse( @@ -137,7 +170,6 @@ async def route_general_request(request: Request, endpoint: str): server_url = GetRoutingLogic().route_request( endpoints, engine_stats, request_stats, request ) - logger.info(f"Request {request_id} routed to {server_url}") curr_time = time.time() logger.info( f"Routing request {request_id} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}" @@ -161,6 +193,19 @@ async def route_general_request(request: Request, endpoint: str): # --- File Endpoints --- @app.post("/v1/files") async def route_files(request: Request): + """ + Handle file upload requests and save the files to the configured storage. + + Args: + request (Request): The incoming HTTP request. + + Returns: + JSONResponse: A JSON response containing the file metadata. + + Raises: + JSONResponse: A JSON response with a 400 status code if the request is invalid, + or a 500 status code if an error occurs during file saving. + """ form = await request.form() purpose = form.get("purpose", "unknown") if "file" not in form: @@ -194,7 +239,8 @@ async def route_get_file(file_id: str): @app.get("/v1/files/{file_id}/content") async def route_get_file_content(file_id: str): try: - # TODO: Stream file content in chunks to support large files. + # TODO (gaocegege): Stream the file content with chunks to support + # openai uploads interface. file_content = await FILE_STORAGE.get_file_content(file_id) return Response(content=file_content) except FileNotFoundError: @@ -221,6 +267,18 @@ async def show_version(): @app.get("/v1/models") async def show_models(): + """ + Returns a list of all models available in the stack. + + Args: + None + + Returns: + JSONResponse: A JSON response containing the list of models. + + Raises: + Exception: If there is an error in retrieving the endpoint information. + """ endpoints = GetServiceDiscovery().get_endpoint_info() existing_models = set() model_cards = [] @@ -241,6 +299,20 @@ async def show_models(): @app.get("/health") async def health() -> Response: + """ + Endpoint to check the health status of various components. + + This function verifies the health of the service discovery module and + the engine stats scraper. If either component is down, it returns a + 503 response with the appropriate status message. If both components + are healthy, it returns a 200 OK response. + + Returns: + Response: A JSONResponse with status code 503 if a component is + down, or a plain Response with status code 200 if all components + are healthy. + """ + if not GetServiceDiscovery().get_health(): return JSONResponse( content={"status": "Service discovery module is down."}, status_code=503 @@ -256,6 +328,22 @@ async def health() -> Response: @app.get("/metrics") async def metrics(): # Retrieve request stats from the monitor. + """ + Endpoint to expose Prometheus metrics for the vLLM router. + + This function gathers request statistics, engine metrics, and health status + of the service endpoints to update Prometheus gauges. It exports metrics + such as queries per second (QPS), average decoding length, number of prefill + and decoding requests, average latency, average inter-token latency, number + of swapped requests, and the number of healthy pods for each server. The + metrics are used to monitor the performance and health of the vLLM router + services. + + Returns: + Response: A HTTP response containing the latest Prometheus metrics in + the appropriate content type. + """ + stats = GetRequestStatsMonitor().get_request_stats(time.time()) for server, stat in stats.items(): current_qps.labels(server=server).set(stat.qps) @@ -418,6 +506,15 @@ def parse_static_model_names(args): def InitializeAll(args): + """ + Initialize all the components of the router with the given arguments. + + Args: + args: the parsed command-line arguments + + Raises: + ValueError: if the service discovery type is invalid + """ if args.service_discovery == "static": InitializeServiceDiscovery( ServiceDiscoveryType.STATIC, @@ -441,6 +538,20 @@ def InitializeAll(args): def log_stats(interval: int = 10): + """ + Periodically logs the engine and request statistics for each service endpoint. + + This function retrieves the current service endpoints and their corresponding + engine and request statistics, and logs them at a specified interval. The + statistics include the number of running and queued requests, GPU cache hit + rate, queries per second (QPS), average latency, average inter-token latency + (ITL), and more. These statistics are also updated in the Prometheus metrics. + + Args: + interval (int): The interval in seconds at which statistics are logged. + Default is 10 seconds. + """ + while True: time.sleep(interval) logstr = "\n" + "=" * 50 + "\n"