Skip to content

Commit

Permalink
[Misc] Implement Singleton Design Pattern for EngineStat Scraper, Req…
Browse files Browse the repository at this point in the history
…uestStat Monitor, and Router (#131)

* update singleton implementation pattern for engine_stat, request_stat and router scraper

Signed-off-by: sitloboi2012 <[email protected]>

* update fastapi to use state as built-in singleton method aside with custom singleton

Signed-off-by: sitloboi2012 <[email protected]>

---------

Signed-off-by: sitloboi2012 <[email protected]>
  • Loading branch information
sitloboi2012 authored Feb 20, 2025
1 parent 8134ea5 commit 5247c69
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 149 deletions.
62 changes: 62 additions & 0 deletions src/tests/test_singleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# test_singleton.py
import unittest

# Import the classes and helper functions from your module.
from vllm_router.request_stats import (
GetRequestStatsMonitor,
InitializeRequestStatsMonitor,
RequestStatsMonitor,
SingletonMeta,
)


class TestRequestStatsMonitorSingleton(unittest.TestCase):
def setUp(self):
# Clear any existing singleton instance for RequestStatsMonitor
if RequestStatsMonitor in SingletonMeta._instances:
del SingletonMeta._instances[RequestStatsMonitor]

def test_singleton_initialization(self):
sliding_window = 10.0
# First initialization using the helper.
monitor1 = InitializeRequestStatsMonitor(sliding_window)
# Subsequent retrieval using GetRequestStatsMonitor() should return the same instance.
monitor2 = GetRequestStatsMonitor()
self.assertIs(
monitor1,
monitor2,
"GetRequestStatsMonitor should return the initialized singleton.",
)

# Directly calling the constructor with the same parameter should also return the same instance.
monitor3 = RequestStatsMonitor(sliding_window)
self.assertIs(
monitor1,
monitor3,
"Direct constructor calls should return the same singleton instance.",
)

def test_initialization_without_parameter_after_initialized(self):
sliding_window = 10.0
# First, initialize with the sliding_window.
monitor1 = InitializeRequestStatsMonitor(sliding_window)
# Now, calling the constructor without a parameter should not raise an error
# and should return the already initialized instance.
monitor2 = RequestStatsMonitor()
self.assertIs(
monitor1,
monitor2,
"Calling RequestStatsMonitor() without parameter after initialization should return the singleton.",
)

def test_initialization_without_parameter_before_initialized(self):
# Ensure no instance is present.
if RequestStatsMonitor in SingletonMeta._instances:
del SingletonMeta._instances[RequestStatsMonitor]
# Calling the constructor without the sliding_window parameter before initialization should raise a ValueError.
with self.assertRaises(ValueError):
RequestStatsMonitor() # This should fail because sliding_window_size is required on first init.


if __name__ == "__main__":
unittest.main()
71 changes: 23 additions & 48 deletions src/vllm_router/engine_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@

logger = init_logger(__name__)

_global_engine_stats_scraper: "Optional[EngineStatsScraper]" = None

class SingletonMeta(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]


@dataclass
Expand Down Expand Up @@ -63,7 +71,7 @@ def FromVllmScrape(vllm_scrape: str):
)


class EngineStatsScraper:
class EngineStatsScraper(metaclass=SingletonMeta):
def __init__(self, scrape_interval: float):
"""
Initialize the scraper to periodically fetch metrics from all serving engines.
Expand All @@ -77,19 +85,27 @@ def __init__(self, scrape_interval: float):
not been initialized.
"""
self.service_discovery = GetServiceDiscovery()
# Allow multiple calls but require the first call provide scrape_interval.
if hasattr(self, "_initialized"):
return
if scrape_interval is None:
raise ValueError(
"EngineStatsScraper must be initialized with scrape_interval"
)
self.service_discovery = GetServiceDiscovery() # (remains unchanged)
self.engine_stats: Dict[str, EngineStats] = {}
self.engine_stats_lock = threading.Lock()
self.scrape_interval = scrape_interval
self.scrape_thread = threading.Thread(target=self._scrape_worker, daemon=True)
self.scrape_thread.start()
self._initialized = True

def _scrape_one_endpoint(self, url: str):
"""
Scrape metrics from a single serving engine.
Args:
url (str): The base URL of the serving engine.
url (str): The URL of the serving engine (does not contain endpoint)
"""
try:
response = requests.get(url + "/metrics")
Expand Down Expand Up @@ -161,50 +177,9 @@ def get_health(self) -> bool:


def InitializeEngineStatsScraper(scrape_interval: float) -> EngineStatsScraper:
"""
Initialize the EngineStatsScraper.
Args:
scrape_interval (float): The interval (in seconds) to scrape metrics.
Raises:
ValueError: if the service discover module is have
not been initialized
ValueError: if the EngineStatsScraper object has already been
initialized
"""
global _global_engine_stats_scraper
if _global_engine_stats_scraper:
raise ValueError("EngineStatsScraper object has already been initialized")
_global_engine_stats_scraper = EngineStatsScraper(scrape_interval)
return _global_engine_stats_scraper
return EngineStatsScraper(scrape_interval)


def GetEngineStatsScraper() -> EngineStatsScraper:
"""
Retrieve the EngineStatsScraper.
Raises:
ValueError: If not initialized.
"""
global _global_engine_stats_scraper
if not _global_engine_stats_scraper:
raise ValueError("EngineStatsScraper object has not been initialized")
return _global_engine_stats_scraper


# if __name__ == "__main__":
# from service_discovery import InitializeServiceDiscovery, ServiceDiscoveryType
# import time
# InitializeServiceDiscovery(ServiceDiscoveryType.K8S,
# namespace = "default",
# port = 8000,
# label_selector = "release=test")
# time.sleep(1)
# InitializeEngineStatsScraper(10.0)
# engine_scraper = GetEngineStatsScraper()
# while True:
# engine_stats = engine_scraper.get_engine_stats()
# print(engine_stats)
# time.sleep(2.0)
# This call returns the already-initialized instance (or raises an error if not yet initialized)
return EngineStatsScraper()
62 changes: 20 additions & 42 deletions src/vllm_router/request_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@

logger = init_logger(__name__)

_global_request_stats_monitor = None

class SingletonMeta(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]


@dataclass
Expand Down Expand Up @@ -70,7 +78,7 @@ def get_sum(self) -> float:
return sum(self.values)


class RequestStatsMonitor:
class RequestStatsMonitor(metaclass=SingletonMeta):
"""
Monitors the request statistics of all serving engines.
"""
Expand All @@ -79,16 +87,14 @@ class RequestStatsMonitor:
# arrived requests in the sliding window, but the inter_token_latency and
# ttft are calculated based on the number of completed requests in the
# sliding window.
def __init__(self, sliding_window_size: float):
"""
Args:
sliding_window_size: The size of the sliding window (in seconds)
to store the request statistics
"""
def __init__(self, sliding_window_size: float = None):
if hasattr(self, "_initialized"):
return
if sliding_window_size is None:
raise ValueError(
"RequestStatsMonitor must be initialized with sliding_window_size"
)
self.sliding_window_size = sliding_window_size

# Finished requests for each serving engine
# The elements in the deque should be sorted by 'complete' time
self.qps_monitors: Dict[str, MovingAverageMonitor] = {}
self.ttft_monitors: Dict[str, MovingAverageMonitor] = {}

Expand All @@ -101,7 +107,6 @@ def __init__(self, sliding_window_size: float):
self.in_prefill_requests: Dict[str, int] = {}
self.in_decoding_requests: Dict[str, int] = {}
self.finished_requests: Dict[str, int] = {}

# New monitors for overall latency and decoding length
self.latency_monitors: Dict[str, MovingAverageMonitor] = {}
self.decoding_length_monitors: Dict[str, MovingAverageMonitor] = {}
Expand All @@ -110,6 +115,7 @@ def __init__(self, sliding_window_size: float):
self.swapped_requests: Dict[str, int] = {}

self.first_query_time: float = None
self._initialized = True

def on_new_request(self, engine_url: str, request_id: str, timestamp: float):
"""
Expand Down Expand Up @@ -262,36 +268,8 @@ def get_request_stats(self, current_time: float) -> Dict[str, RequestStats]:


def InitializeRequestStatsMonitor(sliding_window_size: float):
"""
Initialize the global request statistics monitor
Args:
sliding_window_size: The size of the sliding window (in seconds)
to store the request
Raises:
ValueError: If the global request statistics monitor has been initialized
"""
global _global_request_stats_monitor
if _global_request_stats_monitor is not None:
raise ValueError("The global request statistics monitor has been initialized")
_global_request_stats_monitor = RequestStatsMonitor(sliding_window_size)
return _global_request_stats_monitor
return RequestStatsMonitor(sliding_window_size)


def GetRequestStatsMonitor():
"""
Get the global request statistics monitor
Returns:
The global request statistics monitor
Raises:
ValueError: If the global request statistics monitor has not been initialized
"""
global _global_request_stats_monitor
if _global_request_stats_monitor is None:
raise ValueError(
"The global request statistics monitor has not been initialized"
)
return _global_request_stats_monitor
return RequestStatsMonitor()
36 changes: 24 additions & 12 deletions src/vllm_router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ async def process_request(
first_token = False
total_len = 0
start_time = time.time()
GetRequestStatsMonitor().on_new_request(backend_url, request_id, start_time)
logger.info(f"Started request {request_id} for backend {backend_url}")
app.state.request_stats_monitor.on_new_request(backend_url, request_id, start_time)

client = httpx_client_wrapper()
async with client.stream(
Expand All @@ -122,15 +121,17 @@ async def process_request(
total_len += len(chunk)
if not first_token:
first_token = True
GetRequestStatsMonitor().on_request_response(
app.state.request_stats_monitor.on_request_response(
backend_url, request_id, time.time()
)
yield chunk

GetRequestStatsMonitor().on_request_complete(backend_url, request_id, time.time())
logger.info(f"Completed request {request_id} for backend {backend_url}")
# Optional debug logging can be enabled here.
# logger.debug(f"Finished the request with id: {debug_request.headers.get('x-request-id', None)} at {time.time()}")
app.state.request_stats_monitor.on_request_complete(
backend_url, request_id, time.time()
)

# if debug_request:
# logger.debug(f"Finished the request with request id: {debug_request.headers.get('x-request-id', None)} at {time.time()}")


async def route_general_request(request: Request, endpoint: str):
Expand Down Expand Up @@ -161,17 +162,21 @@ async def route_general_request(request: Request, endpoint: str):
content={"error": "Invalid request: missing 'model' in request body."},
)

# TODO (ApostaC): merge two awaits into one
endpoints = GetServiceDiscovery().get_endpoint_info()
engine_stats = GetEngineStatsScraper().get_engine_stats()
request_stats = GetRequestStatsMonitor().get_request_stats(time.time())
engine_stats = request.app.state.engine_stats_scraper.get_engine_stats()
request_stats = request.app.state.request_stats_monitor.get_request_stats(
time.time()
)

endpoints = list(filter(lambda x: x.model_name == requested_model, endpoints))
if not endpoints:
return JSONResponse(
status_code=400, content={"error": f"Model {requested_model} not found."}
)

logger.debug(f"Routing request {request_id} for model: {requested_model}")
server_url = GetRoutingLogic().route_request(
server_url = request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request
)
curr_time = time.time()
Expand Down Expand Up @@ -740,6 +745,8 @@ def InitializeAll(args):
)
else:
raise ValueError(f"Invalid service discovery type: {args.service_discovery}")

# Initialize singletons via custom functions.
InitializeEngineStatsScraper(args.engine_stats_interval)
InitializeRequestStatsMonitor(args.request_stats_window)

Expand All @@ -754,6 +761,11 @@ def InitializeAll(args):

InitializeRoutingLogic(args.routing_logic, session_key=args.session_key)

# --- Hybrid addition: attach singletons to FastAPI state ---
app.state.engine_stats_scraper = GetEngineStatsScraper()
app.state.request_stats_monitor = GetRequestStatsMonitor()
app.state.router = GetRoutingLogic()


def log_stats(interval: int = 10):
"""
Expand All @@ -774,8 +786,8 @@ def log_stats(interval: int = 10):
time.sleep(interval)
logstr = "\n" + "=" * 50 + "\n"
endpoints = GetServiceDiscovery().get_endpoint_info()
engine_stats = GetEngineStatsScraper().get_engine_stats()
request_stats = GetRequestStatsMonitor().get_request_stats(time.time())
engine_stats = app.state.engine_stats_scraper.get_engine_stats()
request_stats = app.state.request_stats_monitor.get_request_stats(time.time())
for endpoint in endpoints:
url = endpoint.url
logstr += f"Model: {endpoint.model_name}\n"
Expand Down
Loading

0 comments on commit 5247c69

Please sign in to comment.