Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix tune_cublaslt_gemm compile bug #8844

Merged
merged 3 commits into from
Jul 31, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 37 additions & 24 deletions csrc/generation/tune_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ limitations under the License. */
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <sys/time.h>

#include <algorithm>
#include <fstream>
#include <iostream>
#include <limits>
#include <list>
#include <vector>

#include "helper.h"

template <typename T>
Expand Down Expand Up @@ -466,11 +468,10 @@ class DevContext {};
class CPUContext : public DevContext {};

class CUBLASLTContext : public DevContext {
public:
CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle_)); }
public:
CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle)); }

private:
cublasLtHandle_t handle_;
cublasLtHandle_t handle;
};

template <typename InT, typename OutT, typename DevContext>
Expand Down Expand Up @@ -502,7 +503,7 @@ void GEMMInt8<int8_t, int32_t, CPUContext>(CPUContext dev_ctx,
}

template <>
void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,
void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
const std::vector<int8_t>& AVec,
const std::vector<int8_t>& BVec,
std::vector<int32_t>& CVec,
Expand Down Expand Up @@ -711,8 +712,8 @@ void TuneCublasltGemm(const paddle::Tensor& M,
bool is_test,
bool is_read_from_file,
const std::string& path) {

// Ensure that M, K, and N are all one-dimensional Tensors. is_test != is_read_from_file
// Ensure that M, K, and N are all one-dimensional Tensors. is_test !=
// is_read_from_file
assert(M.dims().size() == 1 && K.dims().size() == 1 && N.dims().size() == 1);
assert(is_test != is_read_from_file);

Expand All @@ -730,22 +731,34 @@ void TuneCublasltGemm(const paddle::Tensor& M,

int m_data = (int)M_data[0];
assert(m_data > 0 && 4 <= 8192);

std::vector<int> mm;

int m = 1, step = 1;
while (m <= m_data) {
while (m <= m_data) {
mm.push_back(m);
m += step;

// update step
switch (m) {
case 4: step = 4; break;
case 16: step = 16; break;
case 64: step = 32; break;
case 256: step = 64; break;
case 512: step = 128; break;
case 1024: step = 1024; break;
case 4:
step = 4;
break;
case 16:
step = 16;
break;
case 64:
step = 32;
break;
case 256:
step = 64;
break;
case 512:
step = 128;
break;
case 1024:
step = 1024;
break;
}
}

Expand All @@ -761,15 +774,15 @@ void TuneCublasltGemm(const paddle::Tensor& M,
if (dtype == "int8") {
CUBLASLTContext dev_ctx;
GEMMInt8(dev_ctx,
A,
B,
C,
m,
k,
n,
is_test, /*is_test*/
is_read_from_file, /*is_read_from_file*/
path);
A,
B,
C,
m,
k,
n,
is_test, /*is_test*/
is_read_from_file, /*is_read_from_file*/
path);
} else {
// other dtype
std::cout << "Not currently supported" << std::endl;
Expand Down
Loading