Skip to content

Commit

Permalink
revert router change
Browse files Browse the repository at this point in the history
Signed-off-by: Huamin Chen <[email protected]>
  • Loading branch information
rootfs committed Mar 1, 2025
1 parent 8903619 commit 0b09bb1
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 154 deletions.
4 changes: 1 addition & 3 deletions src/vllm_router/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
aiofiles==24.1.0
faiss-cpu>=1.7.4
fastapi==0.115.8
httpx==0.28.1
kubernetes==32.0.0
numpy==1.26.4
prometheus_client==0.21.1
python-multipart==0.0.20
sentence-transformers>=2.2.2
uhashring==2.3
uvicorn==0.34.0
uvicorn==0.34.0
252 changes: 101 additions & 151 deletions src/vllm_router/router.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,24 @@
import argparse
import asyncio
import json
import logging
import os
import random
import re
import sys
import threading
import time
import uuid
from asyncio import Task
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from urllib.parse import urlparse

import httpx
import numpy as np
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from prometheus_client import CONTENT_TYPE_LATEST, REGISTRY, Gauge, generate_latest
from pydantic import BaseModel
from fastapi import FastAPI, Request, UploadFile
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import CONTENT_TYPE_LATEST, Gauge, generate_latest

from vllm_router.batch import BatchProcessor, initialize_batch_processor
from vllm_router.engine_stats import GetEngineStatsScraper, InitializeEngineStatsScraper

# Import experimental feature gates and semantic cache
from vllm_router.experimental.feature_gates import (
get_feature_gates,
initialize_feature_gates,
)

# Semantic cache integration
from vllm_router.experimental.semantic_cache import (
GetSemanticCache,
InitializeSemanticCache,
enable_semantic_cache,
is_semantic_cache_enabled,
)
from vllm_router.experimental.semantic_cache_integration import (
add_semantic_cache_args,
check_semantic_cache,
semantic_cache_hit_ratio,
semantic_cache_hits,
semantic_cache_latency,
semantic_cache_misses,
semantic_cache_size,
store_in_semantic_cache,
)
from vllm_router.files import Storage, initialize_storage
from vllm_router.httpx_client import HTTPXClientWrapper
from vllm_router.protocols import ModelCard, ModelList
from vllm_router.request_stats import (
GetRequestStatsMonitor,
InitializeRequestStatsMonitor,
RequestStatsMonitor,
)
from vllm_router.routing_logic import GetRoutingLogic, InitializeRoutingLogic
from vllm_router.service_discovery import (
Expand All @@ -65,10 +30,7 @@
from vllm_router.version import __version__

httpx_client_wrapper = HTTPXClientWrapper()

from vllm_router.log import init_logger

logger = init_logger(__name__)
logger = logging.getLogger("uvicorn")


@asynccontextmanager
Expand Down Expand Up @@ -143,18 +105,6 @@ async def process_request(
start_time = time.time()
app.state.request_stats_monitor.on_new_request(backend_url, request_id, start_time)

# Check if this is a streaming request
is_streaming = False
try:
request_json = json.loads(body)
is_streaming = request_json.get("stream", False)
except:
# If we can't parse the body as JSON, assume it's not streaming
pass

# For non-streaming requests, collect the full response to cache it properly
full_response = bytearray() if not is_streaming else None

client = httpx_client_wrapper()
async with client.stream(
method=method,
Expand All @@ -173,11 +123,6 @@ async def process_request(
app.state.request_stats_monitor.on_request_response(
backend_url, request_id, time.time()
)

# For non-streaming requests, collect the full response
if full_response is not None:
full_response.extend(chunk)

yield chunk

app.state.request_stats_monitor.on_request_complete(
Expand All @@ -187,13 +132,6 @@ async def process_request(
# if debug_request:
# logger.debug(f"Finished the request with request id: {debug_request.headers.get('x-request-id', None)} at {time.time()}")

# Store in semantic cache if applicable
# Use the full response for non-streaming requests, or the last chunk for streaming
cache_chunk = bytes(full_response) if full_response is not None else chunk
await store_in_semantic_cache(
endpoint=endpoint, method=method, body=body, chunk=cache_chunk
)


async def route_general_request(request: Request, endpoint: str):
"""
Expand Down Expand Up @@ -416,17 +354,101 @@ async def route_cancel_batch(batch_id: str):
)


@app.post("/v1/chat/completions")
async def route_chat_completition(request: Request):
# Check if the request can be served from the semantic cache
logger.debug("Received chat completion request, checking semantic cache")
cache_response = await check_semantic_cache(request=request)
@app.post("/v1/batches")
async def route_batches(request: Request):
"""Handle batch requests that process files with specified endpoints."""
try:
request_json = await request.json()

# Validate required fields
if "input_file_id" not in request_json:
return JSONResponse(
status_code=400,
content={"error": "Missing required parameter 'input_file_id'"},
)
if "endpoint" not in request_json:
return JSONResponse(
status_code=400,
content={"error": "Missing required parameter 'endpoint'"},
)

# Verify file exists
storage: Storage = app.state.batch_storage
file_id = request_json["input_file_id"]
try:
await storage.get_file(file_id)
except FileNotFoundError:
return JSONResponse(
status_code=404, content={"error": f"File {file_id} not found"}
)

batch_processor: BatchProcessor = app.state.batch_processor
batch = await batch_processor.create_batch(
input_file_id=file_id,
endpoint=request_json["endpoint"],
completion_window=request_json.get("completion_window", "5s"),
metadata=request_json.get("metadata", None),
)

# Return metadata as attribute, not a callable.
return JSONResponse(content=batch.to_dict())

except Exception as e:
return JSONResponse(
status_code=500,
content={"error": f"Failed to process batch request: {str(e)}"},
)


@app.get("/v1/batches/{batch_id}")
async def route_get_batch(batch_id: str):
try:
batch_processor: BatchProcessor = app.state.batch_processor
batch = await batch_processor.retrieve_batch(batch_id)
return JSONResponse(content=batch.to_dict())
except FileNotFoundError:
return JSONResponse(
status_code=404, content={"error": f"Batch {batch_id} not found"}
)


@app.get("/v1/batches")
async def route_list_batches(limit: int = 20, after: str = None):
try:
batch_processor: BatchProcessor = app.state.batch_processor
batches = await batch_processor.list_batches(limit=limit, after=after)

# Convert batches to response format
batch_data = [batch.to_dict() for batch in batches]

response = {
"object": "list",
"data": batch_data,
"first_id": batch_data[0]["id"] if batch_data else None,
"last_id": batch_data[-1]["id"] if batch_data else None,
"has_more": len(batch_data)
== limit, # If we got limit items, there may be more
}

return JSONResponse(content=response)
except FileNotFoundError:
return JSONResponse(status_code=404, content={"error": "No batches found"})

if cache_response:
logger.info("Serving response from semantic cache")
return cache_response

logger.debug("No cache hit, forwarding request to backend")
@app.delete("/v1/batches/{batch_id}")
async def route_cancel_batch(batch_id: str):
try:
batch_processor: BatchProcessor = app.state.batch_processor
batch = await batch_processor.cancel_batch(batch_id)
return JSONResponse(content=batch.to_dict())
except FileNotFoundError:
return JSONResponse(
status_code=404, content={"error": f"Batch {batch_id} not found"}
)


@app.post("/v1/chat/completions")
async def route_chat_completition(request: Request):
return await route_general_request(request, "/v1/chat/completions")


Expand Down Expand Up @@ -711,26 +733,6 @@ def parse_args():
help="Show version and exit",
)

# Add semantic cache arguments
add_semantic_cache_args(parser)

# Add feature gates argument
parser.add_argument(
"--feature-gates",
type=str,
default="",
help="Comma-separated list of feature gates (e.g., 'SemanticCache=true')",
)

# Add log level argument
parser.add_argument(
"--log-level",
type=str,
default="info",
choices=["critical", "error", "warning", "info", "debug", "trace"],
help="Log level for uvicorn. Default is 'info'.",
)

args = parser.parse_args()
validate_args(args)
return args
Expand Down Expand Up @@ -793,58 +795,6 @@ def InitializeAll(args):

InitializeRoutingLogic(args.routing_logic, session_key=args.session_key)

# Initialize feature gates
initialize_feature_gates(args.feature_gates)
# Check if the SemanticCache feature gate is enabled
feature_gates = get_feature_gates()
if feature_gates.is_enabled("SemanticCache"):
# The feature gate is enabled, explicitly enable the semantic cache
enable_semantic_cache()

# Verify that the semantic cache was successfully enabled
if not is_semantic_cache_enabled():
logger.error("Failed to enable semantic cache feature")

logger.info("SemanticCache feature gate is enabled")

# Initialize the semantic cache with the model if specified
if args.semantic_cache_model:
logger.info(
f"Initializing semantic cache with model: {args.semantic_cache_model}"
)
logger.info(
f"Semantic cache directory: {args.semantic_cache_dir or 'default'}"
)
logger.info(f"Semantic cache threshold: {args.semantic_cache_threshold}")

cache = InitializeSemanticCache(
embedding_model=args.semantic_cache_model,
cache_dir=args.semantic_cache_dir,
default_similarity_threshold=args.semantic_cache_threshold,
)

# Update cache size metric
if cache and hasattr(cache, "db") and hasattr(cache.db, "index"):
semantic_cache_size.labels(server="router").set(cache.db.index.ntotal)
logger.info(
f"Semantic cache initialized with {cache.db.index.ntotal} entries"
)

logger.info(
f"Semantic cache initialized with model {args.semantic_cache_model}"
)
else:
logger.warning(
"SemanticCache feature gate is enabled but no embedding model specified. "
"The semantic cache will not be functional without an embedding model. "
"Use --semantic-cache-model to specify an embedding model."
)
elif args.semantic_cache_model:
logger.warning(
"Semantic cache model specified but SemanticCache feature gate is not enabled. "
"Enable the feature gate with --feature-gates=SemanticCache=true"
)

# --- Hybrid addition: attach singletons to FastAPI state ---
app.state.engine_stats_scraper = GetEngineStatsScraper()
app.state.request_stats_monitor = GetRequestStatsMonitor()
Expand Down Expand Up @@ -925,8 +875,8 @@ def main():
# Workaround to avoid footguns where uvicorn drops requests with too
# many concurrent requests active.
set_ulimit()
uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level)
uvicorn.run(app, host=args.host, port=args.port)


if __name__ == "__main__":
main()
main()

0 comments on commit 0b09bb1

Please sign in to comment.