From 969907a9aa079d4693c575b7e3271b8e4867c192 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 18 Feb 2025 09:01:32 +0000 Subject: [PATCH 1/8] fix --- python/paddle/distributed/auto_parallel/static/helper.py | 2 +- python/paddle/distributed/auto_parallel/static/pir_pass.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/static/helper.py b/python/paddle/distributed/auto_parallel/static/helper.py index 4b33190e964313..61d99924fcc6be 100644 --- a/python/paddle/distributed/auto_parallel/static/helper.py +++ b/python/paddle/distributed/auto_parallel/static/helper.py @@ -91,7 +91,7 @@ def _train(self, inputs, labels): self._label_vars[mode] = labels # step 2. call inner_layer.forward - self._output_vars[mode] = self.inner_layer(*inputs) + self._output_vars[mode] = self.inner_layer(*inputs, labels=labels[0]) # step 3. calculate loss if needed new_inputs = self._prepare(self.output_vars, labels) diff --git a/python/paddle/distributed/auto_parallel/static/pir_pass.py b/python/paddle/distributed/auto_parallel/static/pir_pass.py index 602fae14c92f6c..64e56ace2499e1 100644 --- a/python/paddle/distributed/auto_parallel/static/pir_pass.py +++ b/python/paddle/distributed/auto_parallel/static/pir_pass.py @@ -135,7 +135,11 @@ def apply_partition_pass(program, block=None): operand = op.operand(in_idx) operand_attr = op.dist_attr.operand(in_idx) prev_var = operand.source() - if not prev_var.is_dist() or operand_attr == prev_var.dist_attr(): + if ( + not prev_var.is_dist() + or operand_attr == prev_var.dist_attr() + or not operand_attr + ): continue assert ( From 1197fbd69b3ae6fb19b7dc3c0cbd48a7bc4f37f3 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 19 Feb 2025 02:45:07 +0000 Subject: [PATCH 2/8] fix --- .../auto_parallel/static/helper.py | 658 +----------------- 1 file changed, 1 insertion(+), 657 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/static/helper.py b/python/paddle/distributed/auto_parallel/static/helper.py index 61d99924fcc6be..a9cc79cc9d7f19 100644 --- a/python/paddle/distributed/auto_parallel/static/helper.py +++ b/python/paddle/distributed/auto_parallel/static/helper.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,659 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import copy -import inspect -import logging -from collections import defaultdict - -import paddle -from paddle import core -from paddle.jit import not_to_static, to_static -from paddle.jit.dy2static.program_translator import ( - ProgramTranslator, - StaticFunction, -) -from paddle.jit.dy2static.utils import as_not_paddle_func -from paddle.nn import Layer -from paddle.static import Parameter, global_scope, program_guard -from paddle.static.amp.fp16_utils import ( - DEFAULT_AMP_OPTIONS, - prepare_op_amp_options, -) - -from .converter import Converter -from .dist_attribute import TensorDistAttr -from .process_group import get_world_process_group -from .utils import get_logger, to_list - - -class ProxyLayer(Layer): - """ - ProxyLayer implements all logic for converting dygraph model into - static Program IR. Meanwhile, it provides conventional interfaces for - auto parallel to visit feed/fetch/loss/metric variables. - """ - - def __init__(self, layer, loss_func, metrics): - super().__init__() - # NOTE: All verify logics are finished in Engine.Prepare - self.inner_layer = layer - self.loss_func = loss_func - self.metrics = metrics - # train / eval / predict - self.mode = None - - # generated program vars - self._input_vars = defaultdict(list) - self._label_vars = defaultdict(list) - self._output_vars = defaultdict(list) - self._loss_vars = defaultdict(list) - self._loss_names = defaultdict(list) - self._metric_vars = defaultdict(list) - - # Consider ProxyLayer as not Paddle inner function because it contains - # user-defined layer. - for fn_name in [ - "_train", - "_eval", - "_predict", - "call_loss", - "call_metrics", - ]: - as_not_paddle_func( - f"{inspect.getmodule(ProxyLayer).__name__}.ProxyLayer.{fn_name}" - ) - - @paddle.jit.not_to_static - def append_loss_to_shadow_output(self, mode): - name = paddle.utils.unique_name.generate('loss') - paddle._C_ops.set_persistable_value(self._loss_vars[mode], name) - self._loss_names[mode] = name - - def _train(self, inputs, labels): - """ - Train process of inner_layer with forward/loss/metric logic. - """ - # step 1. save feed variables of Program - mode = 'train' - self._input_vars[mode] = inputs - self._label_vars[mode] = labels - - # step 2. call inner_layer.forward - self._output_vars[mode] = self.inner_layer(*inputs, labels=labels[0]) - - # step 3. calculate loss if needed - new_inputs = self._prepare(self.output_vars, labels) - self._loss_vars[mode] = self.call_loss(new_inputs) - if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[ - "FLAGS_enable_pir_api" - ]: - self.append_loss_to_shadow_output(mode) - - # step 4. calculate metrics if needed - self._metric_vars[mode] = self.call_metrics(new_inputs) - - def _eval(self, inputs, labels): - """ - Evaluate process of inner_layer with forward/loss/metric logic. - """ - # TODO(dev): we can reuse codes with self._train after making - # sure if they can. - - # step 1. save feed variables of Program - mode = 'eval' - self._input_vars[mode] = inputs - self._label_vars[mode] = labels - - # step 2. call inner_layer.forward - self._output_vars[mode] = self.inner_layer(*inputs) - - # step 3. calculate loss if needed - new_inputs = self._prepare(self.output_vars, labels) - self._loss_vars[mode] = self.call_loss(new_inputs) - if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[ - "FLAGS_enable_pir_api" - ]: - self.append_loss_to_shadow_output(mode) - - # step 4. calculate metrics if needed - self._metric_vars[mode] = self.call_metrics(new_inputs) - - def _predict(self, inputs, labels): - """ - Predict process of inner_layer with forward logic. - """ - # step 1. save feed variables of Program - mode = 'predict' - self._input_vars[mode] = inputs - self._label_vars[mode] = labels - - # step 2. call inner_layer.forward - self._output_vars[mode] = self.inner_layer(*inputs) - - @not_to_static - def _prepare(self, outputs, labels): - """ - Concat outputs and labels as a single list - - NOTE(dev): We use @not_to_static to avoid AST Analysis. - """ - return to_list(outputs) + to_list(labels) - - def call_loss(self, inputs): - """ - Apply Loss Function on outputs and labels. - - Args: - inputs: List[Variable] - - Returns: List[Variable] - """ - res = [] - if self.loss_func is not None: - res = self.loss_func(*inputs) - return res - - def call_metrics(self, inputs): - """ - Apply Metrics Function on outputs and labels. - - Args: - inputs: List[Variable] - - Returns: List[Variable] - """ - outs = [] - for metric in self.metrics: - outs.append(to_list(metric.compute(*inputs))) - - return outs - - def set_mode(self, mode): - self.mode = mode - self.training = mode == 'train' - - def clone(self): - return ProxyLayer(self.inner_layer, self.loss_func, self.metrics) - - @property - def input_vars(self): - return self._input_vars[self.mode] - - @property - def label_vars(self): - return self._label_vars[self.mode] - - @property - def output_vars(self): - return self._output_vars[self.mode] - - @property - def loss_vars(self): - return self._loss_vars[self.mode] - - @property - def loss_names(self): - return self._loss_names[self.mode] - - @property - def metric_vars(self): - return self._metric_vars[self.mode] - - @property - def startup_program(self): - return self.inner_layer._startup_program() - - -class BuildInfo: - def __init__(self): - self.clear() - - def has_cache(self, mode, update=False): - is_cache = self.states[mode] - if update: - self.cache(mode) - return is_cache - - def cache(self, mode): - self.states[mode] = True - - def clear(self): - self.states = defaultdict(bool) - - -class ProgramHelper: - """ - A Helper class for Engine to provides different Program IR according specified 'mode'. - """ - - def __init__(self, layer, loss_func, metrics, inputs_spec, labels_spec): - # original model config information - # TODO(Aurelius84): Implement append_backward and optimizer in ProxyLayer - # after distribute engine satisfy basic condition. - self.proxy_layer = ProxyLayer(layer, loss_func, metrics) - self.inputs_spec = inputs_spec - self.labels_spec = labels_spec - - self.build_info = BuildInfo() - self._logger = get_logger(logging.INFO) - self.lazy_init = False - self._all_params_dist_attr = {} - - def reset(self): - """ - Reset all state of current Object. - """ - self.build_info.clear() - self.proxy_layer = self.proxy_layer.clone() - - def build_program(self, mode): - """ - Convert dygraph model into static Program IR. - """ - assert mode in ['train', 'eval', 'predict'] - self.proxy_layer.set_mode(mode) - # skip if we has already built program. - if self.build_info.has_cache(mode, True): - self._logger.info( - f"Already build program with mode = {mode}, use cached program." - ) - return - - self._logger.info(f"start to build program for mode = {mode}.") - input_spec = [self.inputs_spec, self.labels_spec] - static_func = to_static( - self.static_func(), input_spec=input_spec, full_graph=True - ) - - func_name = '_' + mode - setattr(self.proxy_layer, func_name, static_func) - - # NOTE(dev): Because @to_static is a Lazy mechanism, so we explicitly call this to trigger - # generating Program IR immediately. - concrete_program = getattr(self.proxy_layer, func_name).concrete_program - - # TODO(zhiqiu): prepare_op_amp_options is not supported for PIR program - # It will to use dynamic-static unified amp in pir program, and there is - # no need to fit for prepare_op_amp_options - if not paddle.base.framework.get_flags("FLAGS_enable_pir_api")[ - "FLAGS_enable_pir_api" - ]: - prepare_op_amp_options( - concrete_program.main_program, - ProgramTranslator.get_instance()._amp_records, - DEFAULT_AMP_OPTIONS, - ) - self._build_startup_program() - - def _build_startup_program(self): - """ - Create and Sync parameters into startup program. - """ - startup_program = self.startup_program - if len(startup_program.global_block().ops) > 1: - self.lazy_init = True - return - - for param in self.concrete_program.parameters: - Parameter( - name=param.name, - desc=param, - type=param.type, - shape=param.shape, - dtype=param.dtype, - stop_gradient=param.stop_gradient, - block=startup_program.global_block(), - ) - - def apply_optimizer(self, optimizer): - """ - Append backward and generate optimizer operations. - """ - self._verify_optimizer(optimizer) - self._logger.info( - "start to apply optimizer: %s ", type(optimizer).__name__ - ) - # clear optimizer parameters - original_params = optimizer._parameter_list - optimizer._parameter_list = None - with program_guard(self.main_program, self.startup_program): - res = optimizer.minimize(self.loss_vars[0]) - - # restore optimizer parameters - optimizer._parameter_list = original_params - return res - - def _verify_optimizer(self, optimizer): - assert optimizer is not None - assert hasattr( - optimizer, "minimize" - ), "Optimizer must have minimize() method." - assert ( - self.proxy_layer.mode == 'train' - ), f"Required mode == 'train', but received '{self.proxy_layer.mode}'" - assert ( - len(self.loss_vars) == 1 - ), f"Required len(loss_vars) == 1, but received len(loss_vars) = {len(self.loss_vars)}" - - def to(self, mode): - """ - Switch underly proxy layer mode into target mode. - """ - assert mode in ['train', 'eval', 'predict'] - func = getattr(self.proxy_layer, '_' + mode) - assert isinstance( - func, StaticFunction - ), "Please call build_program(mode) firstly." - self.proxy_layer.set_mode(mode) - - def static_func(self): - """ - Return StaticFunction instance with underly target mode. - """ - assert self.proxy_layer.mode in [ - 'train', - 'eval', - 'predict', - ], "Please call build_program(mode) firstly." - func_name = '_' + self.proxy_layer.mode - return getattr(self.proxy_layer, func_name) - - def init_pir(self, main_program, place): - # collect all params in current dist program - param_values = main_program.global_block().all_parameters() - value_name_to_value = {} - dy_param_name_to_pir_param_name = {} - for value in param_values: - value_name_to_value[value.name] = value - - dy_params = self.concrete_program.parameters[0] - pir_param = self.concrete_program.parameters[1] - - for i in range(len(pir_param)): - if pir_param[i].name in value_name_to_value: - dy_param_name_to_pir_param_name[dy_params[i].name] = pir_param[ - i - ].name - - is_comm = False - for param in dy_params: - if param.is_dist(): - process_mesh, dims_mapping = self._all_params_dist_attr[ - param.name - ] - var_dist_attr = TensorDistAttr() - var_dist_attr.process_mesh = process_mesh - var_dist_attr.dims_mapping = dims_mapping - is_comm = True - with paddle.no_grad(): - tmp = paddle.base.core.reshard(param, var_dist_attr) - if tmp._is_initialized(): - param.get_tensor()._share_data_with(tmp.get_tensor()) - else: - # Only setting the "param" to "None" can't release the memory - param.get_tensor()._clear() - param = None - - # create var in scope and share parameters to scope - if param is None: - continue - if param.name not in dy_param_name_to_pir_param_name: - # Release the redundant params - param.get_tensor()._clear() - continue - if not param._is_initialized(): - continue - if param.is_dense(): - value_name = dy_param_name_to_pir_param_name[param.name] - value = value_name_to_value[value_name] - # get param_var's dist_attr - assert ( - value.is_dist_dense_tensor_type() - ), f"param [{value.name}] is not dist tensor type" - dist_attr = { - "dims_mapping": value.dist_attr().dims_mapping, - "process_shape": value.dist_attr().process_mesh.shape, - "process_group": value.dist_attr().process_mesh.process_ids, - } - # slice param_value with dist_attr - # share sliced_param_value with param_tensor in global_scope - pir_scope_param = global_scope().var(value_name).get_tensor() - sliced_param = Converter.slice_with_dist_attr( - param.numpy(), dist_attr - ) - pir_scope_param.set(sliced_param, place) - param.get_tensor()._clear() - - elif param.is_dist(): - value_name = dy_param_name_to_pir_param_name[param.name] - value = value_name_to_value[value_name] - # assert value.is_dist_dense_tensor_type(), "param [{}] is not dist tensor type".format(value.name) - pir_scope_param = global_scope().var(value_name).get_tensor() - pir_scope_param._share_data_with( - param.get_tensor().get_tensor() - ) - param.get_tensor()._clear() - - world_group = get_world_process_group() - if ( - is_comm - and world_group.nranks > 1 - and paddle.distributed.get_world_size() > 1 - ): - paddle.disable_static() - barrier_tensor = paddle.full([1], 1, dtype="int32") - # barrier is not available in xpu for now - if not paddle.framework.core.is_compiled_with_xpu(): - paddle._legacy_C_ops.barrier( - barrier_tensor, barrier_tensor, 'ring_id', 0 - ) - paddle.enable_static() - - def init(self, main_program, place, dist_context): - if self.lazy_init: - return - - amp_strategy = dist_context.strategy.amp - amp_config = copy.deepcopy(amp_strategy.to_dict()) - need_cast_parameter = amp_strategy.enable and amp_config["level"] in [ - "o2", - "o3", - ] - is_comm = False - for param in self.concrete_program.parameters: - if param.is_dist(): - serial_main_program = self.concrete_program.main_program - var = serial_main_program.global_block().vars[param.name] - var_dist_attr = dist_context.get_tensor_dist_attr_for_program( - var - ) - is_comm = True - # No need to construct backward. - with paddle.no_grad(): - tmp = paddle.base.core.reshard(param, var_dist_attr) - if tmp._is_initialized(): - param.get_tensor()._share_data_with(tmp.get_tensor()) - else: - # Only setting the "param" to "None" can't release the memory - param.get_tensor()._clear() - param = None - paddle.device.synchronize() - - # create var in scope and share parameters to scope - if param is None: - continue - if param.name not in main_program.global_block().vars: - # Release the redundant params - param.get_tensor()._clear() - continue - if not param._is_initialized(): - continue - if param.is_dense(): - # get param_var's dist_attr - var = main_program.global_block().vars[param.name] - var_dist_attr = dist_context.get_tensor_dist_attr_for_program( - var - ) - dist_attr = { - "dims_mapping": var_dist_attr.dims_mapping, - "process_shape": var_dist_attr.process_mesh.shape, - "process_group": var_dist_attr.process_mesh.process_ids, - } - # slice param_value with dist_attr - # share sliced_param_value with param_tensor in global_scope - param_tensor = global_scope().var(param.name).get_tensor() - sliced_param = Converter.slice_with_dist_attr( - param.numpy(), dist_attr - ) - param_tensor.set(sliced_param, place) - if not need_cast_parameter: - param.get_tensor()._clear() - elif param.is_dist(): - dense_tensor = global_scope().var(param.name).get_tensor() - dense_tensor._share_data_with(param.get_tensor().get_tensor()) - - # transform the parameter in eager mode for amp. - if need_cast_parameter: - for param in self.concrete_program.parameters: - amp_dtype = amp_config["dtype"] - scope_var = global_scope().find_var(param.name) - # The parameter is not in this rank. - if not scope_var: - continue - # The parameter do not need to transform - if param.dtype in [paddle.float16, paddle.bfloat16]: - continue - scope_tensor = global_scope().var(param.name).get_tensor() - assert ( - scope_var and scope_tensor._is_initialized() - ), f"Parameter: {param.name} is not put into global_scope or not initialized." - param_used = param - # For the params without dist_attr. - # NOTE(lizhiyu): In principle, each param should have dist_attr. - if param.is_dense(): - # get param_var's dist_attr - var = main_program.global_block().vars[param.name] - var_dist_attr = ( - dist_context.get_tensor_dist_attr_for_program(var) - ) - dist_attr = { - "dims_mapping": var_dist_attr.dims_mapping, - "process_shape": var_dist_attr.process_mesh.shape, - "process_group": var_dist_attr.process_mesh.process_ids, - } - # slice param_value with dist_attr - sliced_param = Converter.slice_with_dist_attr( - param.numpy(), dist_attr - ) - with paddle.base.dygraph.guard(): - param_used = paddle.to_tensor( - sliced_param, place=param.place - ) - param.get_tensor()._clear() - with paddle.base.dygraph.guard(): - if amp_dtype == "float16": - with paddle.no_grad(): - with paddle.base.framework._dygraph_place_guard( - place=place - ): - t_casted = param_used.cast( - dtype=core.VarDesc.VarType.FP16 - ) - elif amp_dtype == "bfloat16": - with paddle.no_grad(): - with paddle.base.framework._dygraph_place_guard( - place=place - ): - t_casted = param_used.cast( - dtype=core.VarDesc.VarType.BF16 - ) - # NOTE(lizhiyu): Clear the origin param. Don't use `param_used.get_tensor().get_tensor()._clear()` to - # clear the `DistTensor`, because it can't clear the `_holder`, - # which `param_used.get_tensor().get_tensor()` will copy one `DenseTensor`. - param_used.get_tensor()._clear() - if t_casted.is_dist(): - scope_tensor._share_data_with( - t_casted.get_tensor().get_tensor() - ) - else: - scope_tensor._share_data_with(t_casted.get_tensor()) - - world_group = get_world_process_group() - if ( - is_comm - and world_group.nranks > 1 - and paddle.distributed.get_world_size() > 1 - ): - paddle.disable_static() - barrier_tensor = paddle.full([1], 1, dtype="int32") - # barrier is not available in xpu for now - if not paddle.framework.core.is_compiled_with_xpu(): - paddle._legacy_C_ops.barrier( - barrier_tensor, barrier_tensor, 'ring_id', 0 - ) - paddle.enable_static() - - def cache_whole_graph_dist_attr(self, all_params): - for param_value in all_params: - dist_attr = param_value.dist_attr() - if dist_attr: - process_mesh = dist_attr.process_mesh - dims_mapping = dist_attr.dims_mapping - self._all_params_dist_attr[param_value.name] = [ - process_mesh, - dims_mapping, - ] - - @property - def concrete_program(self): - return self.static_func().concrete_program - - @property - def main_program(self): - return self.concrete_program.main_program - - @property - def startup_program(self): - try: - return self.proxy_layer.startup_program - except Exception as err: - self._logger.warning( - "The startup_program is not built by `lazy init`." - ) - if isinstance(err, AssertionError): - return self.concrete_program.startup_program - raise err - - @property - def input_vars(self): - return to_list(self.proxy_layer.input_vars) - - @property - def output_vars(self): - return to_list(self.proxy_layer.output_vars) - - @property - def label_vars(self): - return to_list(self.proxy_layer.label_vars) - - @property - def loss_vars(self): - return to_list(self.proxy_layer.loss_vars) - - @property - def loss_names(self): - return to_list(self.proxy_layer.loss_names) - - @property - def metric_vars(self): - return to_list(self.proxy_layer.metric_vars) - - def named_parameters(self): - static_func = self.static_func() - partial_program = static_func.get_concrete_program( - self.inputs_spec, self.labels_spec - )[-1] - # TODO(xiongkun): support pir in the feature. - return {param.name: param for param in partial_program._params} From b5555fc574b85ef3d3944f651ca99013a45d939b Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 19 Feb 2025 02:47:47 +0000 Subject: [PATCH 3/8] fix --- .../auto_parallel/static/helper.py | 669 +++++++++++++++++- 1 file changed, 668 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/auto_parallel/static/helper.py b/python/paddle/distributed/auto_parallel/static/helper.py index a9cc79cc9d7f19..46beda4e823dff 100644 --- a/python/paddle/distributed/auto_parallel/static/helper.py +++ b/python/paddle/distributed/auto_parallel/static/helper.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,670 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import copy +import inspect +import logging +from collections import defaultdict + +import paddle +from paddle import core +from paddle.jit import not_to_static, to_static +from paddle.jit.dy2static.program_translator import ( + ProgramTranslator, + StaticFunction, + unwrap_decorators, +) +from paddle.jit.dy2static.utils import as_not_paddle_func +from paddle.nn import Layer +from paddle.static import Parameter, global_scope, program_guard +from paddle.static.amp.fp16_utils import ( + DEFAULT_AMP_OPTIONS, + prepare_op_amp_options, +) + +from .converter import Converter +from .dist_attribute import TensorDistAttr +from .process_group import get_world_process_group +from .utils import get_logger, to_list + + +class ProxyLayer(Layer): + """ + ProxyLayer implements all logic for converting dygraph model into + static Program IR. Meanwhile, it provides conventional interfaces for + auto parallel to visit feed/fetch/loss/metric variables. + """ + + def __init__(self, layer, loss_func, metrics): + super().__init__() + # NOTE: All verify logics are finished in Engine.Prepare + self.inner_layer = layer + self.loss_func = loss_func + self.metrics = metrics + # train / eval / predict + self.mode = None + + # generated program vars + self._input_vars = defaultdict(list) + self._label_vars = defaultdict(list) + self._output_vars = defaultdict(list) + self._loss_vars = defaultdict(list) + self._loss_names = defaultdict(list) + self._metric_vars = defaultdict(list) + + # Consider ProxyLayer as not Paddle inner function because it contains + # user-defined layer. + for fn_name in [ + "_train", + "_eval", + "_predict", + "call_loss", + "call_metrics", + ]: + as_not_paddle_func( + f"{inspect.getmodule(ProxyLayer).__name__}.ProxyLayer.{fn_name}" + ) + + @paddle.jit.not_to_static + def append_loss_to_shadow_output(self, mode): + name = paddle.utils.unique_name.generate('loss') + paddle._C_ops.set_persistable_value(self._loss_vars[mode], name) + self._loss_names[mode] = name + + def _train(self, inputs, labels): + """ + Train process of inner_layer with forward/loss/metric logic. + """ + # step 1. save feed variables of Program + mode = 'train' + self._input_vars[mode] = inputs + self._label_vars[mode] = labels + + # step 2. call inner_layer.forward + has_labels_arg = False + if isinstance(self.inner_layer, Layer): + _, fwd_func = unwrap_decorators(self.inner_layer.forward) + if "labels" in inspect.signature(fwd_func).parameters.keys(): + has_labels_arg = True + if has_labels_arg: + self._output_vars[mode] = self.inner_layer( + *inputs, labels=labels[0] + ) + else: + self._output_vars[mode] = self.inner_layer(*inputs) + + # step 3. calculate loss if needed + new_inputs = self._prepare(self.output_vars, labels) + self._loss_vars[mode] = self.call_loss(new_inputs) + if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[ + "FLAGS_enable_pir_api" + ]: + self.append_loss_to_shadow_output(mode) + + # step 4. calculate metrics if needed + self._metric_vars[mode] = self.call_metrics(new_inputs) + + def _eval(self, inputs, labels): + """ + Evaluate process of inner_layer with forward/loss/metric logic. + """ + # TODO(dev): we can reuse codes with self._train after making + # sure if they can. + + # step 1. save feed variables of Program + mode = 'eval' + self._input_vars[mode] = inputs + self._label_vars[mode] = labels + + # step 2. call inner_layer.forward + self._output_vars[mode] = self.inner_layer(*inputs) + + # step 3. calculate loss if needed + new_inputs = self._prepare(self.output_vars, labels) + self._loss_vars[mode] = self.call_loss(new_inputs) + if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[ + "FLAGS_enable_pir_api" + ]: + self.append_loss_to_shadow_output(mode) + + # step 4. calculate metrics if needed + self._metric_vars[mode] = self.call_metrics(new_inputs) + + def _predict(self, inputs, labels): + """ + Predict process of inner_layer with forward logic. + """ + # step 1. save feed variables of Program + mode = 'predict' + self._input_vars[mode] = inputs + self._label_vars[mode] = labels + + # step 2. call inner_layer.forward + self._output_vars[mode] = self.inner_layer(*inputs) + + @not_to_static + def _prepare(self, outputs, labels): + """ + Concat outputs and labels as a single list + + NOTE(dev): We use @not_to_static to avoid AST Analysis. + """ + return to_list(outputs) + to_list(labels) + + def call_loss(self, inputs): + """ + Apply Loss Function on outputs and labels. + + Args: + inputs: List[Variable] + + Returns: List[Variable] + """ + res = [] + if self.loss_func is not None: + res = self.loss_func(*inputs) + return res + + def call_metrics(self, inputs): + """ + Apply Metrics Function on outputs and labels. + + Args: + inputs: List[Variable] + + Returns: List[Variable] + """ + outs = [] + for metric in self.metrics: + outs.append(to_list(metric.compute(*inputs))) + + return outs + + def set_mode(self, mode): + self.mode = mode + self.training = mode == 'train' + + def clone(self): + return ProxyLayer(self.inner_layer, self.loss_func, self.metrics) + + @property + def input_vars(self): + return self._input_vars[self.mode] + + @property + def label_vars(self): + return self._label_vars[self.mode] + + @property + def output_vars(self): + return self._output_vars[self.mode] + + @property + def loss_vars(self): + return self._loss_vars[self.mode] + + @property + def loss_names(self): + return self._loss_names[self.mode] + + @property + def metric_vars(self): + return self._metric_vars[self.mode] + + @property + def startup_program(self): + return self.inner_layer._startup_program() + + +class BuildInfo: + def __init__(self): + self.clear() + + def has_cache(self, mode, update=False): + is_cache = self.states[mode] + if update: + self.cache(mode) + return is_cache + + def cache(self, mode): + self.states[mode] = True + + def clear(self): + self.states = defaultdict(bool) + + +class ProgramHelper: + """ + A Helper class for Engine to provides different Program IR according specified 'mode'. + """ + + def __init__(self, layer, loss_func, metrics, inputs_spec, labels_spec): + # original model config information + # TODO(Aurelius84): Implement append_backward and optimizer in ProxyLayer + # after distribute engine satisfy basic condition. + self.proxy_layer = ProxyLayer(layer, loss_func, metrics) + self.inputs_spec = inputs_spec + self.labels_spec = labels_spec + + self.build_info = BuildInfo() + self._logger = get_logger(logging.INFO) + self.lazy_init = False + self._all_params_dist_attr = {} + + def reset(self): + """ + Reset all state of current Object. + """ + self.build_info.clear() + self.proxy_layer = self.proxy_layer.clone() + + def build_program(self, mode): + """ + Convert dygraph model into static Program IR. + """ + assert mode in ['train', 'eval', 'predict'] + self.proxy_layer.set_mode(mode) + # skip if we has already built program. + if self.build_info.has_cache(mode, True): + self._logger.info( + f"Already build program with mode = {mode}, use cached program." + ) + return + + self._logger.info(f"start to build program for mode = {mode}.") + input_spec = [self.inputs_spec, self.labels_spec] + static_func = to_static( + self.static_func(), input_spec=input_spec, full_graph=True + ) + + func_name = '_' + mode + setattr(self.proxy_layer, func_name, static_func) + + # NOTE(dev): Because @to_static is a Lazy mechanism, so we explicitly call this to trigger + # generating Program IR immediately. + concrete_program = getattr(self.proxy_layer, func_name).concrete_program + + # TODO(zhiqiu): prepare_op_amp_options is not supported for PIR program + # It will to use dynamic-static unified amp in pir program, and there is + # no need to fit for prepare_op_amp_options + if not paddle.base.framework.get_flags("FLAGS_enable_pir_api")[ + "FLAGS_enable_pir_api" + ]: + prepare_op_amp_options( + concrete_program.main_program, + ProgramTranslator.get_instance()._amp_records, + DEFAULT_AMP_OPTIONS, + ) + self._build_startup_program() + + def _build_startup_program(self): + """ + Create and Sync parameters into startup program. + """ + startup_program = self.startup_program + if len(startup_program.global_block().ops) > 1: + self.lazy_init = True + return + + for param in self.concrete_program.parameters: + Parameter( + name=param.name, + desc=param, + type=param.type, + shape=param.shape, + dtype=param.dtype, + stop_gradient=param.stop_gradient, + block=startup_program.global_block(), + ) + + def apply_optimizer(self, optimizer): + """ + Append backward and generate optimizer operations. + """ + self._verify_optimizer(optimizer) + self._logger.info( + "start to apply optimizer: %s ", type(optimizer).__name__ + ) + # clear optimizer parameters + original_params = optimizer._parameter_list + optimizer._parameter_list = None + with program_guard(self.main_program, self.startup_program): + res = optimizer.minimize(self.loss_vars[0]) + + # restore optimizer parameters + optimizer._parameter_list = original_params + return res + + def _verify_optimizer(self, optimizer): + assert optimizer is not None + assert hasattr( + optimizer, "minimize" + ), "Optimizer must have minimize() method." + assert ( + self.proxy_layer.mode == 'train' + ), f"Required mode == 'train', but received '{self.proxy_layer.mode}'" + assert ( + len(self.loss_vars) == 1 + ), f"Required len(loss_vars) == 1, but received len(loss_vars) = {len(self.loss_vars)}" + + def to(self, mode): + """ + Switch underly proxy layer mode into target mode. + """ + assert mode in ['train', 'eval', 'predict'] + func = getattr(self.proxy_layer, '_' + mode) + assert isinstance( + func, StaticFunction + ), "Please call build_program(mode) firstly." + self.proxy_layer.set_mode(mode) + + def static_func(self): + """ + Return StaticFunction instance with underly target mode. + """ + assert self.proxy_layer.mode in [ + 'train', + 'eval', + 'predict', + ], "Please call build_program(mode) firstly." + func_name = '_' + self.proxy_layer.mode + return getattr(self.proxy_layer, func_name) + + def init_pir(self, main_program, place): + # collect all params in current dist program + param_values = main_program.global_block().all_parameters() + value_name_to_value = {} + dy_param_name_to_pir_param_name = {} + for value in param_values: + value_name_to_value[value.name] = value + + dy_params = self.concrete_program.parameters[0] + pir_param = self.concrete_program.parameters[1] + + for i in range(len(pir_param)): + if pir_param[i].name in value_name_to_value: + dy_param_name_to_pir_param_name[dy_params[i].name] = pir_param[ + i + ].name + + is_comm = False + for param in dy_params: + if param.is_dist(): + process_mesh, dims_mapping = self._all_params_dist_attr[ + param.name + ] + var_dist_attr = TensorDistAttr() + var_dist_attr.process_mesh = process_mesh + var_dist_attr.dims_mapping = dims_mapping + is_comm = True + with paddle.no_grad(): + tmp = paddle.base.core.reshard(param, var_dist_attr) + if tmp._is_initialized(): + param.get_tensor()._share_data_with(tmp.get_tensor()) + else: + # Only setting the "param" to "None" can't release the memory + param.get_tensor()._clear() + param = None + + # create var in scope and share parameters to scope + if param is None: + continue + if param.name not in dy_param_name_to_pir_param_name: + # Release the redundant params + param.get_tensor()._clear() + continue + if not param._is_initialized(): + continue + if param.is_dense(): + value_name = dy_param_name_to_pir_param_name[param.name] + value = value_name_to_value[value_name] + # get param_var's dist_attr + assert ( + value.is_dist_dense_tensor_type() + ), f"param [{value.name}] is not dist tensor type" + dist_attr = { + "dims_mapping": value.dist_attr().dims_mapping, + "process_shape": value.dist_attr().process_mesh.shape, + "process_group": value.dist_attr().process_mesh.process_ids, + } + # slice param_value with dist_attr + # share sliced_param_value with param_tensor in global_scope + pir_scope_param = global_scope().var(value_name).get_tensor() + sliced_param = Converter.slice_with_dist_attr( + param.numpy(), dist_attr + ) + pir_scope_param.set(sliced_param, place) + param.get_tensor()._clear() + + elif param.is_dist(): + value_name = dy_param_name_to_pir_param_name[param.name] + value = value_name_to_value[value_name] + # assert value.is_dist_dense_tensor_type(), "param [{}] is not dist tensor type".format(value.name) + pir_scope_param = global_scope().var(value_name).get_tensor() + pir_scope_param._share_data_with( + param.get_tensor().get_tensor() + ) + param.get_tensor()._clear() + + world_group = get_world_process_group() + if ( + is_comm + and world_group.nranks > 1 + and paddle.distributed.get_world_size() > 1 + ): + paddle.disable_static() + barrier_tensor = paddle.full([1], 1, dtype="int32") + # barrier is not available in xpu for now + if not paddle.framework.core.is_compiled_with_xpu(): + paddle._legacy_C_ops.barrier( + barrier_tensor, barrier_tensor, 'ring_id', 0 + ) + paddle.enable_static() + + def init(self, main_program, place, dist_context): + if self.lazy_init: + return + + amp_strategy = dist_context.strategy.amp + amp_config = copy.deepcopy(amp_strategy.to_dict()) + need_cast_parameter = amp_strategy.enable and amp_config["level"] in [ + "o2", + "o3", + ] + is_comm = False + for param in self.concrete_program.parameters: + if param.is_dist(): + serial_main_program = self.concrete_program.main_program + var = serial_main_program.global_block().vars[param.name] + var_dist_attr = dist_context.get_tensor_dist_attr_for_program( + var + ) + is_comm = True + # No need to construct backward. + with paddle.no_grad(): + tmp = paddle.base.core.reshard(param, var_dist_attr) + if tmp._is_initialized(): + param.get_tensor()._share_data_with(tmp.get_tensor()) + else: + # Only setting the "param" to "None" can't release the memory + param.get_tensor()._clear() + param = None + paddle.device.synchronize() + + # create var in scope and share parameters to scope + if param is None: + continue + if param.name not in main_program.global_block().vars: + # Release the redundant params + param.get_tensor()._clear() + continue + if not param._is_initialized(): + continue + if param.is_dense(): + # get param_var's dist_attr + var = main_program.global_block().vars[param.name] + var_dist_attr = dist_context.get_tensor_dist_attr_for_program( + var + ) + dist_attr = { + "dims_mapping": var_dist_attr.dims_mapping, + "process_shape": var_dist_attr.process_mesh.shape, + "process_group": var_dist_attr.process_mesh.process_ids, + } + # slice param_value with dist_attr + # share sliced_param_value with param_tensor in global_scope + param_tensor = global_scope().var(param.name).get_tensor() + sliced_param = Converter.slice_with_dist_attr( + param.numpy(), dist_attr + ) + param_tensor.set(sliced_param, place) + if not need_cast_parameter: + param.get_tensor()._clear() + elif param.is_dist(): + dense_tensor = global_scope().var(param.name).get_tensor() + dense_tensor._share_data_with(param.get_tensor().get_tensor()) + + # transform the parameter in eager mode for amp. + if need_cast_parameter: + for param in self.concrete_program.parameters: + amp_dtype = amp_config["dtype"] + scope_var = global_scope().find_var(param.name) + # The parameter is not in this rank. + if not scope_var: + continue + # The parameter do not need to transform + if param.dtype in [paddle.float16, paddle.bfloat16]: + continue + scope_tensor = global_scope().var(param.name).get_tensor() + assert ( + scope_var and scope_tensor._is_initialized() + ), f"Parameter: {param.name} is not put into global_scope or not initialized." + param_used = param + # For the params without dist_attr. + # NOTE(lizhiyu): In principle, each param should have dist_attr. + if param.is_dense(): + # get param_var's dist_attr + var = main_program.global_block().vars[param.name] + var_dist_attr = ( + dist_context.get_tensor_dist_attr_for_program(var) + ) + dist_attr = { + "dims_mapping": var_dist_attr.dims_mapping, + "process_shape": var_dist_attr.process_mesh.shape, + "process_group": var_dist_attr.process_mesh.process_ids, + } + # slice param_value with dist_attr + sliced_param = Converter.slice_with_dist_attr( + param.numpy(), dist_attr + ) + with paddle.base.dygraph.guard(): + param_used = paddle.to_tensor( + sliced_param, place=param.place + ) + param.get_tensor()._clear() + with paddle.base.dygraph.guard(): + if amp_dtype == "float16": + with paddle.no_grad(): + with paddle.base.framework._dygraph_place_guard( + place=place + ): + t_casted = param_used.cast( + dtype=core.VarDesc.VarType.FP16 + ) + elif amp_dtype == "bfloat16": + with paddle.no_grad(): + with paddle.base.framework._dygraph_place_guard( + place=place + ): + t_casted = param_used.cast( + dtype=core.VarDesc.VarType.BF16 + ) + # NOTE(lizhiyu): Clear the origin param. Don't use `param_used.get_tensor().get_tensor()._clear()` to + # clear the `DistTensor`, because it can't clear the `_holder`, + # which `param_used.get_tensor().get_tensor()` will copy one `DenseTensor`. + param_used.get_tensor()._clear() + if t_casted.is_dist(): + scope_tensor._share_data_with( + t_casted.get_tensor().get_tensor() + ) + else: + scope_tensor._share_data_with(t_casted.get_tensor()) + + world_group = get_world_process_group() + if ( + is_comm + and world_group.nranks > 1 + and paddle.distributed.get_world_size() > 1 + ): + paddle.disable_static() + barrier_tensor = paddle.full([1], 1, dtype="int32") + # barrier is not available in xpu for now + if not paddle.framework.core.is_compiled_with_xpu(): + paddle._legacy_C_ops.barrier( + barrier_tensor, barrier_tensor, 'ring_id', 0 + ) + paddle.enable_static() + + def cache_whole_graph_dist_attr(self, all_params): + for param_value in all_params: + dist_attr = param_value.dist_attr() + if dist_attr: + process_mesh = dist_attr.process_mesh + dims_mapping = dist_attr.dims_mapping + self._all_params_dist_attr[param_value.name] = [ + process_mesh, + dims_mapping, + ] + + @property + def concrete_program(self): + return self.static_func().concrete_program + + @property + def main_program(self): + return self.concrete_program.main_program + + @property + def startup_program(self): + try: + return self.proxy_layer.startup_program + except Exception as err: + self._logger.warning( + "The startup_program is not built by `lazy init`." + ) + if isinstance(err, AssertionError): + return self.concrete_program.startup_program + raise err + + @property + def input_vars(self): + return to_list(self.proxy_layer.input_vars) + + @property + def output_vars(self): + return to_list(self.proxy_layer.output_vars) + + @property + def label_vars(self): + return to_list(self.proxy_layer.label_vars) + + @property + def loss_vars(self): + return to_list(self.proxy_layer.loss_vars) + + @property + def loss_names(self): + return to_list(self.proxy_layer.loss_names) + + @property + def metric_vars(self): + return to_list(self.proxy_layer.metric_vars) + + def named_parameters(self): + static_func = self.static_func() + partial_program = static_func.get_concrete_program( + self.inputs_spec, self.labels_spec + )[-1] + # TODO(xiongkun): support pir in the feature. + return {param.name: param for param in partial_program._params} From 7b040724f6fd82ef09cbccbc7cab7168bac4c8fa Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 24 Feb 2025 06:07:58 +0000 Subject: [PATCH 4/8] refine --- paddle/fluid/pybind/pybind.cc | 34 +++++++++++++++++++ .../auto_parallel/static/pir_pass.py | 5 ++- .../distributed/auto_parallel/static/utils.py | 1 - 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index d4dc5eea636be8..a3f5b76ccd5542 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -211,6 +211,7 @@ limitations under the License. */ #include "paddle/fluid/eager/nan_inf_utils.h" #include "paddle/fluid/imperative/layout_autotune.h" #include "paddle/fluid/pir/dialect/distributed/ir/dist_interface.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h" #include "paddle/fluid/pir/dialect/operator/interface/decomp.h" #include "paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" @@ -965,6 +966,39 @@ void BindVjp(pybind11::module *m) { } } } + if (!vjp_res[grad_index][j].defining_op()->HasAttribute( + kAttrOpDistAttr)) { + auto ctx = pir::IrContext::Instance(); + auto input_values = + vjp_res[grad_index][j].defining_op()->operands_source(); + auto output_values = + vjp_res[grad_index][j].defining_op()->results(); + paddle::dialect::ProcessMeshAttribute op_mesh; + + if (paddle::dialect::HasDistInput(input_values, &op_mesh)) { + std::vector dist_operand_attrs, + dist_result_attrs; + for (int input_id = 0; input_id < input_values.size(); + ++input_id) { + dist_operand_attrs.push_back( + paddle::dialect::GetTensorDistAttr( + input_values[input_id].type())); + } + for (int output_id = 0; output_id < output_values.size(); + ++output_id) { + dist_result_attrs.push_back( + paddle::dialect::GetTensorDistAttr( + output_values[output_id].type())); + } + vjp_res[grad_index][j].defining_op()->set_attribute( + kAttrOpDistAttr, + paddle::dialect::OperationDistAttribute::get( + ctx, + op_mesh, + dist_operand_attrs, + dist_result_attrs)); + } + } vjp_res[grad_index][j].set_type(inputs[idx][j].type()); } } diff --git a/python/paddle/distributed/auto_parallel/static/pir_pass.py b/python/paddle/distributed/auto_parallel/static/pir_pass.py index 64e56ace2499e1..74521f3aceb1e7 100644 --- a/python/paddle/distributed/auto_parallel/static/pir_pass.py +++ b/python/paddle/distributed/auto_parallel/static/pir_pass.py @@ -86,9 +86,8 @@ def reshard_single_value(program, op, operand, attr): def reshard_combine_value(program, op, operand, attr): prev_var = operand.source() - assert ( - prev_var.get_defining_op().name() == 'builtin.combine' - ), f"TensorList must be defined by builtin.combine op, but is {prev_var.get_defining_op().name()}." + if prev_var.get_defining_op().name() != 'builtin.combine': + return combine_op = prev_var.get_defining_op() array_attr = attr.as_array_attr() diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index 01ecba71196194..f9ea2e61f6caf9 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -58,7 +58,6 @@ partition_skip_op_list = [ "builtin.combine", - "builtin.split", "pd_op.pylayer", "cf.yield", "cf.tuple_push", From fedb6dc7ed775c62e16deff65d28bdc4d3b82d16 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 24 Feb 2025 07:26:52 +0000 Subject: [PATCH 5/8] fix --- paddle/fluid/pybind/pybind.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a3f5b76ccd5542..2f943e192959ed 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -978,13 +978,13 @@ void BindVjp(pybind11::module *m) { if (paddle::dialect::HasDistInput(input_values, &op_mesh)) { std::vector dist_operand_attrs, dist_result_attrs; - for (int input_id = 0; input_id < input_values.size(); + for (size_t input_id = 0; input_id < input_values.size(); ++input_id) { dist_operand_attrs.push_back( paddle::dialect::GetTensorDistAttr( input_values[input_id].type())); } - for (int output_id = 0; output_id < output_values.size(); + for (size_t output_id = 0; output_id < output_values.size(); ++output_id) { dist_result_attrs.push_back( paddle::dialect::GetTensorDistAttr( From 646d8e988abb5cf89016538628f947a7de82ea97 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 24 Feb 2025 11:07:52 +0000 Subject: [PATCH 6/8] fix --- python/paddle/distributed/auto_parallel/static/pir_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/auto_parallel/static/pir_pass.py b/python/paddle/distributed/auto_parallel/static/pir_pass.py index 74521f3aceb1e7..224831cf1b80e5 100644 --- a/python/paddle/distributed/auto_parallel/static/pir_pass.py +++ b/python/paddle/distributed/auto_parallel/static/pir_pass.py @@ -87,7 +87,7 @@ def reshard_combine_value(program, op, operand, attr): prev_var = operand.source() if prev_var.get_defining_op().name() != 'builtin.combine': - return + return prev_var combine_op = prev_var.get_defining_op() array_attr = attr.as_array_attr() From a78cdb78442e2c0fb901a31efe1fec8229e8c984 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 25 Feb 2025 09:04:39 +0000 Subject: [PATCH 7/8] fix ci bug --- paddle/fluid/pybind/pybind.cc | 15 ++++++++------- .../distributed/auto_parallel/static/pir_pass.py | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 2f943e192959ed..cf9617d47833b5 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -966,15 +966,16 @@ void BindVjp(pybind11::module *m) { } } } + auto input_values = + vjp_res[grad_index][j].defining_op()->operands_source(); + auto output_values = + vjp_res[grad_index][j].defining_op()->results(); + paddle::dialect::ProcessMeshAttribute op_mesh; if (!vjp_res[grad_index][j].defining_op()->HasAttribute( - kAttrOpDistAttr)) { + kAttrOpDistAttr) && + paddle::dialect::AllInputAreDist(input_values) && + paddle::dialect::AllInputAreDist(output_values)) { auto ctx = pir::IrContext::Instance(); - auto input_values = - vjp_res[grad_index][j].defining_op()->operands_source(); - auto output_values = - vjp_res[grad_index][j].defining_op()->results(); - paddle::dialect::ProcessMeshAttribute op_mesh; - if (paddle::dialect::HasDistInput(input_values, &op_mesh)) { std::vector dist_operand_attrs, dist_result_attrs; diff --git a/python/paddle/distributed/auto_parallel/static/pir_pass.py b/python/paddle/distributed/auto_parallel/static/pir_pass.py index 224831cf1b80e5..171292d4f60914 100644 --- a/python/paddle/distributed/auto_parallel/static/pir_pass.py +++ b/python/paddle/distributed/auto_parallel/static/pir_pass.py @@ -484,7 +484,7 @@ def prune_op(block): reverse_block_ops[i].erase() skip_idx = dtensor_to_local_idx + 1 continue - elif op.name() in partition_skip_op_list: + elif op.name() in [*partition_skip_op_list, 'builtin.split']: can_delete = True for val in op.results(): if not val.use_empty(): From 2a1796dbe700a7546cc1ab5130257d9a6ad66c76 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 26 Feb 2025 03:49:48 +0000 Subject: [PATCH 8/8] fix --- python/paddle/distributed/auto_parallel/static/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index f9ea2e61f6caf9..36f221dbe2c269 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -1115,7 +1115,7 @@ def _complete_op_dist_attr(program, block=None): for op in block.ops: for sub_block in op.blocks(): _complete_op_dist_attr(program, block=sub_block) - if op.name() in partition_skip_op_list: + if op.name() in [*partition_skip_op_list, 'builtin.split']: continue if op.dist_attr is None: