Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DCU] high performance LLM train and inference for DCU #8580

Merged
merged 7 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions csrc/generation/get_padding_offset_v2.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/extension.h"

__global__ void RemovePaddingV2(int64_t *output_data,
Expand Down
37 changes: 37 additions & 0 deletions csrc/generation/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,44 @@
#pragma once

#include "paddle/extension.h"
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#include <hipcub/hipcub.hpp>
#include <hiprand.h>
#include <hiprand_kernel.h>
namespace cub = hipcub;
#else
#include <cub/cub.cuh>
#include <curand_kernel.h>
#endif

constexpr int kBlockSize = 256;
constexpr int kNumWaves = 16;

#ifdef PADDLE_WITH_HIP
inline hipError_t GetNumBlocks(int64_t n, int* num_blocks) {
int dev;
{
hipError_t err = hipGetDevice(&dev);
if (err != hipSuccess) { return err; }
}
int sm_count;
{
hipError_t err = hipDeviceGetAttribute(&sm_count, hipDeviceAttributeMultiprocessorCount, dev);
if (err != hipSuccess) { return err; }
}
int tpm;
{
hipError_t err = hipDeviceGetAttribute(&tpm, hipDeviceAttributeMaxThreadsPerMultiProcessor, dev);
if (err != hipSuccess) { return err; }
}
*num_blocks = std::max<int>(1, std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * tpm / kBlockSize * kNumWaves));
return hipSuccess;
}
#else
inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) {
int dev;
{
Expand All @@ -41,6 +73,7 @@ inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) {
sm_count * tpm / kBlockSize * kNumWaves));
return cudaSuccess;
}
#endif

template<typename T>
__device__ T max_func(const T a, const T b) {
Expand Down Expand Up @@ -74,7 +107,11 @@ class PDTraits<paddle::DataType::FLOAT16> {
template <>
class PDTraits<paddle::DataType::BFLOAT16> {
public:
#ifdef PADDLE_WITH_HIP
typedef hip_bfloat16 DataType;
#else
typedef __nv_bfloat16 DataType;
#endif
typedef paddle::bfloat16 data_t;
};

Expand Down
18 changes: 17 additions & 1 deletion csrc/generation/quant_int8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@
#include<sys/mman.h>
#include<stdio.h>
#include<algorithm>
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#else
#include<cuda_fp16.h>
#include<cuda_bf16.h>
#endif


constexpr int DequantKernelVecSize = 4;
Expand Down Expand Up @@ -52,11 +57,17 @@ __forceinline__ __device__ half add_mul<half>(half a, half b, half c) {
return __hmul(__hadd(a, b), c);
}

#ifdef PADDLE_WITH_HIP
template<>
__forceinline__ __device__ hip_bfloat16 add_mul<hip_bfloat16>(hip_bfloat16 a, hip_bfloat16 b, hip_bfloat16 c) {
return (a + b) * c;
}
#else
template<>
__forceinline__ __device__ __nv_bfloat16 add_mul<__nv_bfloat16>(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
return __hmul(__hadd(a, b), c);
}

#endif


template <typename data_t>
Expand Down Expand Up @@ -173,8 +184,13 @@ std::vector<paddle::Tensor> LaunchQuantInt8(const paddle::Tensor& input,
auto output=paddle::full(input_shape, -1, paddle::DataType::INT8, input.place());
int m = input_shape[0];
int n = input_shape[1];
#ifdef PADDLE_WITH_HIP
dim3 grid(((n >> 2) + 63) / 64, (m + 7) / 8);
dim3 block(64, 8);
#else
dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32);
dim3 block(32, 32);
#endif
auto stream = input.stream();
if (shift && smooth) {
QuantKernel<DataType_><<<grid, block, 0, stream>>>(reinterpret_cast<const DataType_*>(input.data<data_t>()),
Expand Down
7 changes: 6 additions & 1 deletion csrc/generation/rebuild_padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ void InvokeRebuildPadding(T *output_data,
const int *padding_offset,
const int token_num,
const int dim_embed,
cudaStream_t stream) {
#ifdef PADDLE_WITH_HIP
hipStream_t stream
#else
cudaStream_t stream
#endif
) {
// src: [token_num, dim_embed]
// dst: [batch_size * max_seq_len, dim_embed]
RebuildPaddingKernel<<<token_num, 256, 0, stream>>>(
Expand Down
14 changes: 14 additions & 0 deletions csrc/generation/rebuild_padding_v2.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"

template <typename T, int VecSize>
Expand Down
14 changes: 14 additions & 0 deletions csrc/generation/set_value_by_flags_v2.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/extension.h"

__global__ void set_value_by_flag_and_id_v2(const bool *stop_flags,
Expand Down
22 changes: 22 additions & 0 deletions csrc/generation/step.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"

// #define DEBUG_STEP
Expand Down Expand Up @@ -255,7 +269,11 @@ void StepPaddle(const paddle::Tensor& stop_flags,
max_decoder_block_num
);
#ifdef DEBUG_STEP
#ifdef PADDLE_WITH_HIP
hipDeviceSynchronize();
#else
cudaDeviceSynchronize();
#endif
#endif
auto cpu_recover_lens = recover_lens.copy_to(paddle::CPUPlace(), false);
const int grid_size = cpu_recover_lens.data<int>()[0];
Expand Down Expand Up @@ -287,7 +305,11 @@ void StepPaddle(const paddle::Tensor& stop_flags,
first_token_id
);
#ifdef DEBUG_STEP
#ifdef PADDLE_WITH_HIP
hipDeviceSynchronize();
#else
cudaDeviceSynchronize();
#endif
#endif
}
}
Expand Down
14 changes: 14 additions & 0 deletions csrc/generation/stop_generation_multi_ends_v2.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/extension.h"
#include<stdlib.h>
#include<string.h>
Expand Down
14 changes: 14 additions & 0 deletions csrc/generation/token_penalty_multi_scores_v2.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"


Expand Down
7 changes: 6 additions & 1 deletion csrc/generation/transpose_removing_padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ void InvokeTransposeRemovePadding(const T* input_data,
const int head_dim,
const int token_num,
const int* padding_offset,
cudaStream_t cu_stream) {
#ifdef PADDLE_WITH_HIP
hipStream_t cu_stream
#else
cudaStream_t cu_stream
#endif
) {
// [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, num_head,
// head_dim]
constexpr int VEC_16B = 16;
Expand Down
35 changes: 31 additions & 4 deletions csrc/generation/write_int8_cache_kv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@

#include "helper.h"

#ifdef PADDLE_WITH_HIP
constexpr int32_t WARP_SIZE = 64;
constexpr int32_t HALF_WARP = 32;
#else
constexpr int32_t WARP_SIZE = 32;
constexpr int32_t HALF_WARP = 16;
#endif
constexpr float QUANT_MAX_BOUND = 127.0;
constexpr float QUANT_MIN_BOUND = -127.0;

Expand Down Expand Up @@ -47,14 +52,22 @@ struct MaxFunc{
template<>
struct MaxFunc<half>{
__device__ half operator()(half a, half b){
#if __CUDA_ARCH__ >= 800
#if (__CUDA_ARCH__ >= 800) || defined(PADDLE_WITH_HIP)
return __hmax(a, b);
#else
return max(static_cast<float>(a), static_cast<float>(b));
#endif
}
};

#ifdef PADDLE_WITH_HIP
template<>
struct MaxFunc<hip_bfloat16>{
__device__ hip_bfloat16 operator()(hip_bfloat16 a, hip_bfloat16 b){
return static_cast<hip_bfloat16>(max(static_cast<float>(a), static_cast<float>(b)));
}
};
#else
template<>
struct MaxFunc<__nv_bfloat16>{
__device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b){
Expand All @@ -65,6 +78,7 @@ struct MaxFunc<__nv_bfloat16>{
#endif
}
};
#endif

template<typename T>
struct AbsFunc{
Expand All @@ -76,14 +90,22 @@ struct AbsFunc{
template<>
struct AbsFunc<half>{
__device__ half operator()(half x){
#if __CUDA_ARCH__ >= 800
#if (__CUDA_ARCH__ >= 800) || defined(PADDLE_WITH_HIP)
return __habs(x);
#else
return abs(static_cast<float>(x));
#endif
}
};

#ifdef PADDLE_WITH_HIP
template<>
struct AbsFunc<hip_bfloat16>{
__device__ hip_bfloat16 operator()(hip_bfloat16 x) {
return static_cast<hip_bfloat16>(abs(static_cast<float>(x)));
}
};
#else
template<>
struct AbsFunc<__nv_bfloat16>{
__device__ __nv_bfloat16 operator()(__nv_bfloat16 x){
Expand All @@ -94,6 +116,7 @@ struct AbsFunc<__nv_bfloat16>{
#endif
}
};
#endif

template <typename T, typename Vec, int VecSize>
__inline__ __device__ T LocalReduceMax(Vec& vec) {
Expand All @@ -109,7 +132,11 @@ template <typename T>
__inline__ __device__ T WarpReduceAbsMax(T val, unsigned lane_mask) {
#pragma unroll
for (int mask = HALF_WARP; mask > 0; mask >>= 1){
#ifdef PADDLE_WITH_HIP
val = MaxFunc<T>()(val, static_cast<T>(__shfl_xor(static_cast<float>(val), mask, WARP_SIZE)));
#else
val = MaxFunc<T>()(val, __shfl_xor_sync(lane_mask, val, mask, WARP_SIZE));
#endif
}
return val;
}
Expand Down Expand Up @@ -147,7 +174,7 @@ __global__ void write_cache_k_int8_kernel(const T* k, const int64_t num_head, co
InVec abs_max_vec;
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
abs_max_vec[i] = 0.0f;
abs_max_vec[i] = static_cast<T>(0.0f);
}

T local_abs_max;
Expand Down Expand Up @@ -205,7 +232,7 @@ __global__ void write_cache_v_int8_kernel(const T* v, const int64_t num_head, co
InVec abs_max_vec;
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
abs_max_vec[i] = 0.0f;
abs_max_vec[i] = static_cast<T>(0.0f);
}

T local_abs_max;
Expand Down
Loading
Loading