Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rm convertToSSA API,test=huawei_ascend_npu test=nvidia_tensorrt test=… #9233

Merged
merged 1 commit into from
Jul 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ lite_option(LITE_WITH_XCODE "when debug in xcode, its ON."
lite_option(LITE_WITH_ARM82_FP16 "when compile with arm v8.2 fp16, it's ON." OFF)
lite_option(LITE_WITH_ARM82_INT8_SDOT "when compile with arm v8.2 int8, it's ON." OFF)
lite_option(LITE_WITH_CODE_META_INFO "include git version in the header file." ON)
# whether convert input model which is not a DAG to SSA graph
lite_option(WITH_CONVERT_TO_SSA "whether convert input model which is not a DAG to SSA graph" ON)

# Thirdparty
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
Expand Down
3 changes: 0 additions & 3 deletions cmake/configure.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,3 @@ if (LITE_WITH_M1)
add_definitions("-DLITE_WITH_M1")
endif(LITE_WITH_M1)

if (WITH_CONVERT_TO_SSA STREQUAL ON)
add_definitions("-DWITH_CONVERT_TO_SSA")
endif(WITH_CONVERT_TO_SSA)
181 changes: 173 additions & 8 deletions lite/core/optimizer/mir/type_target_cast_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {

// record the copied node.
std::map<std::string, Node*> copied_nodes;
// record the origin node.
std::map<std::string, Node*> input_nodes;
std::vector<std::string> skip_ops = {
"while", "conditional_block", "write_back"};

Expand All @@ -48,8 +50,14 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if (!node->IsStmt() || iter != skip_ops.end()) continue;
auto inlinks = node->inlinks;
for (auto* in : inlinks) {
if (!input_nodes.count(in->AsArg().name))
input_nodes[in->AsArg().name] = in;
ComplementInputs(graph.get(), node, in, &copied_nodes);
}
auto outlinks = node->outlinks;
for (auto* out : outlinks) {
ComplementOutputs(graph.get(), node, out, &input_nodes);
}
}
}

Expand Down Expand Up @@ -78,17 +86,174 @@ void TypeTargetTransformPass::ComplementInputs(
<< " for kernel " << inst.op()->DebugString() << " "
<< *in->AsArg().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst(*in->AsArg().type,
*decl_arg_type,
in,
graph,
inst_node,
copied_nodes,
valid_places_);
AddInputIoCopyInst(*in->AsArg().type,
*decl_arg_type,
in,
graph,
inst_node,
copied_nodes,
valid_places_);
}
}

void TypeTargetTransformPass::AddOutputIoCopyInst(
const Type& from,
const Type& to,
Node* out,
SSAGraph* graph,
Node* inst_node,
const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set";
// inst -> out node(new_name) -> io_copy_op -> new_var_node(out->AsArg().name)
// So there will be a new Argument node and a new IoCopy Statement Node.
CHECK(out->IsArg());
auto new_name = string_format("%s/target_trans", out->AsArg().name.c_str());
auto* new_var_node = graph->NewArgumentNode(out->AsArg().name);

// Set the place for new var node, the target should be equal to to.target()
// The precision and layout should be equal to from.precision(), from.layout()
bool is_tensor = from.IsTensor();
if (!is_tensor) {
CHECK(from.IsTensorList()) << "only support tensor or tensor_array.";
}
if (is_tensor) {
new_var_node->AsArg().type =
LiteType::GetTensorTy(to.target(), from.precision(), from.layout());
} else {
new_var_node->AsArg().type =
LiteType::GetTensorListTy(to.target(), from.precision(), from.layout());
}
auto* io_copy_inst = graph->NewInstructNode();
std::string io_copy_type = "io_copy";
// create Op and kernels.
auto io_copy_op = LiteOpRegistry::Global().Create(io_copy_type);
CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed";
// CHECK(io_copy_op);
// Create the new var manually.
inst_node->AsStmt().op()->scope()->Var(new_name);

// Create IoCopy Instruction.
cpp::OpDesc op_desc;
op_desc.SetType(io_copy_type);
if (is_tensor) {
op_desc.SetInput("Input", {new_name});
op_desc.SetOutput("Out", {out->AsArg().name});
} else {
op_desc.SetInput("InputArray", {new_name});
op_desc.SetOutput("OutArray", {out->AsArg().name});
}
io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto kernels = io_copy_op->CreateKernels(valid_places);
bool is_found = false;
std::vector<std::unique_ptr<KernelBase>> selected_kernels;
for (auto& kernel : kernels) {
const Type* in_arg_ty = nullptr;
const Type* out_arg_ty = nullptr;
if (is_tensor) {
in_arg_ty = kernel->GetInputDeclType("Input");
out_arg_ty = kernel->GetOutputDeclType("Out");
} else {
in_arg_ty = kernel->GetInputDeclType("InputArray");
out_arg_ty = kernel->GetOutputDeclType("OutArray");
}

VLOG(4) << "------ kernel info -------";
VLOG(4) << "*in_arg_ty(io_copy kernel input):" << *in_arg_ty;
VLOG(4) << "from(last kernel output):" << from;
VLOG(4) << "out_arg_ty(io_copy kernel output):" << *out_arg_ty;
VLOG(4) << "to:" << to << "\n";

if (TypeCompatible(*in_arg_ty, from) &&
TargetCompatibleTo(*out_arg_ty, to)) {
VLOG(4) << "picked";
is_found = true;
}

if (is_found) {
selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel
io_copy_inst->AsStmt(
io_copy_type, std::move(selected_kernels), io_copy_op);
break;
}
VLOG(4) << "not picked";
}

CHECK(is_found) << "Can't find a io_copy kernel for io_copy op: " << from
<< ":" << inst_node->AsStmt().op_info()->Type() << " -> "
<< to << ":" << out->AsArg().name;
// Add new link, inst -> var -> io_copy_op -> new_var_node
DirectedLink(out, io_copy_inst);
DirectedLink(io_copy_inst, new_var_node);

// Update the original instruction OpDesc.
// Update its output var name to the io_copy_output_name
auto* inst_node_op_desc = inst_node->AsStmt().op()->mutable_op_info();
for (auto& op_output : *inst_node_op_desc->mutable_outputs()) {
for (auto& var_name : op_output.second)
if (var_name == out->AsArg().name) var_name = new_name;
}
// Update the input name of Ops whose input var is out var node
for (auto& op : out->outlinks) {
if (!op->IsStmt()) continue;
auto* op_desc = op->AsStmt().op()->mutable_op_info();
for (auto& op_input : *op_desc->mutable_inputs())
for (auto& var_name : op_input.second)
if (var_name == out->AsArg().name) var_name = new_name;
}
// reset opdesc and update kernel information
out->AsArg().name = new_name;
auto original_selected_kernel =
std::move(inst_node->AsStmt().kernels().front());
auto update_op_info = *inst_node->AsStmt().op_info();
inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places());
inst_node->AsStmt().kernels().clear();
inst_node->AsStmt().kernels().emplace_back(
std::move(original_selected_kernel));

for (auto& kernel : inst_node->AsStmt().kernels()) {
VLOG(4) << "kernel info: " << kernel->name();
inst_node->AsStmt().op()->AttachKernel(kernel.get());
}

graph->CheckValid();
}

void TypeTargetTransformPass::ComplementOutputs(
SSAGraph* graph,
Node* inst_node,
Node* out,
std::map<std::string, Node*>* input_nodes) {
// If this output is out of date.
if (inst_node->outlinks.end() ==
std::find(inst_node->outlinks.begin(), inst_node->outlinks.end(), out))
return;

CHECK(inst_node->IsStmt());
auto& inst = inst_node->AsStmt();
VLOG(3) << "found Target tensor: " << out->AsArg().name;
CHECK(out->IsRoleSet());
CHECK(out->IsArg());
CHECK(out->AsArg().type);
if (input_nodes->count(out->AsArg().name)) {
if (!TargetCompatibleTo(
*out->AsArg().type,
*input_nodes->at(out->AsArg().name)->AsArg().type)) {
VLOG(3) << "found Output Target unmatched tensor: " << out->AsArg().name
<< " for kernel " << inst.op()->DebugString() << " "
<< *out->AsArg().type << " -> "
<< *(input_nodes->at(out->AsArg().name))->AsArg().type;
AddOutputIoCopyInst(*out->AsArg().type,
*input_nodes->at(out->AsArg().name)->AsArg().type,
out,
graph,
inst_node,
valid_places_);
}
}
}

void TypeTargetTransformPass::AddIoCopyInst(
void TypeTargetTransformPass::AddInputIoCopyInst(
const Type& from,
const Type& to,
Node* in,
Expand Down
26 changes: 19 additions & 7 deletions lite/core/optimizer/mir/type_target_cast_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,25 @@ class TypeTargetTransformPass : public ProgramPass {
Node* in,
std::map<std::string, Node*>* copied_nodes);

void AddIoCopyInst(const Type& from,
const Type& to,
Node* in,
SSAGraph* graph,
Node* inst_node,
std::map<std::string, Node*>* copied_nodes,
const std::vector<Place>& valid_places);
void ComplementOutputs(SSAGraph* graph,
Node* inst_node,
Node* out,
std::map<std::string, Node*>* input_nodes);

void AddInputIoCopyInst(const Type& from,
const Type& to,
Node* in,
SSAGraph* graph,
Node* inst_node,
std::map<std::string, Node*>* copied_nodes,
const std::vector<Place>& valid_places);

void AddOutputIoCopyInst(const Type& from,
const Type& to,
Node* out,
SSAGraph* graph,
Node* inst_node,
const std::vector<Place>& valid_places);

void SetValidPlaces(const std::vector<Place>& valid_places);

Expand Down
3 changes: 0 additions & 3 deletions lite/model_parser/model_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,6 @@ void LoadModelPb(const std::string &model_dir,
pb::ProgramDesc pb_prog(&pb_proto_prog);
// Transform to cpp::ProgramDesc
TransformProgramDescAnyToCpp(pb_prog, cpp_prog);
#ifdef WITH_CONVERT_TO_SSA
general::ssa::ConvertToSSA(cpp_prog);
#endif

// Load params data from file.
// NOTE: Only main block be used now.
Expand Down