Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Paddle Tensor 规范化第二期】pow support complex #71230

Merged
merged 9 commits into from
Mar 4, 2025

Conversation

fangfangssj
Copy link
Contributor

@fangfangssj fangfangssj commented Feb 21, 2025

PR Category

User Experience

PR Types

Bug fixes

Description

Pcard-75624
为pd_op.pow增加了复数支持与cuda实现,将pd_op.elementwise_powGPU上的实现由std改为cuda实现

通过nsys分析,powkernel执行时间从平均7.364μs降到6.986μs,elementwise_powkernel从平均9.631μs降到6.744μs

Copy link

paddle-bot bot commented Feb 21, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@fangfangssj
Copy link
Contributor Author

95274fb44d842b8887025beded2e3fed
test_activation_op.py文件中有很多的报错,单独测试pow相关的测试通过

Comment on lines +207 to +211
funcs::CudaPowFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = factor.to<float>();
ActivationGPUImpl<T, Context, funcs::CudaPowFunctor<T>>(
dev_ctx, x, out, functor);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以特判一下factor为0和1情况,0的话直接构造 实部全1,虚部全0 的张量作为结果,1的话直接拷贝x即可

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Comment on lines 251 to 271
template <typename T, typename Context>
void PowGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const Scalar& factor,
DenseTensor* dx) {
PADDLE_ENFORCE_NOT_NULL(
dx, errors::NotFound("The output DenseTensor dX can not be nullptr"));
dev_ctx.template Alloc<T>(dx);
auto dout_flatten = EigenVector<T>::Flatten(
GET_DATA_SAFELY(&dout, "Input", "Out@GRAD", "PowGrad"));
auto dx_flatten = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dx, "Output", "X@GRAD", "PowGrad"));
auto x_flatten =
EigenVector<T>::Flatten(GET_DATA_SAFELY(&x, "Input", "X", "PowGrad"));
auto* place = dev_ctx.eigen_device();
phi::funcs::PowGradFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = factor.to<float>();
functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理,特判一下factor=0的情况,1反向稍微麻烦点,可以正常调functor计算

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

+ 1j * np.random.uniform(0.1, 1, [1, 3, 6])
).astype(np.complex64),
}
self.attrs = {"factor": 2.0}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

factor建议多测几组,比如:[-3.4, -3, -2, -1, -0.1, 0, 0.7, 1, 1.3, 2, 3, 5]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加测试

Comment on lines 264 to 265
PADDLE_ENFORCE_NOT_NULL(
dx, errors::NotFound("The output DenseTensor dX can not be nullptr"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不应该用NotFound,image,而是InvalidArgument
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

const Scalar& factor,
DenseTensor* out) {
PADDLE_ENFORCE_NOT_NULL(out,
errors::NotFound("Output Out should not be nullptr"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,报错类型改一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

return;
}
PADDLE_ENFORCE_NOT_NULL(
dx, errors::NotFound("The output DenseTensor dX can not be nullptr"));
Copy link
Contributor

@HydrogenSulfate HydrogenSulfate Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

报错信息里,应该用dx而不是dX

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Contributor

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit d5b4ed5 into PaddlePaddle:develop Mar 4, 2025
33 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants