-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
crossentropy grad op #3186
crossentropy grad op #3186
Conversation
paddle/framework/operator.h
Outdated
@@ -319,5 +319,14 @@ class OperatorWithKernel : public OperatorBase { | |||
virtual void InferShape(const InferShapeContext& ctx) const = 0; | |||
}; | |||
|
|||
namespace op_helpers { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle/operators/cross_entropy_op.cc
Outdated
auto X_grad = ctx.Output<Tensor>(GenGradName("X")); | ||
// auto Y_grad = ctx.Input<Tensor>(GenGradName("Y")); | ||
auto X = ctx.Input<Tensor>("X"); | ||
// auto Y = ctx.Input<Tensor>("Y"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete unwanted code instead of commenting them out.
paddle/operators/cross_entropy_op.h
Outdated
@@ -21,25 +21,53 @@ namespace operators { | |||
template <typename Place, typename T> | |||
class OnehotCrossEntropyOpKernel : public OpKernel { | |||
public: | |||
constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); } | |||
constexpr T kLOG_THRESHOLD() const { return static_cast<T>(1e-20); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's follow the code style https://google.github.io/styleguide/cppguide.html#Constant_Names and change kLOG_THRESHOLD
into
static const double kLogThreshold = 1e-20;
We mustn't define it as a function template without any constaint on type T
, as it is obvious that we cannoot type cast 1e-20 into any integral value like LOG_THRESHOLD<int>()
.
Let's fix this special case in this PR and fix others in another separate PR.
No description provided.