Skip to content

Commit

Permalink
refactor: standard fastapi project structure for better main… (#217)
Browse files Browse the repository at this point in the history
* refactor: maintain standard fastapi project structure for better maintainability

Signed-off-by: BrianPark314 <[email protected]>

* chore: merge main into branch

Signed-off-by: BrianPark314 <[email protected]>

* fix: run pre-commit

Signed-off-by: BrianPark314 <[email protected]>

* fix: setup.py

Signed-off-by: BrianPark314 <[email protected]>

* fix: service discovery time import

Signed-off-by: BrianPark314 <[email protected]>

* fix: service discovery time import

Signed-off-by: BrianPark314 <[email protected]>

* fix: delete unused file

Signed-off-by: BrianPark314 <[email protected]>

* fix: merge main and fix log_stats.py

Signed-off-by: BrianPark314 <[email protected]>

* chore: run pre-commit

Signed-off-by: BrianPark314 <[email protected]>

* fix: correct experimental imports

Signed-off-by: BrianPark314 <[email protected]>

* fix: parser error

Signed-off-by: BrianPark314 <[email protected]>

* fix: experimental feature flag

Signed-off-by: BrianPark314 <[email protected]>

* fix: wrapper call issue

Signed-off-by: BrianPark314 <[email protected]>

* fix: add missing lifespan

Signed-off-by: BrianPark314 <[email protected]>

* chore: add TODO comment

Signed-off-by: BrianPark314 <[email protected]>

---------

Signed-off-by: BrianPark314 <[email protected]>
Co-authored-by: BrianPark314 <[email protected]>
  • Loading branch information
BrianPark314 and BrianPark314 authored Mar 4, 2025
1 parent 1b5e499 commit 95efecc
Show file tree
Hide file tree
Showing 40 changed files with 1,177 additions and 1,130 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
install_requires=install_requires,
entry_points={
"console_scripts": [
"vllm-router=vllm_router.router:main",
"vllm-router=vllm_router.app:main",
],
},
description="The router for vLLM",
Expand Down
3 changes: 2 additions & 1 deletion src/tests/test_file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import pytest

from vllm_router.files import FileStorage, OpenAIFile
from vllm_router.services.files_service.file_storage import FileStorage
from vllm_router.services.files_service.openai_files import OpenAIFile

TEST_BASE_PATH = "/tmp/test_vllm_files"
pytest_plugins = ("pytest_asyncio",)
Expand Down
8 changes: 2 additions & 6 deletions src/tests/test_session_router.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import sys
from typing import Dict, List
from unittest.mock import Mock
from typing import Dict

import pytest

from vllm_router.routing_logic import SessionRouter
from vllm_router.routers.routing_logic import SessionRouter


class EndpointInfo:
Expand Down
12 changes: 6 additions & 6 deletions src/tests/test_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import unittest

# Import the classes and helper functions from your module.
from vllm_router.request_stats import (
GetRequestStatsMonitor,
InitializeRequestStatsMonitor,
from vllm_router.stats.request_stats import (
RequestStatsMonitor,
SingletonMeta,
get_request_stats_monitor,
initialize_request_stats_monitor,
)


Expand All @@ -19,9 +19,9 @@ def setUp(self):
def test_singleton_initialization(self):
sliding_window = 10.0
# First initialization using the helper.
monitor1 = InitializeRequestStatsMonitor(sliding_window)
monitor1 = initialize_request_stats_monitor(sliding_window)
# Subsequent retrieval using GetRequestStatsMonitor() should return the same instance.
monitor2 = GetRequestStatsMonitor()
monitor2 = get_request_stats_monitor()
self.assertIs(
monitor1,
monitor2,
Expand All @@ -39,7 +39,7 @@ def test_singleton_initialization(self):
def test_initialization_without_parameter_after_initialized(self):
sliding_window = 10.0
# First, initialize with the sliding_window.
monitor1 = InitializeRequestStatsMonitor(sliding_window)
monitor1 = initialize_request_stats_monitor(sliding_window)
# Now, calling the constructor without a parameter should not raise an error
# and should return the already initialized instance.
monitor2 = RequestStatsMonitor()
Expand Down
230 changes: 230 additions & 0 deletions src/vllm_router/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import logging
import threading
from contextlib import asynccontextmanager

import uvicorn
from fastapi import FastAPI

from vllm_router.dynamic_config import (
DynamicRouterConfig,
get_dynamic_config_watcher,
initialize_dynamic_config_watcher,
)
from vllm_router.experimental import get_feature_gates, initialize_feature_gates

try:
# Semantic cache integration
from vllm_router.experimental.semantic_cache import (
GetSemanticCache,
enable_semantic_cache,
initialize_semantic_cache,
is_semantic_cache_enabled,
)
from vllm_router.experimental.semantic_cache_integration import (
add_semantic_cache_args,
check_semantic_cache,
semantic_cache_hit_ratio,
semantic_cache_hits,
semantic_cache_latency,
semantic_cache_misses,
semantic_cache_size,
store_in_semantic_cache,
)

semantic_cache_available = True
except ImportError:
semantic_cache_available = False

from vllm_router.httpx_client import HTTPXClientWrapper
from vllm_router.parsers.parser import parse_args
from vllm_router.routers.batches_router import batches_router
from vllm_router.routers.files_router import files_router
from vllm_router.routers.main_router import main_router
from vllm_router.routers.metrics_router import metrics_router
from vllm_router.routers.routing_logic import (
get_routing_logic,
initialize_routing_logic,
)
from vllm_router.service_discovery import (
ServiceDiscoveryType,
get_service_discovery,
initialize_service_discovery,
)
from vllm_router.services.batch_service import initialize_batch_processor
from vllm_router.services.files_service import initialize_storage
from vllm_router.stats.engine_stats import (
get_engine_stats_scraper,
initialize_engine_stats_scraper,
)
from vllm_router.stats.log_stats import log_stats
from vllm_router.stats.request_stats import (
get_request_stats_monitor,
initialize_request_stats_monitor,
)
from vllm_router.utils import parse_static_model_names, parse_static_urls, set_ulimit

logger = logging.getLogger("uvicorn")


@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.httpx_client_wrapper.start()
if hasattr(app.state, "batch_processor"):
await app.state.batch_processor.initialize()
yield
await app.state.httpx_client_wrapper.stop()

# Close the threaded-components
logger.info("Closing engine stats scraper")
engine_stats_scraper = get_engine_stats_scraper()
engine_stats_scraper.close()

logger.info("Closing service discovery module")
service_discovery = get_service_discovery()
service_discovery.close()

# Close the optional dynamic config watcher
dyn_cfg_watcher = get_dynamic_config_watcher()
if dyn_cfg_watcher is not None:
logger.info("Closing dynamic config watcher")
dyn_cfg_watcher.close()


# TODO: This method needs refactoring, since it has nested if statements and is too long
def initialize_all(app: FastAPI, args):
"""
Initialize all the components of the router with the given arguments.
Args:
app (FastAPI): FastAPI application
args: the parsed command-line arguments
Raises:
ValueError: if the service discovery type is invalid
"""
if args.service_discovery == "static":
initialize_service_discovery(
ServiceDiscoveryType.STATIC,
urls=parse_static_urls(args.static_backends),
models=parse_static_model_names(args.static_models),
)
elif args.service_discovery == "k8s":
initialize_service_discovery(
ServiceDiscoveryType.K8S,
namespace=args.k8s_namespace,
port=args.k8s_port,
label_selector=args.k8s_label_selector,
)
else:
raise ValueError(f"Invalid service discovery type: {args.service_discovery}")

# Initialize singletons via custom functions.
initialize_engine_stats_scraper(args.engine_stats_interval)
initialize_request_stats_monitor(args.request_stats_window)

if args.enable_batch_api:
logger.info("Initializing batch API")
app.state.batch_storage = initialize_storage(
args.file_storage_class, args.file_storage_path
)
app.state.batch_processor = initialize_batch_processor(
args.batch_processor, args.file_storage_path, app.state.batch_storage
)

initialize_routing_logic(args.routing_logic, session_key=args.session_key)

# Initialize feature gates
initialize_feature_gates(args.feature_gates)
# Check if the SemanticCache feature gate is enabled
feature_gates = get_feature_gates()
if semantic_cache_available:
if feature_gates.is_enabled("SemanticCache"):
# The feature gate is enabled, explicitly enable the semantic cache
enable_semantic_cache()

# Verify that the semantic cache was successfully enabled
if not is_semantic_cache_enabled():
logger.error("Failed to enable semantic cache feature")

logger.info("SemanticCache feature gate is enabled")

# Initialize the semantic cache with the model if specified
if args.semantic_cache_model:
logger.info(
f"Initializing semantic cache with model: {args.semantic_cache_model}"
)
logger.info(
f"Semantic cache directory: {args.semantic_cache_dir or 'default'}"
)
logger.info(
f"Semantic cache threshold: {args.semantic_cache_threshold}"
)

cache = initialize_semantic_cache(
embedding_model=args.semantic_cache_model,
cache_dir=args.semantic_cache_dir,
default_similarity_threshold=args.semantic_cache_threshold,
)

# Update cache size metric
if cache and hasattr(cache, "db") and hasattr(cache.db, "index"):
semantic_cache_size.labels(server="router").set(
cache.db.index.ntotal
)
logger.info(
f"Semantic cache initialized with {cache.db.index.ntotal} entries"
)

logger.info(
f"Semantic cache initialized with model {args.semantic_cache_model}"
)
else:
logger.warning(
"SemanticCache feature gate is enabled but no embedding model specified. "
"The semantic cache will not be functional without an embedding model. "
"Use --semantic-cache-model to specify an embedding model."
)
elif args.semantic_cache_model:
logger.warning(
"Semantic cache model specified but SemanticCache feature gate is not enabled. "
"Enable the feature gate with --feature-gates=SemanticCache=true"
)

# --- Hybrid addition: attach singletons to FastAPI state ---
app.state.engine_stats_scraper = get_engine_stats_scraper()
app.state.request_stats_monitor = get_request_stats_monitor()
app.state.router = get_routing_logic()

# Initialize dynamic config watcher
if args.dynamic_config_json:
init_config = DynamicRouterConfig.from_args(args)
initialize_dynamic_config_watcher(
args.dynamic_config_json, 10, init_config, app
)


app = FastAPI(lifespan=lifespan)
app.include_router(main_router)
app.include_router(files_router)
app.include_router(batches_router)
app.include_router(metrics_router)
app.state.httpx_client_wrapper = HTTPXClientWrapper()
app.state.semantic_cache_available = semantic_cache_available


def main():
args = parse_args()
initialize_all(app, args)
if args.log_stats:
threading.Thread(
target=log_stats, args=(args.log_stats_interval,), daemon=True
).start()

# Workaround to avoid footguns where uvicorn drops requests with too
# many concurrent requests active.
set_ulimit()
uvicorn.run(app, host=args.host, port=args.port)


if __name__ == "__main__":
main()
14 changes: 7 additions & 7 deletions src/vllm_router/dynamic_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from fastapi import FastAPI

from vllm_router.log import init_logger
from vllm_router.routing_logic import ReconfigureRoutingLogic
from vllm_router.routers.routing_logic import reconfigure_routing_logic
from vllm_router.service_discovery import (
ReconfigureServiceDiscovery,
ServiceDiscoveryType,
reconfigure_service_discovery,
)
from vllm_router.utils import SingletonMeta, parse_static_model_names, parse_static_urls

Expand Down Expand Up @@ -115,13 +115,13 @@ def reconfigure_service_discovery(self, config: DynamicRouterConfig):
Reconfigures the router with the given config.
"""
if config.service_discovery == "static":
ReconfigureServiceDiscovery(
reconfigure_service_discovery(
ServiceDiscoveryType.STATIC,
urls=parse_static_urls(config.static_backends),
models=parse_static_model_names(config.static_models),
)
elif config.service_discovery == "k8s":
ReconfigureServiceDiscovery(
reconfigure_service_discovery(
ServiceDiscoveryType.K8S,
namespace=config.k8s_namespace,
port=config.k8s_port,
Expand All @@ -138,7 +138,7 @@ def reconfigure_routing_logic(self, config: DynamicRouterConfig):
"""
Reconfigures the router with the given config.
"""
routing_logic = ReconfigureRoutingLogic(
routing_logic = reconfigure_routing_logic(
config.routing_logic, session_key=config.session_key
)
self.app.state.router = routing_logic
Expand Down Expand Up @@ -209,7 +209,7 @@ def close(self):
logger.info("DynamicConfigWatcher: Closed")


def InitializeDynamicConfigWatcher(
def initialize_dynamic_config_watcher(
config_json: str,
watch_interval: int,
init_config: DynamicRouterConfig,
Expand All @@ -221,7 +221,7 @@ def InitializeDynamicConfigWatcher(
return DynamicConfigWatcher(config_json, watch_interval, init_config, app)


def GetDynamicConfigWatcher() -> DynamicConfigWatcher:
def get_dynamic_config_watcher() -> DynamicConfigWatcher:
"""
Returns the DynamicConfigWatcher singleton.
"""
Expand Down
4 changes: 2 additions & 2 deletions src/vllm_router/experimental/semantic_cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from vllm_router.experimental.feature_gates import get_feature_gates
from vllm_router.experimental.semantic_cache.semantic_cache import (
GetSemanticCache,
InitializeSemanticCache,
SemanticCache,
initialize_semantic_cache,
)

logger = logging.getLogger(__name__)

__all__ = [
"SemanticCache",
"InitializeSemanticCache",
"initialize_semantic_cache",
"GetSemanticCache",
"is_semantic_cache_enabled",
"enable_semantic_cache",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from vllm_router.experimental.semantic_cache.db_adapters import (
FAISSAdapter,
VectorDBAdapterBase,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -318,7 +317,7 @@ def complete_store(
_semantic_cache_instance = None


def InitializeSemanticCache(
def initialize_semantic_cache(
embedding_model: str = "all-MiniLM-L6-v2",
cache_dir: str = None,
default_similarity_threshold: float = 0.95,
Expand Down
Loading

0 comments on commit 95efecc

Please sign in to comment.