Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable custom paged attention kernel for Navi 3/4 #446

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
NUM_BLOCKS = 128 * 1024
PARTITION_SIZE = 512
PARTITION_SIZE_ROCM = 256
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_NAVI = "gfx1" in GPU_ARCH


@torch.inference_mode()
Expand Down Expand Up @@ -83,7 +85,7 @@ def main(
if version == "v2":
if current_platform.is_rocm():
global PARTITION_SIZE
if not args.custom_paged_attn:
if not args.custom_paged_attn and not ON_NAVI:
PARTITION_SIZE = 1024
else:
PARTITION_SIZE = PARTITION_SIZE_ROCM
Expand Down Expand Up @@ -170,6 +172,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
k_scale,
v_scale,
None,
ON_NAVI,
)
else:
raise ValueError(f"Invalid version: {version}")
Expand Down
1,964 changes: 1,828 additions & 136 deletions csrc/rocm/attention.cu

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion csrc/rocm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale,
const c10::optional<torch::Tensor>& fp8_out_scale);
const c10::optional<torch::Tensor>& fp8_out_scale,
bool is_navi);
3 changes: 2 additions & 1 deletion csrc/rocm/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale,"
" Tensor? fp8_out_scale) -> ()");
" Tensor? fp8_out_scale,"
" bool is_navi) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
rocm_ops.def(
"wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in,"
Expand Down
12 changes: 9 additions & 3 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ def test_paged_attention(
or (version == "rocm" and head_size not in (64, 128))):
pytest.skip()

is_rocm_navi = is_navi()
if (version == "rocm" and is_rocm_navi
and (kv_cache_dtype == "fp8" or head_size != 128
or block_size != 16 or use_alibi)):
pytest.skip()

global PARTITION_SIZE

current_platform.seed_everything(seed)
Expand Down Expand Up @@ -287,14 +293,14 @@ def test_paged_attention(
k_scale,
v_scale,
None,
PARTITION_SIZE,
is_rocm_navi,
)

opcheck(torch.ops._rocm_C.paged_attention,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, None, PARTITION_SIZE),
kv_cache_dtype, k_scale, v_scale, None, is_rocm_navi),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

Expand Down Expand Up @@ -446,4 +452,4 @@ def test_multi_query_kv_attention(
)
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
3 changes: 2 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,14 @@ def paged_attention_rocm(
k_scale: torch.Tensor,
v_scale: torch.Tensor,
fp8_out_scale: Optional[torch.Tensor],
is_navi: bool = False,
) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale,
fp8_out_scale)
fp8_out_scale, is_navi)


# pos encoding ops
Expand Down
39 changes: 27 additions & 12 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
PagedAttentionMetadata)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_navi

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
Expand All @@ -24,8 +25,8 @@

_PARTITION_SIZE_ROCM = 256
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
_ON_NAVI = "gfx1" in _GPU_ARCH
_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"])
_ON_NAVI3_NAVI4 = any(arch in _GPU_ARCH for arch in ["gfx11", "gfx12"])


class ROCmFlashAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -777,7 +778,8 @@ def forward(
gqa_ratio = num_heads // self.num_kv_heads
use_custom = _use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len)
decode_meta.max_decode_seq_len, self.kv_cache_dtype,
self.alibi_slopes)
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else
Expand Down Expand Up @@ -831,6 +833,7 @@ def forward(
layer._k_scale,
layer._v_scale,
fp8_out_scale if cpa_fp8_out else None,
is_navi(),
)
if cpa_fp8_out:
return out.view(num_seqs, num_heads * head_size)
Expand Down Expand Up @@ -896,13 +899,25 @@ def _sdpa_attention(
return output


def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int) -> bool:
# rocm custom page attention not support on navi (gfx1*)
return (_ON_MI250_MI300 and not _ON_NAVI
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024)
def _use_rocm_custom_paged_attention(
qtype: torch.dtype,
head_size: int,
block_size: int,
gqa_ratio: int,
max_seq_len: int,
kv_cache_dtype: str,
alibi_slopes: Optional[torch.Tensor] = None) -> bool:
if _ON_NAVI3_NAVI4:
return (envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN
and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128 and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024 and alibi_slopes is None
and kv_cache_dtype == "auto")
else:
return (_ON_MI250_MI300
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024)
2 changes: 1 addition & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: (os.getenv("VLLM_USE_ROCM_SKINNY_GEMM", "True").lower() in
("true", "1")),

# custom paged attention implemented for MI3* cards
# custom paged attention implemented for MI3*/Navi3*/Navi4* cards
"VLLM_USE_ROCM_CUSTOM_PAGED_ATTN":
lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
("true", "1")),
Expand Down