Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Enable chunked prefill/paged attention in MLA on ROCm #14316

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions vllm/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,7 @@ def _compute_prefill_context(
assert prefill_metadata.context_chunk_max_seq_lens is not None
assert prefill_metadata.context_lens_tensor is not None

has_context = prefill_metadata.context_lens_tensor.max() > 0
output = None
iters = len(prefill_metadata.context_chunk_seq_tot)

Expand Down Expand Up @@ -1322,7 +1323,8 @@ def _compute_prefill_context(
[0, q.shape[-1] - v.shape[-1]],
value=0)

if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and \
has_context is False:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
has_context is False:
not has_context:

attn_output, attn_softmax_lse = self.triton_fa_func(
q,
k,
Expand Down Expand Up @@ -1411,7 +1413,7 @@ def _forward_prefill(
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)

if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and has_context is False:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and has_context is False:
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context:

output = self.triton_fa_func(
q,
k,
Expand Down
4 changes: 2 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3433,9 +3433,9 @@ def __post_init__(self):
self.compilation_config.level = CompilationLevel.NO_COMPILATION

if self.model_config and self.model_config.use_mla and \
not current_platform.is_cuda():
not (current_platform.is_cuda() or current_platform.is_rocm()):
logger.info(
"MLA is enabled on a non-cuda platform; forcing chunked "
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
self.scheduler_config.enable_chunked_prefill = False
self.scheduler_config.chunked_prefill_enabled = False
Expand Down