Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoPrallel] Fix some bug #71181

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ limitations under the License. */
#include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_interface.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h"
#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
#include "paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
Expand Down Expand Up @@ -965,6 +966,40 @@ void BindVjp(pybind11::module *m) {
}
}
}
auto input_values =
vjp_res[grad_index][j].defining_op()->operands_source();
auto output_values =
vjp_res[grad_index][j].defining_op()->results();
paddle::dialect::ProcessMeshAttribute op_mesh;
if (!vjp_res[grad_index][j].defining_op()->HasAttribute(
kAttrOpDistAttr) &&
paddle::dialect::AllInputAreDist(input_values) &&
paddle::dialect::AllInputAreDist(output_values)) {
auto ctx = pir::IrContext::Instance();
if (paddle::dialect::HasDistInput(input_values, &op_mesh)) {
std::vector<pir::Attribute> dist_operand_attrs,
dist_result_attrs;
for (size_t input_id = 0; input_id < input_values.size();
++input_id) {
dist_operand_attrs.push_back(
paddle::dialect::GetTensorDistAttr(
input_values[input_id].type()));
}
for (size_t output_id = 0; output_id < output_values.size();
++output_id) {
dist_result_attrs.push_back(
paddle::dialect::GetTensorDistAttr(
output_values[output_id].type()));
}
vjp_res[grad_index][j].defining_op()->set_attribute(
kAttrOpDistAttr,
paddle::dialect::OperationDistAttribute::get(
ctx,
op_mesh,
dist_operand_attrs,
dist_result_attrs));
}
}
vjp_res[grad_index][j].set_type(inputs[idx][j].type());
}
}
Expand Down
13 changes: 12 additions & 1 deletion python/paddle/distributed/auto_parallel/static/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle.jit.dy2static.program_translator import (
ProgramTranslator,
StaticFunction,
unwrap_decorators,
)
from paddle.jit.dy2static.utils import as_not_paddle_func
from paddle.nn import Layer
Expand Down Expand Up @@ -91,7 +92,17 @@ def _train(self, inputs, labels):
self._label_vars[mode] = labels

# step 2. call inner_layer.forward
self._output_vars[mode] = self.inner_layer(*inputs)
has_labels_arg = False
if isinstance(self.inner_layer, Layer):
_, fwd_func = unwrap_decorators(self.inner_layer.forward)
if "labels" in inspect.signature(fwd_func).parameters.keys():
has_labels_arg = True
if has_labels_arg:
self._output_vars[mode] = self.inner_layer(
*inputs, labels=labels[0]
)
else:
self._output_vars[mode] = self.inner_layer(*inputs)

# step 3. calculate loss if needed
new_inputs = self._prepare(self.output_vars, labels)
Expand Down
13 changes: 8 additions & 5 deletions python/paddle/distributed/auto_parallel/static/pir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,8 @@ def reshard_single_value(program, op, operand, attr):
def reshard_combine_value(program, op, operand, attr):
prev_var = operand.source()

assert (
prev_var.get_defining_op().name() == 'builtin.combine'
), f"TensorList must be defined by builtin.combine op, but is {prev_var.get_defining_op().name()}."
if prev_var.get_defining_op().name() != 'builtin.combine':
return prev_var

combine_op = prev_var.get_defining_op()
array_attr = attr.as_array_attr()
Expand Down Expand Up @@ -135,7 +134,11 @@ def apply_partition_pass(program, block=None):
operand = op.operand(in_idx)
operand_attr = op.dist_attr.operand(in_idx)
prev_var = operand.source()
if not prev_var.is_dist() or operand_attr == prev_var.dist_attr():
if (
not prev_var.is_dist()
or operand_attr == prev_var.dist_attr()
or not operand_attr
):
continue

assert (
Expand Down Expand Up @@ -481,7 +484,7 @@ def prune_op(block):
reverse_block_ops[i].erase()
skip_idx = dtensor_to_local_idx + 1
continue
elif op.name() in partition_skip_op_list:
elif op.name() in [*partition_skip_op_list, 'builtin.split']:
can_delete = True
for val in op.results():
if not val.use_empty():
Expand Down
3 changes: 1 addition & 2 deletions python/paddle/distributed/auto_parallel/static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@

partition_skip_op_list = [
"builtin.combine",
"builtin.split",
"pd_op.pylayer",
"cf.yield",
"cf.tuple_push",
Expand Down Expand Up @@ -1116,7 +1115,7 @@ def _complete_op_dist_attr(program, block=None):
for op in block.ops:
for sub_block in op.blocks():
_complete_op_dist_attr(program, block=sub_block)
if op.name() in partition_skip_op_list:
if op.name() in [*partition_skip_op_list, 'builtin.split']:
continue

if op.dist_attr is None:
Expand Down