-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fake_quantize_op. #11359
Add fake_quantize_op. #11359
Conversation
// PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0], | ||
// ctx->Outputs("OutScales")[0], | ||
// "Mean and MeanOut should share the same memory"); | ||
//} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the commented lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the comment is for test of python , the commented lines is used for train
"Input(X) of FakeQuantizeOp should not be null."); | ||
PADDLE_ENFORCE(ctx->HasOutput("Out"), | ||
"Output(Out) of FakeQuantizeOp should not be null."); | ||
PADDLE_ENFORCE(ctx->HasOutput("OutMovingScale"), ""); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ctx->SetOutputDim("OutMovingScale", ctx->GetInputDim("InMovingScale")); | ||
//} | ||
// if (ctx->HasInput("InScales")) { | ||
PADDLE_ENFORCE(ctx->HasOutput("OutScales"), ""); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
public: | ||
void Make() override { | ||
AddInput("X", "(Tensor) Input tensor of scale operator."); | ||
AddInput("InScales", "(Tensor) scale buffer").AsDispensable(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add more comments for why this argument is optional. When need it and when don't need it. The same is the following.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
namespace operators { | ||
|
||
template <typename T> | ||
__global__ void find_abs_max_kernel(const int n, const T* in, T* out) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
find_abs_max_kernel -> FindAbsMaxKernel
Please follow Google C++ code style: https://google.github.io/styleguide/cppguide.html#Function_Names
Please modify other code with the same problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
float find_abs_max_gpu(const platform::CUDADeviceContext& ctx, | ||
const float* array, int length) { | ||
float host_max; | ||
int NUM_THREADS = 1024; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NUM_THREADS -> kNumTheads Please follow Goolge code style.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
cudaMemcpy(&host_max, device_max, sizeof(float), cudaMemcpyDeviceToHost), | ||
cudaSuccess, "cudaMemcpy failed"); | ||
return host_max; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe can use thrust::reduce + thrust::max_element to find the maximum value for more simply.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will be slow
int window_size = context.Attr<int>("window_size"); | ||
int bit_length = context.Attr<int>("bit_length"); | ||
int bin_cnt = std::pow(2, bit_length - 1) - 1; | ||
LOG(ERROR) << "bin_cnt:" << bin_cnt; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
auto* scale_list = context.Output<framework::Tensor>("OutScales"); | ||
auto* saving_scale = | ||
context.Output<framework::Tensor>("OutMovingScale"); | ||
scale = find_abs_max(const_cast<framework::Tensor*>(in), in->numel()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here can use Eigen's method:
https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/adamax_op.h#L57
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this cwiseMax is an elemwise max operation, i need a reduce max op .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need unit testing.
} | ||
} | ||
|
||
apply_saturate(const_cast<framework::Tensor*>(in), tensor, -scale, scale); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also can refer https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/clip_op.h#L70 for more simply.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
AddComment(R"DOC( | ||
FakeQuantize operator | ||
|
||
$$Out = scale*X$$ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need comments for how to calculate scale
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
$$Out = scale*X$$ | ||
)DOC"); | ||
AddAttr<std::string>("quantize_type", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quantize_type -> scale_type for more accurate ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the quantization method is non-uniform, scale is not need, so i think this should not be scale_type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我理解quantize_type
一般指: Abs-Max,或者Min-Max等不同的量化方式。
而这里,这个attr是想来标示,计算scale的方式吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是非均匀量化, 那浮点输入和定点输出可能是个函数,或者离散数值映射,就没有scale操作了
CI没有通过,需要更新至最新develop代码。 |
The unit testing did not pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved. @dangqingqing will refine and add more unit testing.
* Add a fake_quantize_op, which quantize an input tensor to a tensor with lower bits.
add quant code for test