forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[LLM INFER] Append attn (PaddlePaddle#9244)
* refine paddle::empty(), fix memory error, support multi_stream for attention * fix and rename attention as append_attention * rename file --------- Co-authored-by: lizhenyun <[email protected]> Co-authored-by: lizhenyun01 <[email protected]>
- Loading branch information
Showing
45 changed files
with
18,321 additions
and
155 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
1,339 changes: 1,339 additions & 0 deletions
1,339
csrc/gpu/append_attn/append_attention_c16_impl.cuh
Large diffs are not rendered by default.
Oops, something went wrong.
1,580 changes: 1,580 additions & 0 deletions
1,580
csrc/gpu/append_attn/append_attention_c4_impl.cuh
Large diffs are not rendered by default.
Oops, something went wrong.
1,417 changes: 1,417 additions & 0 deletions
1,417
csrc/gpu/append_attn/append_attention_c8_impl.cuh
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,295 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
#pragma once | ||
|
||
#include "helper.h" | ||
#include "utils.cuh" | ||
|
||
template <typename T, typename OutT> | ||
void CascadeAppendAttentionC16Kernel( | ||
const AppendAttnMetaData& meta_data, | ||
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] | ||
const paddle::Tensor& | ||
cache_k, // [max_block_num, num_heads, block_size, head_dim] | ||
const paddle::Tensor& | ||
cache_v, // [max_block_num, num_heads, head_dim, block_size] | ||
const paddle::optional<paddle::Tensor>& attn_mask, | ||
const paddle::optional<paddle::Tensor>& | ||
cache_k_scale, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
cache_v_scale, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
cache_k_zp, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
cache_v_zp, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
shift_bias, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
smooth_weight, // [num_kv_heads, head_dim] | ||
const paddle::Tensor& seq_lens_q, | ||
const paddle::Tensor& seq_lens_kv, | ||
const paddle::Tensor& seq_lens_encoder, | ||
const paddle::Tensor& padding_offsets, | ||
const paddle::Tensor& cum_offsets, | ||
const paddle::Tensor& block_table, | ||
const paddle::Tensor& batch_ids, | ||
const paddle::Tensor& tile_ids_per_batch, | ||
const int num_blocks, | ||
const int block_shape_q, | ||
const int max_seq_len, | ||
const int max_dec_len, | ||
const float in_scale, | ||
const int max_partition_size, | ||
const int encoder_max_partition_size, | ||
const int speculate_max_draft_token_num, | ||
const bool causal, | ||
const bool is_decoder, | ||
const bool enable_prefill, | ||
cudaStream_t& stream, | ||
paddle::Tensor* out); | ||
|
||
template <typename T, typename OutT> | ||
void CascadeAppendAttentionC8Kernel( | ||
const AppendAttnMetaData& meta_data, | ||
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] | ||
const paddle::Tensor& | ||
cache_k, // [max_block_num, num_heads, block_size, head_dim] | ||
const paddle::Tensor& | ||
cache_v, // [max_block_num, num_heads, head_dim, block_size] | ||
const paddle::optional<paddle::Tensor>& attn_mask, | ||
const paddle::optional<paddle::Tensor>& | ||
cache_k_scale, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
cache_v_scale, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
cache_k_zp, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
cache_v_zp, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
shift_bias, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
smooth_weight, // [num_kv_heads, head_dim] | ||
const paddle::Tensor& seq_lens_q, | ||
const paddle::Tensor& seq_lens_kv, | ||
const paddle::Tensor& seq_lens_encoder, | ||
const paddle::Tensor& padding_offsets, | ||
const paddle::Tensor& cum_offsets, | ||
const paddle::Tensor& block_table, | ||
const paddle::Tensor& batch_ids, | ||
const paddle::Tensor& tile_ids_per_batch, | ||
const int num_blocks, | ||
const int block_shape_q, | ||
const int max_seq_len, | ||
const int max_dec_len, | ||
const float in_scale, | ||
const int max_partition_size, | ||
const int encoder_max_partition_size, | ||
const int speculate_max_draft_token_num, | ||
const bool causal, | ||
const bool is_decoder, | ||
const bool enable_prefill, | ||
cudaStream_t& stream, | ||
paddle::Tensor* out); | ||
|
||
template <typename T, typename OutT> | ||
void CascadeAppendAttentionC4Kernel( | ||
const AppendAttnMetaData& meta_data, | ||
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] | ||
const paddle::Tensor& | ||
cache_k, // [max_block_num, num_heads, block_size, head_dim] | ||
const paddle::Tensor& | ||
cache_v, // [max_block_num, num_heads, head_dim, block_size] | ||
const paddle::optional<paddle::Tensor>& attn_mask, | ||
const paddle::optional<paddle::Tensor>& | ||
cache_k_scale, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
cache_v_scale, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
cache_k_zp, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
cache_v_zp, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
shift_bias, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
smooth_weight, // [num_kv_heads, head_dim] | ||
const paddle::Tensor& seq_lens_q, | ||
const paddle::Tensor& seq_lens_kv, | ||
const paddle::Tensor& seq_lens_encoder, | ||
const paddle::Tensor& padding_offsets, | ||
const paddle::Tensor& cum_offsets, | ||
const paddle::Tensor& block_table, | ||
const paddle::Tensor& batch_ids, | ||
const paddle::Tensor& tile_ids_per_batch, | ||
const int num_blocks, | ||
const int block_shape_q, | ||
const int max_seq_len, | ||
const int max_dec_len, | ||
const float in_scale, | ||
const int max_partition_size, | ||
const int encoder_max_partition_size, | ||
const int speculate_max_draft_token_num, | ||
const bool causal, | ||
const bool is_decoder, | ||
const bool enable_prefill, | ||
cudaStream_t& stream, | ||
paddle::Tensor* out); | ||
|
||
template <typename T, typename OutT> | ||
void CascadeAppendAttentionKernel( | ||
const AppendAttnMetaData& meta_data, | ||
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] | ||
const paddle::Tensor& | ||
cache_k, // [max_block_num, num_heads, block_size, head_dim] | ||
const paddle::Tensor& | ||
cache_v, // [max_block_num, num_heads, head_dim, block_size] | ||
const paddle::optional<paddle::Tensor>& attn_mask, | ||
const paddle::optional<paddle::Tensor>& | ||
cache_k_scale, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
cache_v_scale, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
cache_k_zp, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
cache_v_zp, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
shift_bias, // [num_kv_heads, head_dim] | ||
const paddle::optional<paddle::Tensor>& | ||
smooth_weight, // [num_kv_heads, head_dim] | ||
const paddle::Tensor& seq_lens_q, | ||
const paddle::Tensor& seq_lens_kv, | ||
const paddle::Tensor& seq_lens_encoder, | ||
const paddle::Tensor& padding_offsets, | ||
const paddle::Tensor& cum_offsets, | ||
const paddle::Tensor& block_table, | ||
const paddle::Tensor& batch_ids, | ||
const paddle::Tensor& tile_ids_per_batch, | ||
const std::string& cache_quant_type_str, | ||
const int num_blocks, | ||
const int block_shape_q, | ||
const int max_seq_len, | ||
const int max_dec_len, | ||
const float in_scale, | ||
const int max_partition_size, | ||
const int encoder_max_partition_size, | ||
const int speculate_max_draft_token_num, | ||
const bool causal, | ||
const bool is_decoder, | ||
const bool enable_prefill, | ||
cudaStream_t& stream, | ||
paddle::Tensor* out) { | ||
if (cache_quant_type_str == "none") { | ||
CascadeAppendAttentionC16Kernel<T, OutT>(meta_data, | ||
qkv, | ||
cache_k, | ||
cache_v, | ||
attn_mask, | ||
cache_k_scale, | ||
cache_v_scale, | ||
cache_k_zp, | ||
cache_v_zp, | ||
shift_bias, | ||
smooth_weight, | ||
seq_lens_q, | ||
seq_lens_kv, | ||
seq_lens_encoder, | ||
padding_offsets, | ||
cum_offsets, | ||
block_table, | ||
batch_ids, | ||
tile_ids_per_batch, | ||
num_blocks, | ||
block_shape_q, | ||
max_seq_len, | ||
max_dec_len, | ||
in_scale, | ||
max_partition_size, | ||
encoder_max_partition_size, | ||
speculate_max_draft_token_num, | ||
causal, | ||
is_decoder, | ||
enable_prefill, | ||
stream, | ||
out); | ||
} else if (cache_quant_type_str == "cache_int8") { | ||
CascadeAppendAttentionC8Kernel<T, OutT>(meta_data, | ||
qkv, | ||
cache_k, | ||
cache_v, | ||
attn_mask, | ||
cache_k_scale, | ||
cache_v_scale, | ||
cache_k_zp, | ||
cache_v_zp, | ||
shift_bias, | ||
smooth_weight, | ||
seq_lens_q, | ||
seq_lens_kv, | ||
seq_lens_encoder, | ||
padding_offsets, | ||
cum_offsets, | ||
block_table, | ||
batch_ids, | ||
tile_ids_per_batch, | ||
num_blocks, | ||
block_shape_q, | ||
max_seq_len, | ||
max_dec_len, | ||
in_scale, | ||
max_partition_size, | ||
encoder_max_partition_size, | ||
speculate_max_draft_token_num, | ||
causal, | ||
is_decoder, | ||
enable_prefill, | ||
stream, | ||
out); | ||
} else if (cache_quant_type_str == "cache_int4_zp") { | ||
CascadeAppendAttentionC4Kernel<T, OutT>(meta_data, | ||
qkv, | ||
cache_k, | ||
cache_v, | ||
attn_mask, | ||
cache_k_scale, | ||
cache_v_scale, | ||
cache_k_zp, | ||
cache_v_zp, | ||
shift_bias, | ||
smooth_weight, | ||
seq_lens_q, | ||
seq_lens_kv, | ||
seq_lens_encoder, | ||
padding_offsets, | ||
cum_offsets, | ||
block_table, | ||
batch_ids, | ||
tile_ids_per_batch, | ||
num_blocks, | ||
block_shape_q, | ||
max_seq_len, | ||
max_dec_len, | ||
in_scale, | ||
max_partition_size, | ||
encoder_max_partition_size, | ||
speculate_max_draft_token_num, | ||
causal, | ||
is_decoder, | ||
enable_prefill, | ||
stream, | ||
out); | ||
} else { | ||
PD_THROW( | ||
"cache_quant_type_str should be one of [none, cache_int8, " | ||
"cache_int4_zp]"); | ||
} | ||
} |
Oops, something went wrong.