Skip to content

Commit

Permalink
[PIR] Refine and fix pir exe (PaddlePaddle#60443)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
zhangbo9674 authored and Wanglongzhi2001 committed Jan 7, 2024
1 parent a43fada commit a957d5e
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ IfInstruction::IfInstruction(size_t id,
GetInputIds(op, *value_exec_info, &inputs);
auto true_outside_inputs =
GetExternalInputs(&true_branch_block, *value_exec_info, &inputs);
std::vector<pir::Value> false_outside_inputs;
auto& false_branch_block = if_op.false_block();
false_outside_inputs =
auto false_outside_inputs =
GetExternalInputs(&false_branch_block, *value_exec_info, &inputs);
// NOTE(chenxi67): the variable corresponding to container value if a
// <VariableRefArray> Type. It will recursively get the ID of internal
Expand Down Expand Up @@ -107,9 +106,14 @@ IfInstruction::IfInstruction(size_t id,
}
}
InsertTuplePushContinerToOuts(&true_branch_block, *value_exec_info, &outputs);

InsertTuplePushContinerToOuts(
&if_op.false_block(), *value_exec_info, &outputs);

InsertInplacedExternalInputsToOuts(
&true_branch_block, true_outside_inputs, *value_exec_info, &outputs);
InsertInplacedExternalInputsToOuts(
&false_branch_block, false_outside_inputs, *value_exec_info, &outputs);

for (auto& item : outputs) {
auto& var_vec = item.second;
for (auto it = var_vec.begin(); it != var_vec.end();) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,33 +47,25 @@ WhileInstruction::WhileInstruction(
ValueExecutionInfo* parent_exe_info,
interpreter::ExecutionConfig execution_config)
: InstructionBase(id, place) {
op_ = op;
VLOG(6) << "finish process dist attributes";

SetKernelType(AnalyseOpFuncType(op, place));
VLOG(6) << "finish process analyse kernel type";

VLOG(6) << "finish process inputs outputs index";

PADDLE_ENFORCE(op->isa<paddle::dialect::WhileOp>(),
phi::errors::PreconditionNotMet(
"While instruction only support While op"));

op_ = op;
auto while_op = op->dyn_cast<paddle::dialect::WhileOp>();
body_block_ = &while_op.body();

cond_var_ = parent_exe_info->GetVarByValue(while_op.operand_source(0));
SetKernelType(AnalyseOpFuncType(op, place));
VLOG(6) << "finish process analyse kernel type";

cond_var_ = parent_exe_info->GetVarByValue(while_op.operand_source(0));
for (size_t i = 1; i < while_op.num_operands(); ++i) {
inputs_.push_back(
parent_exe_info->GetVarByValue(while_op.operand_source(i)));
}

for (size_t i = 0; i < while_op.num_results(); ++i) {
outputs_.push_back(parent_exe_info->GetVarByValue(while_op.result(i)));
}

body_block_ = &while_op.body();

std::unordered_map<pir::Value, std::vector<int>> inputs;
GetInputIds(op, *parent_exe_info, &inputs);
auto body_outside_inputs =
Expand All @@ -94,8 +86,10 @@ WhileInstruction::WhileInstruction(
std::vector<int> outputs_id = GetValueIds(value, *parent_exe_info);
outputs.emplace(value, outputs_id);
}
InsertTuplePushContinerToOuts(body_block_, *parent_exe_info, &outputs);
}
InsertTuplePushContinerToOuts(body_block_, *parent_exe_info, &outputs);
InsertInplacedExternalInputsToOuts(
body_block_, body_outside_inputs, *parent_exe_info, &outputs);
SetOutputs(outputs);

Scope* body_scope = &(parent_exe_info->GetScope()->NewScope());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,55 @@ void InsertTuplePushContinerToOuts(
}
}

void InsertInplacedExternalInputsToOuts(
pir::Block* block,
const std::vector<pir::Value>& external_inputs,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* outputs) {
for (auto& op : *block) {
if (op.attributes().count("is_inplace") != 0 &&
op.attributes()
.at("is_inplace")
.dyn_cast<pir::BoolAttribute>()
.data()) {
std::string op_name = op.name();
if (op.attributes().count("op_name")) {
op_name = op.attributes()
.at("op_name")
.dyn_cast<pir::StrAttribute>()
.AsString();
}
pir::OpInfo op_info =
pir::IrContext::Instance()->GetRegisteredOpInfo(op_name);
paddle::dialect::OpYamlInfoParser yaml_parser(
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>()
->get_op_info_(op_name),
paddle::dialect::IsLegacyOp(op_name));

for (size_t i = 0; i < op.num_results(); ++i) {
pir::Value value = op.result(i);
if (!IsInvalid(value)) {
VLOG(8) << "Number " << i << " result of " << op_name
<< " is not invalid, so skip build a variable.";
continue;
}
std::string value_name = yaml_parser.OutputNames()[i];
if (yaml_parser.HasInplace(value_name)) {
const std::string& inplace_name = yaml_parser.InplaceName(value_name);
pir::Value inplace_value =
op.operand_source(yaml_parser.InputName2Id().at(inplace_name));
if (std::find(external_inputs.begin(),
external_inputs.end(),
inplace_value) != external_inputs.end()) {
outputs->emplace(value,
GetValueIds(inplace_value, value_exec_info));
}
}
}
}
}
}

bool GetCondData(const phi::DenseTensor& cond) {
if (paddle::platform::is_cpu_place(cond.place())) {
return cond.data<bool>()[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ void InsertTuplePushContinerToOuts(
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* outputs);

void InsertInplacedExternalInputsToOuts(
pir::Block* block,
const std::vector<pir::Value>& external_inputs,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* outputs);

bool GetCondData(const phi::DenseTensor& cond);
} // namespace framework
} // namespace paddle
30 changes: 1 addition & 29 deletions test/legacy_test/test_while_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest

import numpy
from utils import compare_legacy_with_pt

import paddle
from paddle import base, set_flags
from paddle import base
from paddle.base import core
from paddle.base.backward import append_backward
from paddle.base.executor import Executor
Expand Down Expand Up @@ -82,7 +81,6 @@ def simple_net(self):
loss = paddle.mean(sum_result)
return loss, sum_result

# TODO(winter-wang): Support pir test in (FLAGS_enable_pir_in_executor_trace_run = False && FLAGS_new_executor_serial_run == False).
@test_with_pir_api
def test_simple_net(self):
main_program = base.Program()
Expand All @@ -92,14 +90,6 @@ def test_simple_net(self):

append_backward(loss)

if in_pir_mode():
flag_1 = "FLAGS_enable_pir_in_executor_trace_run"
flag_2 = "FLAGS_new_executor_serial_run"
os.environ[flag_1] = 'True'
os.environ[flag_2] = 'True'
set_flags({flag_1: True})
set_flags({flag_2: True})

cpu = core.CPUPlace()
exe = Executor(cpu)
d = []
Expand All @@ -111,14 +101,8 @@ def test_simple_net(self):
feed={'d0': d[0], 'd1': d[1], 'd2': d[2]},
fetch_list=[sum_result],
)
if in_pir_mode():
del os.environ[flag_1]
del os.environ[flag_2]
set_flags({flag_1: False})
set_flags({flag_2: False})
self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01)

# TODO(winter-wang): Support pir test in (FLAGS_enable_pir_in_executor_trace_run = False && FLAGS_new_executor_serial_run == False).
@test_with_pir_api
def test_simple_net_forward(self):
main_program = base.Program()
Expand All @@ -136,20 +120,8 @@ def test_simple_net_forward(self):
for i in range(3):
d.append(numpy.random.random(size=[10]).astype('float32'))

if in_pir_mode():
flag_1 = "FLAGS_enable_pir_in_executor_trace_run"
flag_2 = "FLAGS_new_executor_serial_run"
os.environ[flag_1] = 'True'
os.environ[flag_2] = 'True'
set_flags({flag_1: True})
set_flags({flag_2: True})
for _ in range(2):
exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]})
if in_pir_mode():
del os.environ[flag_1]
del os.environ[flag_2]
set_flags({flag_1: False})
set_flags({flag_2: False})

@compare_legacy_with_pt
@test_with_pir_api
Expand Down

0 comments on commit a957d5e

Please sign in to comment.