Skip to content

Commit

Permalink
remove fp8 scale when reducing on Navi
Browse files Browse the repository at this point in the history
  • Loading branch information
hyoon1 committed Feb 26, 2025
1 parent e689d99 commit 047b9ce
Showing 1 changed file with 1 addition and 8 deletions.
9 changes: 1 addition & 8 deletions csrc/rocm/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2322,16 +2322,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(

const float inv_global_exp_sum =
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
const float out_scale =
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
acc *= inv_global_exp_sum;
acc *= out_scale;
OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
if constexpr (std::is_same<OUTT, bit8_t>::value) {
out_ptr[threadIdx.x] = hip_fp8(acc).data;
} else {
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
}
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
}

#else
Expand Down

0 comments on commit 047b9ce

Please sign in to comment.