Skip to content

Commit

Permalink
Added scalar type support
Browse files Browse the repository at this point in the history
Added SetScalar api to support scalar input

Ovxlib need update to support gather with scalar index

Type: New Feature
Signed-off-by: Feiyue Chen <[email protected]>
  • Loading branch information
chenfeiyue-cfy committed Oct 26, 2023
1 parent 1008179 commit d4187b3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/tim/vx/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ class Tensor {
virtual void unmap() = 0;
virtual bool IsPlaceHolder() = 0;
virtual bool IsConstTensor() = 0;
virtual bool IsScalar() = 0;
virtual bool SaveTensorToTextByFp32(std::string filename) = 0;
virtual void SetScalar(int8_t is_scalar) = 0;
virtual void* ConvertTensorToData(uint8_t* tensorData) = 0;
virtual float* ConvertTensorToFloat32Data() = 0;
};
Expand Down
8 changes: 8 additions & 0 deletions src/tim/vx/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,14 @@ float* TensorImpl::ConvertTensorToFloat32Data() {
graph_->graph(), vsi_nn_GetTensor(graph_->graph(), id_));
}

void TensorImpl::SetScalar(int8_t is_scalar) {
bool retn = vsi_nn_SetTensorIsScalar(vsi_nn_GetTensor(graph_->graph(), id_),is_scalar);
if (retn != VSI_SUCCESS) {
VSILOGE("Setting scalar fail!");
}
return;
}

bool TensorImpl::SwapHandle(void* new_ptr, bool is_new_ptr_malloc_by_ovxlib,
void** old_ptr) {
bool retn = true;
Expand Down
9 changes: 9 additions & 0 deletions src/tim/vx/tensor_private.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,14 @@ class TensorImpl : public Tensor {
bool IsConstTensor() override {
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
}
bool IsScalar() override {
return vsi_nn_GetTensorIsScalar(vsi_nn_GetTensor(graph_->graph(), id_));
}
bool SaveTensorToTextByFp32(std::string filename) override;
void* ConvertTensorToData(uint8_t* tensorData) override;
float* ConvertTensorToFloat32Data() override;
void SetScalar(int8_t is_scalar) override;

GraphImpl* graph_;
vsi_nn_tensor_id_t id_;
TensorSpec spec_;
Expand Down Expand Up @@ -114,6 +119,9 @@ class TensorPlaceholder : public Tensor {
bool IsConstTensor() override {
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
}
bool IsScalar() override {
return false;
}
bool SaveTensorToTextByFp32(std::string filename) override {
(void)filename;
return false;
Expand All @@ -124,6 +132,7 @@ class TensorPlaceholder : public Tensor {
}
float* ConvertTensorToFloat32Data() override { return nullptr; }

void SetScalar(int8_t is_scalar) override { (void) is_scalar; return; }
vsi_nn_tensor_id_t id_;
TensorSpec spec_;
};
Expand Down

0 comments on commit d4187b3

Please sign in to comment.