Skip to content

Commit

Permalink
fix bug of param num in custom op
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouheng.zheng committed May 5, 2022
1 parent 3f629d3 commit 6fccc02
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/tim/vx/ops/custom_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
reinterpret_cast<CustomOpBase*>(self->nn_param.client_param);

uint32_t param_num = op_this->param_list_.size();
uint32_t input_start = op_this->input_num_ + op_this->output_num_;

std::vector<tim::vx::DataType> input_types;
for (uint32_t i = 0; i < op_this->input_num_; i++) {
Expand Down Expand Up @@ -127,7 +128,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_);
kernel->unique_id =
std::hash<std::string>()(std::string(op_this->func_name_));
vx_param_description_t kernel_param_def[param_num];
vx_param_description_t kernel_param_def[param_num + input_start];

for (uint32_t i = 0; i < op_this->input_num_; i++) {
kernel_param_def[i] = {VX_INPUT, VX_TYPE_TENSOR,
Expand All @@ -145,7 +146,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,

kernel->info.parameters = kernel_param_def;
kernel->info.enumeration = KERNEL_ID_PLACEHOLDER;
kernel->info.numParams = param_num;
kernel->info.numParams = param_num + input_start;
kernel->info.initialize =
reinterpret_cast<vx_kernel_initialize_f>(op_this->init_kernel_);

Expand All @@ -162,11 +163,10 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,

auto node = vsi_nn_KernelCreateNodeExt(self->graph, kernel, resource);
if (node) {
uint32_t input_start = op_this->input_num_ + op_this->output_num_;

std::vector<vsi_nn_kernel_node_param_t> node_params(param_num + input_start);
vsi_nn_kernel_node_param_t* node_params_ptr = node_params.data();
vsi_nn_kernel_node_pack_io(node_params_ptr, param_num, inputs,
vsi_nn_kernel_node_pack_io(node_params_ptr, param_num + input_start, inputs,
op_this->input_num_, outputs,
op_this->output_num_);

Expand Down Expand Up @@ -196,7 +196,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
}

input_start = op_this->input_num_ + op_this->output_num_;
status = vsi_nn_KernelNodePassParam(node, node_params_ptr, param_num);
status = vsi_nn_KernelNodePassParam(node, node_params_ptr, param_num + input_start);
for (uint32_t i = 0; i < param_num; i++) {
vsi_nn_kernel_scalar_release(&node_params_ptr[input_start + i]);
}
Expand Down

0 comments on commit 6fccc02

Please sign in to comment.