Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Request rewriter interface in router #230

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/vllm_router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
)
from vllm_router.services.batch_service import initialize_batch_processor
from vllm_router.services.files_service import initialize_storage
from vllm_router.services.request_service.rewriter import (
get_request_rewriter,
initialize_request_rewriter,
)
from vllm_router.stats.engine_stats import (
get_engine_stats_scraper,
initialize_engine_stats_scraper,
Expand Down Expand Up @@ -90,7 +94,6 @@ async def lifespan(app: FastAPI):
dyn_cfg_watcher.close()


# TODO: This method needs refactoring, since it has nested if statements and is too long
def initialize_all(app: FastAPI, args):
"""
Initialize all the components of the router with the given arguments.
Expand Down Expand Up @@ -194,6 +197,7 @@ def initialize_all(app: FastAPI, args):
app.state.engine_stats_scraper = get_engine_stats_scraper()
app.state.request_stats_monitor = get_request_stats_monitor()
app.state.router = get_routing_logic()
app.state.request_rewriter = get_request_rewriter()

# Initialize dynamic config watcher
if args.dynamic_config_json:
Expand Down
9 changes: 9 additions & 0 deletions src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ def parse_args():
help="The key (in the header) to identify a session.",
)

# Request rewriter arguments
parser.add_argument(
"--request-rewriter",
type=str,
default="noop",
choices=["noop"],
help="The request rewriter to use. Default is 'noop' (no rewriting).",
)

# Batch API
# TODO(gaocegege): Make these batch api related arguments to a separate config.
parser.add_argument(
Expand Down
18 changes: 18 additions & 0 deletions src/vllm_router/services/request_service/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from vllm_router.log import init_logger
from vllm_router.service_discovery import get_service_discovery
from vllm_router.services.request_service.rewriter import (
get_request_rewriter,
is_request_rewriter_initialized,
)

try:
# Semantic cache integration
Expand Down Expand Up @@ -141,6 +145,20 @@ async def route_general_request(request: Request, endpoint: str):
content={"error": "Invalid request: missing 'model' in request body."},
)

# Apply request rewriting if enabled
if is_request_rewriter_initialized():
rewriter = get_request_rewriter()
rewritten_body = rewriter.rewrite_request(
request_body, requested_model, endpoint
)
logger.info(f"Request for model {requested_model} was rewritten")
request_body = rewritten_body
# Update request_json if the body was rewritten
try:
request_json = json.loads(request_body)
except:
logger.warning("Failed to parse rewritten request body as JSON")

# TODO (ApostaC): merge two awaits into one
endpoints = get_service_discovery().get_endpoint_info()
engine_stats = request.app.state.engine_stats_scraper.get_engine_stats()
Expand Down
107 changes: 107 additions & 0 deletions src/vllm_router/services/request_service/rewriter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
Request rewriter interface for vLLM router.
This module provides functionality to rewrite requests before they are sent to the backend.
"""

import abc
import json
from typing import Any, Dict

from vllm_router.log import init_logger
from vllm_router.utils import SingletonABCMeta

logger = init_logger(__name__)


class RequestRewriter(metaclass=SingletonABCMeta):
"""
Abstract base class for request rewriters.
Request rewriters can modify the request body before it is sent to the backend.
This can be used for prompt engineering, model-specific adjustments, or request normalization.
"""

@abc.abstractmethod
def rewrite_request(self, request_body: str, model: str, endpoint: str) -> str:
"""
Rewrite the request body.
Args:
request_body: The original request body as string
model: The model name from the request
endpoint: The target endpoint of this request
Returns:
The rewritten request body as string
"""
pass


class NoopRequestRewriter(RequestRewriter):
"""
A request rewriter that does not modify the request.
"""

def rewrite_request(self, request_body: str, model: str, endpoint: str) -> str:
"""
Return the request body unchanged.
Args:
request_body: The original request body as string
model: The model name from the request
endpoint: The target endpoint of this request
Returns:
The original request body without any modifications
"""
return request_body


# Singleton instance
_request_rewriter_instance = None


def initialize_request_rewriter(rewriter_type: str, **kwargs) -> RequestRewriter:
"""
Initialize the request rewriter singleton.
Args:
rewriter_type: The type of rewriter to initialize
**kwargs: Additional arguments for the rewriter
Returns:
The initialized request rewriter instance
"""
global _request_rewriter_instance

# TODO: Implement different rewriter types
# For now, just use the NoopRequestRewriter
_request_rewriter_instance = NoopRequestRewriter()
logger.info(f"Initialized placeholder request rewriter (type: {rewriter_type})")

return _request_rewriter_instance


def is_request_rewriter_initialized() -> bool:
"""
Check if the request rewriter singleton has been initialized.
Returns:
bool: True if the request rewriter has been initialized, False otherwise
"""
global _request_rewriter_instance
return _request_rewriter_instance is not None


def get_request_rewriter() -> RequestRewriter:
"""
Get the request rewriter singleton instance.
Returns:
The request rewriter instance or NoopRequestRewriter if not initialized
"""
global _request_rewriter_instance
if _request_rewriter_instance is None:
_request_rewriter_instance = NoopRequestRewriter()
return _request_rewriter_instance