Skip to content

Commit

Permalink
fix(regression): call error when context parameter is present (#5247)
Browse files Browse the repository at this point in the history
fix: call error when context paremeter is present

Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming authored Feb 25, 2025
1 parent 2109875 commit 74076a7
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions src/_bentoml_impl/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,34 +652,34 @@ async def api_endpoint(self, name: str, request: Request) -> Response:
ctx = self.service.context
serde = ALL_SERDE.get(media_type, ALL_SERDE["application/json"])()
input_data = await method.input_spec.from_http_request(request, serde)
input_args: tuple[t.Any, ...] = ()
input_params: dict[str, t.Any] = {}
if method.ctx_param is not None:
input_params[method.ctx_param] = ctx
call_args: tuple[t.Any, ...] = ()
call_kwargs: dict[str, t.Any] = {}
if getattr(method.input_spec, "__root_input__", False):
if isinstance(input_data, IORootModel):
input_args = t.cast(t.Tuple[t.Any], (input_data.root,))
call_args = t.cast(t.Tuple[t.Any], (input_data.root,))
else:
input_args = (input_data,)
call_args = (input_data,)
else:
input_params = {k: getattr(input_data, k) for k in input_data.model_fields}
if ARGS in input_params:
input_args = (*input_args, input_params.pop(ARGS))
if KWARGS in input_params:
input_params.update(input_params.pop(KWARGS))
call_kwargs = {k: getattr(input_data, k) for k in input_data.model_fields}
if method.ctx_param is not None:
call_kwargs[method.ctx_param] = ctx
if ARGS in call_kwargs:
call_args = (*call_args, call_kwargs.pop(ARGS))
if KWARGS in call_kwargs:
call_kwargs.update(call_kwargs.pop(KWARGS))

original_func = get_original_func(func)

if method.batchable:
output = await self.batch_infer(name, input_args, input_params)
output = await self.batch_infer(name, call_args, call_kwargs)
elif inspect.iscoroutinefunction(original_func):
output = await func(*input_args, **input_params)
output = await func(*call_args, **call_kwargs)
elif inspect.isasyncgenfunction(original_func):
output = func(*input_args, **input_params)
output = func(*call_args, **call_kwargs)
elif inspect.isgeneratorfunction(original_func):

async def inner() -> t.AsyncGenerator[t.Any, None]:
gen = func(*input_args, **input_params)
gen = func(*call_args, **call_kwargs)
while True:
try:
yield await self._to_thread(next, gen)
Expand All @@ -692,7 +692,7 @@ async def inner() -> t.AsyncGenerator[t.Any, None]:

output = inner()
else:
output = await self._to_thread(func, *input_args, **input_params)
output = await self._to_thread(func, *call_args, **call_kwargs)

if isinstance(output, Response):
response = output
Expand Down

0 comments on commit 74076a7

Please sign in to comment.