Skip to content

Commit

Permalink
Added general Float16 support
Browse files Browse the repository at this point in the history
Added Float16 type definition from third-party
Refine float16 bias handlling in conv2d
Refine float16 case in conv2d
Caution: Headers of float16 only be included when build unit_�test

Type: New Feature
Signed-off-by: Feiyue Chen <[email protected]>
  • Loading branch information
chenfeiyue-cfy committed Aug 10, 2023
1 parent 35e50d7 commit 4b288fb
Show file tree
Hide file tree
Showing 12 changed files with 3,771 additions and 197 deletions.
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ cc_binary(
cc_test (
name = "unit_test",
copts = ["-std=c++14", "-Werror"],
includes = ["third_party/half"]
srcs = [
"src/tim/vx/test_utils.h",
] + glob(["src/tim/**/*_test.cc"]),
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ if(TIM_VX_ENABLE_TEST)
FetchContent_Populate(googletest)
add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR})
endif()

include_directories(third_party/half)
endif()

if(TIM_VX_ENABLE_GRPC)
Expand Down
3 changes: 0 additions & 3 deletions include/tim/vx/ops/conv2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,9 @@ class Conv2d : public BuiltinOp {
const int32_t multiplier_;
const DataLayout kernel_layout_;

#if defined(__clang__) && (__clang_major__ >= 15)
#define TIM_VX_OPS_CONV2D_WITH_F16BIAS 1
private:
void OnBindInputPostProc(const std::shared_ptr<Tensor>& tensor,
int32_t input_idx) override;
#endif
};

} // namespace ops
Expand Down
28 changes: 10 additions & 18 deletions src/tim/vx/ops/conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ Conv2d::Conv2d(Graph* graph, const std::array<uint32_t, 4> pad,
const std::array<uint32_t, 2>& stride,
const std::array<uint32_t, 2>& dilation, int32_t multiplier,
DataLayout input_layout, DataLayout kernel_layout)
: Conv2d(graph, 0, PadType::AUTO, {0, 0}, stride, dilation, pad,
multiplier, input_layout, kernel_layout) {}
: Conv2d(graph, 0, PadType::AUTO, {0, 0}, stride, dilation, pad, multiplier,
input_layout, kernel_layout) {}

Conv2d::Conv2d(Graph* graph, int32_t weights, PadType padding,
const std::array<uint32_t, 2>& ksize,
Expand Down Expand Up @@ -88,41 +88,33 @@ std::shared_ptr<Operation> Conv2d::Clone(std::shared_ptr<Graph>& graph) const {
this->kernel_layout_);
}

const std::vector<std::shared_ptr<Tensor>> Conv2d::ConstantInputsTensor() const {
if (this->IsAllInputsConst()) {
const std::vector<std::shared_ptr<Tensor>> Conv2d::ConstantInputsTensor()
const {
if (this->IsAllInputsConst()) {
return {this->impl_->inputs_tensor_[0]};
} else {
return {};
}
}

// Handle float16 bias if clang compiler is no less than 15.0.0 version
#ifdef TIM_VX_OPS_CONV2D_WITH_F16BIAS
// Handle float16 bias
void Conv2d::OnBindInputPostProc(const std::shared_ptr<Tensor>& tensor,
int32_t input_idx) {
if (tensor->GetDataType() == vx::DataType::FLOAT16 &&
tensor->IsConstTensor() && impl_->inputs_tensor_.size() == 3) {
uint32_t bias_size = 1;
for (auto i : tensor->GetShape()) {
bias_size *= i;
}
std::vector<_Float16> in(bias_size);
tensor->CopyDataFromTensor(in.data());
float* float32_bias = tensor->ConvertTensorToFloat32Data();

std::vector<float> out(bias_size);
for (uint i = 0; i < bias_size; i++) {
out[i] = static_cast<float>(in[i]);
}
TensorSpec fp32bias_spec(tim::vx::DataType::FLOAT32, tensor->GetShape(),
tim::vx::TensorAttribute::CONSTANT);
auto out_tensor = impl_->graph_->CreateTensor(fp32bias_spec, out.data());

auto out_tensor = impl_->graph_->CreateTensor(fp32bias_spec, float32_bias);
vsi_nn_Free(float32_bias);

impl_->inputs_tensor_[2] = out_tensor;
impl_->node()->input.tensors[input_idx] = out_tensor->GetId();
impl_->graph_->RenewTensorConsumersMap(tensor, out_tensor, this);
}
}
#endif

} // namespace ops
} // namespace vx
Expand Down
Loading

0 comments on commit 4b288fb

Please sign in to comment.