Skip to content

Commit

Permalink
Check md5sum for const tensor
Browse files Browse the repository at this point in the history
Tensor with same md5 will be created once and shared

Type: New feature

Signed-off-by: Chen Xin <[email protected]>
  • Loading branch information
Chen Xin committed Apr 12, 2023
1 parent 2789071 commit 47b3587
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 2 deletions.
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ option(TIM_VX_ENABLE_PLATFORM "Enable multi devices support"
option(TIM_VX_ENABLE_PLATFORM_LITE "Enable lite multi-device support" OFF)
option(TIM_VX_ENABLE_GRPC "Enable gPRC support" OFF)
option(TIM_VX_DBG_ENABLE_TENSOR_HNDL "Enable built-in tensor from handle: use malloced memory instead of VideoMemory by kernel driver" ON)
option(TIM_VX_ENABLE_MD5_CHECK "Enable md5sum check for const tensor" ON)

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
Expand Down Expand Up @@ -46,6 +47,11 @@ if(${TIM_VX_ENABLE_40BIT})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DVSI_40BIT_VA_SUPPORT")
endif()

if(${TIM_VX_ENABLE_MD5_CHECK})
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DENABLE_MD5_CHECK")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_MD5_CHECK")
endif()

if(${TIM_VX_ENABLE_CUSTOM_OP})
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DTIM_VX_ENABLE_CUSTOM_OP")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTIM_VX_ENABLE_CUSTOM_OP")
Expand Down Expand Up @@ -93,6 +99,9 @@ if(TIM_VX_ENABLE_GRPC)
include(cmake/gRPC.cmake)
endif()

if(TIM_VX_ENABLE_MD5_CHECK)
find_package(OpenSSL REQUIRED)
endif()
add_subdirectory("src/tim")

if(TIM_VX_BUILD_EXAMPLES)
Expand Down
3 changes: 2 additions & 1 deletion src/tim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${op_as_flags}")
add_library(${TARGET_NAME} ${${TARGET_NAME}_SRCS})
target_include_directories(${TARGET_NAME} PRIVATE ${INC_DIRS})
target_link_libraries(${TARGET_NAME} PUBLIC
-Wl,--no-whole-archive ${OVXDRV_LIBRARIES} ${EXTERNAL_LIBS})
-Wl,--no-whole-archive ${OVXDRV_LIBRARIES} ${EXTERNAL_LIBS} ${OPENSSL_CRYPTO_LIBRARY})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations -Wno-deprecated-declarations ")

if(${TIM_VX_USE_EXTERNAL_OVXLIB})
#-Wl,--whole-archive should not applied to external library, but only for shared library
Expand Down
50 changes: 50 additions & 0 deletions src/tim/transform/layout_inference_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,56 @@ TEST(GroupedConv2d, kernel_bigger_than_input_SAME) {
infer_input->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float));
infer_graph->Run();

std::vector<float> output(golden.size());
EXPECT_TRUE(infer_output->CopyDataFromTensor(output.data()));
EXPECT_EQ(golden, output);
}

TEST(FC, share_const_tensor) {
auto ctx = tim::vx::Context::Create();
auto src_graph = ctx->CreateGraph();

tim::vx::ShapeType input_shape({2, 1});
tim::vx::ShapeType kernel_shape({2, 2});
tim::vx::ShapeType bias_shape({2});
tim::vx::ShapeType output_shape({2, 1});
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec kernel_spec(tim::vx::DataType::FLOAT32, kernel_shape,
tim::vx::TensorAttribute::CONSTANT);
tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, bias_shape,
tim::vx::TensorAttribute::CONSTANT);
tim::vx::TensorSpec tran_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::TRANSIENT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
std::vector<float> in_data = {1,4,};
std::vector<float> weight = {-3,3,2,1,};
std::vector<float> bias = {0.1, 0.4,};
std::vector<float> golden = {-8, 25};
auto input_tensor = src_graph->CreateTensor(input_spec);
auto weight_tensor = src_graph->CreateTensor(kernel_spec, weight.data());
auto bias_tensor = src_graph->CreateTensor(bias_spec, bias.data());
auto tran_tensor = src_graph->CreateTensor(tran_spec);
auto output_tensor = src_graph->CreateTensor(output_spec);

auto op1 = src_graph->CreateOperation<tim::vx::ops::FullyConnected>(0,2);
(*op1).BindInputs({input_tensor, weight_tensor, bias_tensor}).BindOutputs({tran_tensor});

auto op2 = src_graph->CreateOperation<tim::vx::ops::FullyConnected>(0,2);
(*op2).BindInputs({tran_tensor, weight_tensor, bias_tensor}).BindOutputs({output_tensor});
// Do layout inference
auto transform = tim::transform::LayoutInference(src_graph, ctx);
auto infer_graph = transform.first;
auto graph_io_map = transform.second;
infer_graph->Compile();

auto infer_input = graph_io_map[src_graph->InputsTensor()[0]];
auto infer_output = graph_io_map[src_graph->OutputsTensor()[0]];

infer_input->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float));
infer_graph->Run();

std::vector<float> output(golden.size());
EXPECT_TRUE(infer_output->CopyDataFromTensor(output.data()));
EXPECT_EQ(golden, output);
Expand Down
54 changes: 53 additions & 1 deletion src/tim/vx/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
#include "tim/vx/graph.h"
#include <algorithm>

#ifdef ENABLE_MD5_CHECK
#include <openssl/md5.h>
#endif

#include "context_private.h"
#include "graph_private.h"
#include "op_impl.h"
Expand Down Expand Up @@ -55,6 +59,32 @@ GraphImpl::GraphImpl(ContextImpl* context, const CompileOption& options)

GraphImpl::~GraphImpl() { vsi_nn_ReleaseGraph(&graph_); }

std::map<std::string, std::shared_ptr<tim::vx::Tensor>>& GraphImpl::GetMd5TensorMap() {
return md5_tensor_;
}
#ifdef ENABLE_MD5_CHECK
#define MD5_SECRET_LEN_16 (16)
#define MD5_BYTE_STRING_LEN (4)
const std::string GraphImpl::commonMd5Secret32(const std::string& src) {
MD5_CTX ctx;

std::string md5String;
unsigned char md[MD5_SECRET_LEN_16] = {0};
char tmp[MD5_BYTE_STRING_LEN] = {0};

MD5_Init(&ctx);
MD5_Update(&ctx, src.c_str(), src.size());
MD5_Final(md, &ctx);

for (int i = 0; i < 16; ++i) {
memset(tmp, 0x00, sizeof(tmp));
snprintf(tmp, sizeof(tmp), "%02X", md[i]);
md5String += tmp;
}
return md5String;
}
#endif

vsi_nn_graph_t* GraphImpl::graph() { return graph_; }

void GraphImpl::AddInput(vsi_nn_tensor_id_t id) {
Expand Down Expand Up @@ -135,7 +165,29 @@ void GraphImpl::PrintGraph() const { vsi_nn_PrintGraph(this->graph_); }

std::shared_ptr<Tensor> GraphImpl::CreateTensor(const TensorSpec& spec,
const void* data) {
auto tensor = std::make_shared<TensorImpl>(this, spec, data);
std::shared_ptr<tim::vx::Tensor> tensor;
#ifdef ENABLE_MD5_CHECK
if (spec.attr_ & TensorAttribute::CONSTANT && data != NULL) {
std::string md5;
uint32_t data_size = 1;
for (auto it = spec.shape_.begin(); it != spec.shape_.end(); ++it) {
data_size *= *it;
}
if (data_size < 512) {
md5 = commonMd5Secret32(std::string((const char*)data, data_size));
} else {
md5 = commonMd5Secret32(std::string((const char*)data, 512)); //Take first 512 bytes
}
if (GetMd5TensorMap().find(md5) != GetMd5TensorMap().end()) {
tensor = GetMd5TensorMap()[md5];
} else {
tensor = std::make_shared<TensorImpl>(this, spec, data);
GetMd5TensorMap()[md5] = tensor;
}
return tensor;
}
#endif
tensor = std::make_shared<TensorImpl>(this, spec, data);
if (spec.attr_ & TensorAttribute::INPUT) {
this->AddInput(tensor);
this->AddInput(tensor->GetId());
Expand Down
4 changes: 4 additions & 0 deletions src/tim/vx/graph_private.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "tim/vx/graph.h"

#include <vector>
#include <string>
#include <mutex>
#include <utility>
#include <map>
Expand All @@ -44,6 +45,8 @@ class GraphImpl : public Graph {
GraphImpl(ContextImpl* context, const CompileOption& options = CompileOption::DefaultOptions);
~GraphImpl();

const std::string commonMd5Secret32(const std::string& src);
std::map<std::string, std::shared_ptr<tim::vx::Tensor>>& GetMd5TensorMap();
/// Return the low-level graph object
vsi_nn_graph_t* graph();
void AddInput(vsi_nn_tensor_id_t id);
Expand Down Expand Up @@ -97,6 +100,7 @@ class GraphImpl : public Graph {
int32_t not_consumed_output_cnt_;
std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>> tensor_consumers_;
std::map<std::shared_ptr<Tensor>, std::shared_ptr<Operation>> tensor_producer_;
std::map<std::string, std::shared_ptr<tim::vx::Tensor>> md5_tensor_;

CompileOption options_;
private:
Expand Down

0 comments on commit 47b3587

Please sign in to comment.