From ecde822ed66cad4238b89fcec3b2c001fa9e97fd Mon Sep 17 00:00:00 2001 From: yuguo-Jack <948529990@qq.com> Date: Tue, 11 Jun 2024 16:37:58 +0800 Subject: [PATCH 1/3] [DCU] high performance LLM train and inference for DCU --- csrc/generation/get_padding_offset_v2.cu | 14 +++++ csrc/generation/helper.h | 37 ++++++++++++ csrc/generation/quant_int8.cu | 18 +++++- csrc/generation/rebuild_padding.cu | 7 ++- csrc/generation/rebuild_padding_v2.cu | 14 +++++ csrc/generation/set_value_by_flags_v2.cu | 14 +++++ csrc/generation/step.cu | 22 +++++++ .../stop_generation_multi_ends_v2.cu | 14 +++++ .../token_penalty_multi_scores_v2.cu | 14 +++++ csrc/generation/transpose_removing_padding.cu | 7 ++- csrc/generation/write_int8_cache_kv.cu | 35 +++++++++-- csrc/setup_hip.py | 58 +++++++++++++++++++ 12 files changed, 247 insertions(+), 7 deletions(-) create mode 100644 csrc/setup_hip.py diff --git a/csrc/generation/get_padding_offset_v2.cu b/csrc/generation/get_padding_offset_v2.cu index 3acfad6cb8a7..080764ed9955 100644 --- a/csrc/generation/get_padding_offset_v2.cu +++ b/csrc/generation/get_padding_offset_v2.cu @@ -1,3 +1,17 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/extension.h" __global__ void RemovePaddingV2(int64_t *output_data, diff --git a/csrc/generation/helper.h b/csrc/generation/helper.h index 4a74709aecae..fcc521318339 100644 --- a/csrc/generation/helper.h +++ b/csrc/generation/helper.h @@ -15,12 +15,44 @@ #pragma once #include "paddle/extension.h" +#ifdef PADDLE_WITH_HIP +#include +#include +#include +#include +#include +#include +namespace cub = hipcub; +#else #include #include +#endif constexpr int kBlockSize = 256; constexpr int kNumWaves = 16; +#ifdef PADDLE_WITH_HIP +inline hipError_t GetNumBlocks(int64_t n, int* num_blocks) { + int dev; + { + hipError_t err = hipGetDevice(&dev); + if (err != hipSuccess) { return err; } + } + int sm_count; + { + hipError_t err = hipDeviceGetAttribute(&sm_count, hipDeviceAttributeMultiprocessorCount, dev); + if (err != hipSuccess) { return err; } + } + int tpm; + { + hipError_t err = hipDeviceGetAttribute(&tpm, hipDeviceAttributeMaxThreadsPerMultiProcessor, dev); + if (err != hipSuccess) { return err; } + } + *num_blocks = std::max(1, std::min((n + kBlockSize - 1) / kBlockSize, + sm_count * tpm / kBlockSize * kNumWaves)); + return hipSuccess; +} +#else inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) { int dev; { @@ -41,6 +73,7 @@ inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) { sm_count * tpm / kBlockSize * kNumWaves)); return cudaSuccess; } +#endif template __device__ T max_func(const T a, const T b) { @@ -74,7 +107,11 @@ class PDTraits { template <> class PDTraits { public: +#ifdef PADDLE_WITH_HIP + typedef hip_bfloat16 DataType; +#else typedef __nv_bfloat16 DataType; +#endif typedef paddle::bfloat16 data_t; }; diff --git a/csrc/generation/quant_int8.cu b/csrc/generation/quant_int8.cu index 1e76f3563ae9..c34c6b701af9 100644 --- a/csrc/generation/quant_int8.cu +++ b/csrc/generation/quant_int8.cu @@ -22,8 +22,13 @@ #include #include #include +#ifdef PADDLE_WITH_HIP +#include +#include +#else #include #include +#endif constexpr int DequantKernelVecSize = 4; @@ -52,11 +57,17 @@ __forceinline__ __device__ half add_mul(half a, half b, half c) { return __hmul(__hadd(a, b), c); } +#ifdef PADDLE_WITH_HIP +template<> +__forceinline__ __device__ hip_bfloat16 add_mul(hip_bfloat16 a, hip_bfloat16 b, hip_bfloat16 c) { + return (a + b) * c; +} +#else template<> __forceinline__ __device__ __nv_bfloat16 add_mul<__nv_bfloat16>(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { return __hmul(__hadd(a, b), c); } - +#endif template @@ -173,8 +184,13 @@ std::vector LaunchQuantInt8(const paddle::Tensor& input, auto output=paddle::full(input_shape, -1, paddle::DataType::INT8, input.place()); int m = input_shape[0]; int n = input_shape[1]; +#ifdef PADDLE_WITH_HIP + dim3 grid(((n >> 2) + 63) / 64, (m + 7) / 8); + dim3 block(64, 8); +#else dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32); dim3 block(32, 32); +#endif auto stream = input.stream(); if (shift && smooth) { QuantKernel<<>>(reinterpret_cast(input.data()), diff --git a/csrc/generation/rebuild_padding.cu b/csrc/generation/rebuild_padding.cu index 3c8dcc9be47f..2e8245fe2c99 100644 --- a/csrc/generation/rebuild_padding.cu +++ b/csrc/generation/rebuild_padding.cu @@ -58,7 +58,12 @@ void InvokeRebuildPadding(T *output_data, const int *padding_offset, const int token_num, const int dim_embed, - cudaStream_t stream) { +#ifdef PADDLE_WITH_HIP + hipStream_t stream +#else + cudaStream_t stream +#endif + ) { // src: [token_num, dim_embed] // dst: [batch_size * max_seq_len, dim_embed] RebuildPaddingKernel<<>>( diff --git a/csrc/generation/rebuild_padding_v2.cu b/csrc/generation/rebuild_padding_v2.cu index 4d61936952da..6a4b83c103d3 100644 --- a/csrc/generation/rebuild_padding_v2.cu +++ b/csrc/generation/rebuild_padding_v2.cu @@ -1,3 +1,17 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "helper.h" template diff --git a/csrc/generation/set_value_by_flags_v2.cu b/csrc/generation/set_value_by_flags_v2.cu index f954c8c96d1d..4171fb5b63c5 100644 --- a/csrc/generation/set_value_by_flags_v2.cu +++ b/csrc/generation/set_value_by_flags_v2.cu @@ -1,3 +1,17 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/extension.h" __global__ void set_value_by_flag_and_id_v2(const bool *stop_flags, diff --git a/csrc/generation/step.cu b/csrc/generation/step.cu index b586db566916..93262e98a9e4 100644 --- a/csrc/generation/step.cu +++ b/csrc/generation/step.cu @@ -1,3 +1,17 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "helper.h" // #define DEBUG_STEP @@ -255,7 +269,11 @@ void StepPaddle(const paddle::Tensor& stop_flags, max_decoder_block_num ); #ifdef DEBUG_STEP +#ifdef PADDLE_WITH_HIP + hipDeviceSynchronize(); +#else cudaDeviceSynchronize(); +#endif #endif auto cpu_recover_lens = recover_lens.copy_to(paddle::CPUPlace(), false); const int grid_size = cpu_recover_lens.data()[0]; @@ -287,7 +305,11 @@ void StepPaddle(const paddle::Tensor& stop_flags, first_token_id ); #ifdef DEBUG_STEP +#ifdef PADDLE_WITH_HIP + hipDeviceSynchronize(); +#else cudaDeviceSynchronize(); +#endif #endif } } diff --git a/csrc/generation/stop_generation_multi_ends_v2.cu b/csrc/generation/stop_generation_multi_ends_v2.cu index 7f23029681a5..726b41afa5d1 100644 --- a/csrc/generation/stop_generation_multi_ends_v2.cu +++ b/csrc/generation/stop_generation_multi_ends_v2.cu @@ -1,3 +1,17 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/extension.h" #include #include diff --git a/csrc/generation/token_penalty_multi_scores_v2.cu b/csrc/generation/token_penalty_multi_scores_v2.cu index b1bbdd4a40d6..24ff966f6ca3 100644 --- a/csrc/generation/token_penalty_multi_scores_v2.cu +++ b/csrc/generation/token_penalty_multi_scores_v2.cu @@ -1,3 +1,17 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "helper.h" diff --git a/csrc/generation/transpose_removing_padding.cu b/csrc/generation/transpose_removing_padding.cu index 5b6b16a7faa2..0e7874272bb3 100644 --- a/csrc/generation/transpose_removing_padding.cu +++ b/csrc/generation/transpose_removing_padding.cu @@ -65,7 +65,12 @@ void InvokeTransposeRemovePadding(const T* input_data, const int head_dim, const int token_num, const int* padding_offset, - cudaStream_t cu_stream) { +#ifdef PADDLE_WITH_HIP + hipStream_t cu_stream +#else + cudaStream_t cu_stream +#endif + ) { // [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, num_head, // head_dim] constexpr int VEC_16B = 16; diff --git a/csrc/generation/write_int8_cache_kv.cu b/csrc/generation/write_int8_cache_kv.cu index 3e423f0d9db7..7def48920495 100644 --- a/csrc/generation/write_int8_cache_kv.cu +++ b/csrc/generation/write_int8_cache_kv.cu @@ -14,8 +14,13 @@ #include "helper.h" +#ifdef PADDLE_WITH_HIP +constexpr int32_t WARP_SIZE = 64; +constexpr int32_t HALF_WARP = 32; +#else constexpr int32_t WARP_SIZE = 32; constexpr int32_t HALF_WARP = 16; +#endif constexpr float QUANT_MAX_BOUND = 127.0; constexpr float QUANT_MIN_BOUND = -127.0; @@ -47,7 +52,7 @@ struct MaxFunc{ template<> struct MaxFunc{ __device__ half operator()(half a, half b){ -#if __CUDA_ARCH__ >= 800 +#if (__CUDA_ARCH__ >= 800) || defined(PADDLE_WITH_HIP) return __hmax(a, b); #else return max(static_cast(a), static_cast(b)); @@ -55,6 +60,14 @@ struct MaxFunc{ } }; +#ifdef PADDLE_WITH_HIP +template<> +struct MaxFunc{ + __device__ hip_bfloat16 operator()(hip_bfloat16 a, hip_bfloat16 b){ + return static_cast(max(static_cast(a), static_cast(b))); + } +}; +#else template<> struct MaxFunc<__nv_bfloat16>{ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b){ @@ -65,6 +78,7 @@ struct MaxFunc<__nv_bfloat16>{ #endif } }; +#endif template struct AbsFunc{ @@ -76,7 +90,7 @@ struct AbsFunc{ template<> struct AbsFunc{ __device__ half operator()(half x){ - #if __CUDA_ARCH__ >= 800 + #if (__CUDA_ARCH__ >= 800) || defined(PADDLE_WITH_HIP) return __habs(x); #else return abs(static_cast(x)); @@ -84,6 +98,14 @@ struct AbsFunc{ } }; +#ifdef PADDLE_WITH_HIP +template<> +struct AbsFunc{ + __device__ hip_bfloat16 operator()(hip_bfloat16 x) { + return static_cast(abs(static_cast(x))); + } +}; +#else template<> struct AbsFunc<__nv_bfloat16>{ __device__ __nv_bfloat16 operator()(__nv_bfloat16 x){ @@ -94,6 +116,7 @@ struct AbsFunc<__nv_bfloat16>{ #endif } }; +#endif template __inline__ __device__ T LocalReduceMax(Vec& vec) { @@ -109,7 +132,11 @@ template __inline__ __device__ T WarpReduceAbsMax(T val, unsigned lane_mask) { #pragma unroll for (int mask = HALF_WARP; mask > 0; mask >>= 1){ +#ifdef PADDLE_WITH_HIP + val = MaxFunc()(val, static_cast(__shfl_xor(static_cast(val), mask, WARP_SIZE))); +#else val = MaxFunc()(val, __shfl_xor_sync(lane_mask, val, mask, WARP_SIZE)); +#endif } return val; } @@ -147,7 +174,7 @@ __global__ void write_cache_k_int8_kernel(const T* k, const int64_t num_head, co InVec abs_max_vec; #pragma unroll for (int i = 0; i < VecSize; ++i) { - abs_max_vec[i] = 0.0f; + abs_max_vec[i] = static_cast(0.0f); } T local_abs_max; @@ -205,7 +232,7 @@ __global__ void write_cache_v_int8_kernel(const T* v, const int64_t num_head, co InVec abs_max_vec; #pragma unroll for (int i = 0; i < VecSize; ++i) { - abs_max_vec[i] = 0.0f; + abs_max_vec[i] = static_cast(0.0f); } T local_abs_max; diff --git a/csrc/setup_hip.py b/csrc/setup_hip.py new file mode 100644 index 000000000000..74485bef3ba8 --- /dev/null +++ b/csrc/setup_hip.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.utils.cpp_extension import CUDAExtension, setup + +setup( + name="paddlenlp_ops", + ext_modules=CUDAExtension( + sources=[ + "./generation/save_with_output.cc", + "./generation/set_value_by_flags.cu", + "./generation/token_penalty_multi_scores.cu", + "./generation/token_penalty_multi_scores_v2.cu", + "./generation/stop_generation_multi_ends.cu", + "./generation/fused_get_rope.cu", + "./generation/get_padding_offset.cu", + "./generation/qkv_transpose_split.cu", + "./generation/rebuild_padding.cu", + "./generation/transpose_removing_padding.cu", + "./generation/write_cache_kv.cu", + "./generation/encode_rotary_qk.cu", + "./generation/get_padding_offset_v2.cu", + "./generation/rebuild_padding_v2.cu", + "./generation/set_value_by_flags_v2.cu", + "./generation/stop_generation_multi_ends_v2.cu", + "./generation/update_inputs.cu", + "./generation/get_output.cc", + "./generation/save_with_output_msg.cc", + "./generation/write_int8_cache_kv.cu", + "./generation/step.cu", + "./generation/quant_int8.cu", + "./generation/dequant_int8.cu", + ], + extra_compile_args={ + "cxx": ["-O3"], + "hipcc": [ + "-O3", + "-U__HIP_NO_HALF_OPERATORS__", + "-U__HIP_NO_HALF_CONVERSIONS__", + "-U__HIP_NO_BFLOAT16_OPERATORS__", + "-U__HIP_NO_BFLOAT16_CONVERSIONS__", + "-U__HIP_NO_BFLOAT162_OPERATORS__", + "-U__HIP_NO_BFLOAT162_CONVERSIONS__", + ], + }, + ), +) From 9bd24f88f4cee72e802198db642591446a6efbf1 Mon Sep 17 00:00:00 2001 From: yuguo-Jack <948529990@qq.com> Date: Tue, 18 Jun 2024 11:28:33 +0800 Subject: [PATCH 2/3] fix --- csrc/setup_hip.py | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/setup_hip.py b/csrc/setup_hip.py index 74485bef3ba8..c9f2d8e216e4 100644 --- a/csrc/setup_hip.py +++ b/csrc/setup_hip.py @@ -46,6 +46,7 @@ "cxx": ["-O3"], "hipcc": [ "-O3", + "--gpu-max-threads-per-block=1024", "-U__HIP_NO_HALF_OPERATORS__", "-U__HIP_NO_HALF_CONVERSIONS__", "-U__HIP_NO_BFLOAT16_OPERATORS__", From 1064402b547eec3727059769ce586ffc40295c1d Mon Sep 17 00:00:00 2001 From: yuguo-Jack <948529990@qq.com> Date: Tue, 18 Jun 2024 17:52:48 +0800 Subject: [PATCH 3/3] fix --- csrc/setup_hip.py | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/setup_hip.py b/csrc/setup_hip.py index c9f2d8e216e4..398e9cbf488a 100644 --- a/csrc/setup_hip.py +++ b/csrc/setup_hip.py @@ -41,6 +41,7 @@ "./generation/step.cu", "./generation/quant_int8.cu", "./generation/dequant_int8.cu", + "./generation/flash_attn_bwd.cc", ], extra_compile_args={ "cxx": ["-O3"],