-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Paddle Tensor 规范化第二期】pow support complex #71230
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
funcs::CudaPowFunctor<T> functor; | ||
auto attrs = functor.GetAttrs(); | ||
*(attrs[0].second) = factor.to<float>(); | ||
ActivationGPUImpl<T, Context, funcs::CudaPowFunctor<T>>( | ||
dev_ctx, x, out, functor); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以特判一下factor为0和1情况,0的话直接构造 实部全1,虚部全0 的张量作为结果,1的话直接拷贝x即可
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
template <typename T, typename Context> | ||
void PowGradKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
const DenseTensor& dout, | ||
const Scalar& factor, | ||
DenseTensor* dx) { | ||
PADDLE_ENFORCE_NOT_NULL( | ||
dx, errors::NotFound("The output DenseTensor dX can not be nullptr")); | ||
dev_ctx.template Alloc<T>(dx); | ||
auto dout_flatten = EigenVector<T>::Flatten( | ||
GET_DATA_SAFELY(&dout, "Input", "Out@GRAD", "PowGrad")); | ||
auto dx_flatten = EigenVector<T>::Flatten( | ||
GET_DATA_SAFELY(dx, "Output", "X@GRAD", "PowGrad")); | ||
auto x_flatten = | ||
EigenVector<T>::Flatten(GET_DATA_SAFELY(&x, "Input", "X", "PowGrad")); | ||
auto* place = dev_ctx.eigen_device(); | ||
phi::funcs::PowGradFunctor<T> functor; | ||
auto attrs = functor.GetAttrs(); | ||
*(attrs[0].second) = factor.to<float>(); | ||
functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同理,特判一下factor=0的情况,1反向稍微麻烦点,可以正常调functor计算
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
+ 1j * np.random.uniform(0.1, 1, [1, 3, 6]) | ||
).astype(np.complex64), | ||
} | ||
self.attrs = {"factor": 2.0} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
factor建议多测几组,比如:[-3.4, -3, -2, -1, -0.1, 0, 0.7, 1, 1.3, 2, 3, 5]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加测试
PADDLE_ENFORCE_NOT_NULL( | ||
dx, errors::NotFound("The output DenseTensor dX can not be nullptr")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
const Scalar& factor, | ||
DenseTensor* out) { | ||
PADDLE_ENFORCE_NOT_NULL(out, | ||
errors::NotFound("Output Out should not be nullptr")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,报错类型改一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
return; | ||
} | ||
PADDLE_ENFORCE_NOT_NULL( | ||
dx, errors::NotFound("The output DenseTensor dX can not be nullptr")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
报错信息里,应该用dx而不是dX
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
User Experience
PR Types
Bug fixes
Description
Pcard-75624
为pd_op.pow增加了复数支持与cuda实现,将pd_op.elementwise_powGPU上的实现由std改为cuda实现
通过nsys分析,powkernel执行时间从平均7.364μs降到6.986μs,elementwise_powkernel从平均9.631μs降到6.744μs