Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Core] fix abort_seq_group and memory leak when n>1 #14326

Merged
merged 1 commit into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStage, SequenceStatus)
SequenceGroupBase, SequenceGroupMetadata,
SequenceGroupMetadataDelta, SequenceStage,
SequenceStatus)
from vllm.utils import Device, PyObjectCache

logger = init_logger(__name__)
Expand Down Expand Up @@ -561,7 +562,11 @@ def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
# Only for testing purposes.
self.swapped.append(seq_group)

def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
def abort_seq_group(
self,
request_id: Union[str, Iterable[str]],
seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None,
) -> None:
"""Aborts a sequence group with the given ID.
Check if the sequence group with the given ID
Expand All @@ -573,21 +578,29 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
Args:
request_id: The ID(s) of the sequence group to abort.
seq_id_to_seq_group: helper for groups with n>1
"""
if isinstance(request_id, str):
request_id = (request_id, )
request_ids = set(request_id)
seq_id_to_seq_group = seq_id_to_seq_group or {}
for state_queue in [self.waiting, self.running, self.swapped]:
aborted_groups: List[SequenceGroup] = []
for seq_group in state_queue:
if not request_ids:
# Using 'break' here may add two extra iterations,
# but is acceptable to reduce complexity.
break
if seq_group.request_id in request_ids:
# When n>1, seq_group.request_id looks like
# foo_parallel_sample_0, while request_ids is just foo, and we
# should resolve it as real_request_id to match.
if seq_group.request_id in seq_id_to_seq_group:
real_request_id = seq_id_to_seq_group[
seq_group.request_id].group_id
else:
real_request_id = seq_group.request_id
if real_request_id in request_ids:
# Appending aborted group into pending list.
aborted_groups.append(seq_group)
request_ids.remove(seq_group.request_id)
# We can't remove real_request_id in request_ids here,
# because there may be other seq groups sharing the same
# real_request_id
for aborted_group in aborted_groups:
# Remove the sequence group from the state queue.
state_queue.remove(aborted_group)
Expand All @@ -598,6 +611,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
continue
seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq)
if aborted_group.request_id in seq_id_to_seq_group:
del seq_id_to_seq_group[aborted_group.request_id]

self._free_seq_group_cross_attn_blocks(aborted_group)

Expand Down
8 changes: 7 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,8 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
>>> engine.abort_request(request_id)
"""
for scheduler in self.scheduler:
scheduler.abort_seq_group(request_id)
scheduler.abort_seq_group(
request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)

def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
Expand Down Expand Up @@ -1354,6 +1355,11 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:

finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
# When n>1, elements in self.seq_id_to_seq_group should be deleted
# here, otherwise memory leaks.
for finished_request_id in finished_requests_ids:
if finished_request_id in self.seq_id_to_seq_group:
del self.seq_id_to_seq_group[finished_request_id]

# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
Expand Down