Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR / Dy2static】Fix ir program deconstruct bugs. #59764

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/eager/to_static/run_program_op_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ inline void pir_run_program_ad_func(
const std::vector<paddle::Tensor>& params,
std::vector<paddle::Tensor*>& out, // NOLINT
std::vector<paddle::framework::Scope*>& step_scope, // NOLINT
const std::vector<PyObject*>& blocks_to_hold,
const paddle::framework::AttributeMap& attrs) {
// Prepare Autograd Meta
VLOG(2) << "start run pir run_program ad function.";
Expand Down Expand Up @@ -263,6 +264,7 @@ inline void pir_run_program_ad_func(
grad_node = std::make_shared<PirGradNodeRunProgram>(1, 2);
grad_node->GetMiddle().resize(middle_size);
grad_node->GetOutputs().resize(output_size);
grad_node->SetBlocks(blocks_to_hold);
for (size_t i = 0; i < middle_size; ++i) {
grad_node->GetMiddle()[i] =
paddle::Tensor(std::make_shared<phi::DenseTensor>());
Expand Down
106 changes: 62 additions & 44 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <Python.h>
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/tensor_wrapper.h"
Expand Down Expand Up @@ -459,21 +460,16 @@ inline void PirRunProgramAPI(
auto param_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp"));

auto *forward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("forward_global_block"));
auto *backward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block"));

auto *forward_program =
forward_global_block->GetParentOp()->GetParentProgram();
auto *forward_program = reinterpret_cast<::pir::Program *>(
PADDLE_GET_CONST(::pir::Block *, attrs.at("forward_program")));
auto *backward_program = reinterpret_cast<::pir::Program *>(
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_program")));

if (FLAGS_print_ir) {
std::ostringstream print_stream;
print_stream << "ForwardProgram is :\n";
forward_program->Print(print_stream);
if (!is_test) {
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();
print_stream << "BackwardProgram is:\n";
backward_program->Print(print_stream);
} else {
Expand Down Expand Up @@ -502,9 +498,9 @@ inline void PirRunProgramAPI(
<< program_id;
// Step 1. share input_vars & parameters into scope
details::ShareTensorsIntoScopeByValue(
forward_global_block, x, input_values, global_inner_scope);
forward_program->block(), x, input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
forward_global_block, params, param_values, global_inner_scope);
forward_program->block(), params, param_values, global_inner_scope);
// Step 2. create new interpretercore
auto passed_kernel_program =
paddle::framework::ApplyIrPass(forward_program, place);
Expand All @@ -527,20 +523,20 @@ inline void PirRunProgramAPI(
// *backward_program);

// update interpretercore skip_gc_var
auto skip_names =
details::GetNameFromValue(forward_global_block, middle_values, false);
auto skip_names = details::GetNameFromValue(
forward_program->block(), middle_values, false);
auto skip_names_set =
std::set<std::string>(skip_names.begin(), skip_names.end());
auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>,
attrs.at("no_need_buffers"));
auto no_need_buffer_names = details::GetNameFromValue(
forward_global_block, no_need_buffer_values, false);
forward_program->block(), no_need_buffer_values, false);
for (auto &name : no_need_buffer_names) {
VLOG(4) << "Find no need buffer vars with name:" << name;
skip_names_set.erase(name);
}
skip_names =
details::GetNameFromValue(forward_global_block, output_values, false);
skip_names = details::GetNameFromValue(
forward_program->block(), output_values, false);
skip_names_set.insert(skip_names.begin(), skip_names.end());
details::print_collection(skip_names_set);
interpreter_core->SetSkipGcVars(skip_names_set);
Expand All @@ -567,9 +563,9 @@ inline void PirRunProgramAPI(
interpreter_core = cached_value.core_;
// Step 2. update scope for cache interpretercore
details::ShareTensorsIntoScopeByValue(
forward_global_block, x, input_values, global_inner_scope);
forward_program->block(), x, input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
forward_global_block, params, param_values, global_inner_scope);
forward_program->block(), params, param_values, global_inner_scope);
// TODO(xiongkun): new ir how to build scope.
// if (interpreter_core->GetVariableScope()->GetMutableScope() !=
// global_inner_scope) {
Expand All @@ -580,7 +576,7 @@ inline void PirRunProgramAPI(
}

// interpretercore run
if (!forward_global_block->empty()) {
if (!forward_program->block()->empty()) {
paddle::platform::RecordEvent record_event(
"interpreter_core_run",
paddle::platform::TracerEventType::UserDefined,
Expand All @@ -593,9 +589,9 @@ inline void PirRunProgramAPI(
"fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1);
// Get Output, and Middle Outputs
details::ShareTensorsFromScopeByValue(
forward_global_block, out, output_values, global_inner_scope);
forward_program->block(), out, output_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(
forward_global_block, middles, middle_values, global_inner_scope);
forward_program->block(), middles, middle_values, global_inner_scope);

VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front());

Expand Down Expand Up @@ -1046,10 +1042,8 @@ inline void PirRunProgramGradAPI(

VLOG(4) << "global_inner_scope:" << global_inner_scope;

auto *backward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block"));
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();
auto *backward_program = reinterpret_cast<::pir::Program *>(
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_program")));

auto output_grad_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo_g"));
Expand All @@ -1069,18 +1063,22 @@ inline void PirRunProgramGradAPI(
details::Trans2ContiguousTensorsInplace(out_grad);

// share x, param, middles, output_grads, out into scope.
details::ShareTensorsIntoScopeByValue(backward_program->block(),
out_grad,
output_grad_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, out_grad, output_grad_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, x, forward_input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_global_block,
backward_program->block(), x, forward_input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_program->block(),
middles,
forward_middle_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_program->block(),
out,
forward_output_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, out, forward_output_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, params, parameter_values, global_inner_scope);
backward_program->block(), params, parameter_values, global_inner_scope);

auto &interpretercore_info_cache =
paddle::framework::InterpreterCoreInfoCache::Instance();
Expand Down Expand Up @@ -1134,11 +1132,11 @@ inline void PirRunProgramGradAPI(

// get all eager gc vars
std::set<std::string> skip_eager_delete_vars;
auto skip_names =
details::GetNameFromValue(backward_global_block, x_grad_values, false);
auto skip_names = details::GetNameFromValue(
backward_program->block(), x_grad_values, false);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
skip_names =
details::GetNameFromValue(backward_global_block, p_grad_values, false);
skip_names = details::GetNameFromValue(
backward_program->block(), p_grad_values, false);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
interpreter_core->SetSkipGcVars(skip_eager_delete_vars);
interpretercore_info_cache.UpdateSkipEagerDeleteVars(
Expand Down Expand Up @@ -1166,14 +1164,11 @@ inline void PirRunProgramGradAPI(

if (interpreter_core->GetVariableScope()->GetMutableScope() !=
global_inner_scope) {
// update scope (TODO(xiongkun): do we need this??)
// details::BuildScopeByBlock(
// *interpreter_core.get(), *backward_global_block, global_inner_scope);
interpreter_core->reset_scope(global_inner_scope);
}
}

if (!backward_global_block->empty()) {
if (!backward_program->block()->empty()) {
paddle::platform::RecordEvent record_event(
"interpreter_core_run",
paddle::platform::TracerEventType::UserDefined,
Expand All @@ -1188,9 +1183,11 @@ inline void PirRunProgramGradAPI(
"fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1);
// Step 4. get outputs
details::ShareTensorsFromScopeByValue(
backward_global_block, x_grad, x_grad_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(
backward_global_block, params_grad, p_grad_values, global_inner_scope);
backward_program->block(), x_grad, x_grad_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(backward_program->block(),
params_grad,
p_grad_values,
global_inner_scope);
VLOG(4) << "after backward gc all vars";
global_inner_scope->SetCanReused(true);
details::GcScope(global_inner_scope);
Expand Down Expand Up @@ -1403,6 +1400,7 @@ class PirGradNodeRunProgram : public egr::GradNodeBase {
}
middles_.clear();
outputs_.clear();
ClearBlocks();
}
}
// Functor: perform backward computations
Expand Down Expand Up @@ -1463,7 +1461,10 @@ class PirGradNodeRunProgram : public egr::GradNodeBase {
params_grad_ptr,
place_hash_key_);
VLOG(3) << "End Eager Backward Node: PirGradNodeRunProgram";

egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&x_grad,
this->OutputMeta()[0]);
egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&params_grad,
this->OutputMeta()[1]);
*executed_ = true;
return {x_grad, params_grad};
}
Expand Down Expand Up @@ -1553,16 +1554,33 @@ class PirGradNodeRunProgram : public egr::GradNodeBase {
std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node = std::shared_ptr<PirGradNodeRunProgram>(
new PirGradNodeRunProgram(*this));
copied_node->SetBlocks(blocks_);
return copied_node;
}

public:
void SetBlocks(const std::vector<PyObject *> blocks) {
blocks_ = blocks;
for (auto &obj : blocks_) {
VLOG(4) << "program is not NULL, we increase the program ref counter.";
Py_INCREF(obj);
}
}

void ClearBlocks() {
for (auto &obj : blocks_) {
Py_DECREF(obj);
}
}

private:
// TensorWrappers
std::vector<paddle::Tensor> x_;
std::vector<paddle::Tensor> params_;
std::vector<paddle::Tensor> middles_;
std::vector<paddle::Tensor> outputs_;
std::vector<paddle::framework::Scope *> step_scope_;
std::vector<PyObject *> blocks_;

// Attribute Map
paddle::framework::AttributeMap attrs_;
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/pybind/eager_legacy_custom_python_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,13 @@ static PyObject *pir_eager_api_run_program(PyObject *self,
// TODO(zengjinle): support CUDA Graph on eager mode
VLOG(1) << "Start Pir ConstructAttrMapFromPyArgs";

std::vector<PyObject *> block_objs;
ConstructAttrMapForRunProgram(
"run_program", args, 5, PyTuple_GET_SIZE(args), attrs);
"run_program", args, 5, PyTuple_GET_SIZE(args), block_objs, attrs);

VLOG(1) << "Finish Pir ConstructAttrMapFromPyArgs";
tstate = PyEval_SaveThread();
pir_run_program_ad_func(X, Params, Out, OutScope, attrs);
pir_run_program_ad_func(X, Params, Out, OutScope, block_objs, attrs);
PyEval_RestoreThread(tstate);
tstate = nullptr;
Py_RETURN_NONE;
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,7 @@ void ConstructAttrMapForRunProgram(
PyObject* args,
ssize_t attr_start,
ssize_t attr_end,
std::vector<PyObject*>& blocks_to_hold, // NOLINT
paddle::framework::AttributeMap& attrs) { // NOLINT
PADDLE_ENFORCE_EQ((attr_end - attr_start) % 2,
0,
Expand Down Expand Up @@ -1021,11 +1022,11 @@ void ConstructAttrMapForRunProgram(

if (std::set<std::string>({"cuda_graph_capture_mode"}).count(key)) {
CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"global_block",
"forward_global_block",
"backward_global_block"})
} else if (std::set<std::string>(
{"global_block", "forward_program", "backward_program"})
.count(key)) {
CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos);
blocks_to_hold.push_back(obj);
} else if (std::set<std::string>({"is_test", "use_interpretorcore"})
.count(key)) {
CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/op_function_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ void ConstructAttrMapForRunProgram(
PyObject* args,
ssize_t attr_start,
ssize_t attr_end,
std::vector<PyObject*>& blocks_to_hold, // NOLINT
paddle::framework::AttributeMap& attrs); // NOLINT

unsigned long GetUnsignedLongFromArgs( // NOLINT
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,10 +866,10 @@ def _prune_unused_params(self, program):

def _prepare_attributes(self):
attrs = [
'forward_global_block',
self.program.forward_program.global_block(),
'backward_global_block',
self.program.backward_program.global_block(),
'forward_program',
self.program.forward_program,
'backward_program',
self.program.backward_program,
'is_test',
not self.training,
'program_id',
Expand Down
6 changes: 5 additions & 1 deletion test/dygraph_to_static/test_no_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import unittest

import numpy
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import (
Dy2StTestBase,
test_legacy_and_pt_and_pir,
)

import paddle

Expand All @@ -33,6 +36,7 @@ def main_func(x, index):


class TestNoGradientCase(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_no_gradient(self):
paddle.disable_static()
x = paddle.randn([10, 3])
Expand Down