Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-pick] fix weight quant kernel bug when n div 64 != 0 #60184

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ void weight_permute_gpu(const GPUContext& dev_ctx,
input_data, output_data, numel, total_k, total_n);
}
}

template <typename T, int VectorSize = 8>
__global__ void per_channel_quant_gpu(const T* weight_data,
int8_t* quanted_weight_data,
Expand Down Expand Up @@ -160,7 +161,6 @@ __global__ void per_channel_quant_gpu(const T* weight_data,
}
}
}

template <typename T, typename GPUContext>
void weight_quant_gpu(const GPUContext& dev_ctx,
const T* weight_data,
Expand All @@ -174,8 +174,15 @@ void weight_quant_gpu(const GPUContext& dev_ctx,
constexpr int kBlockSize = 64;
constexpr int kWarpNum = kBlockSize / kWarpSize;
constexpr int kVectorSize = 128 / sizeof(T) / 8;
PADDLE_ENFORCE_EQ(total_n % kVectorSize,
0,
phi::errors::PreconditionNotMet(
"Currently, weight_quant_gpu kernel only support n "
"with multiple of %d, please use",
kVectorSize));
int vec_total_n = total_n / kVectorSize;
int kGridSize = max(vec_total_n / kBlockSize, static_cast<int>(1));
int kGridSize =
max((vec_total_n + kBlockSize - 1) / kBlockSize, static_cast<int>(1));
per_channel_quant_gpu<T, kVectorSize><<<kGridSize, kBlockSize>>>(
weight_data, quanted_weight_data, scale_data, total_k, vec_total_n);
}
Expand Down
42 changes: 42 additions & 0 deletions test/quantization/test_weight_only_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,5 +399,47 @@ def test_weightonly_linear_backward(self):
np.testing.assert_allclose(quant_x.grad, x.grad, rtol=1e-3, atol=1e-3)


@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase11(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.weight_dtype = "int8"
self.in_features = 128
self.out_features = 288


@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase12(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = False
self.weight_dtype = "int8"
self.in_features = 128
self.out_features = 288


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase13(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int8"
self.in_features = 128
self.out_features = 288


if __name__ == '__main__':
unittest.main()