Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Fix the bug in grad op of dtensor_to_local #71264

Merged
merged 1 commit into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ paddle::Tensor dtensor_to_local_ad_function(
ToTensorDistAttr(process_mesh, placements, input.dims());

grad_node->SetGradDistAttr(grad_dist_attr);
grad_node->SetGradProcessMesh(process_mesh);
grad_node->SetGradPlacements(placements);
}

// Forward API Call
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,15 @@ DtensorToLocalGradNode::operator()(
VLOG(3) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str);
}

std::shared_ptr<phi::DenseTensor> grad_out_ptr =
std::static_pointer_cast<phi::DenseTensor>(grad_out.impl());
// Backward call dtensor_to_local_func function
auto dist_grad_ptr = std::make_shared<phi::distributed::DistTensor>(
grad_out.dims(), grad_dist_attr_);
grad_out_ptr,
out_metas[0][0].DistTensorGlobalDims(),
grad_process_mesh_,
grad_placements_);

*(dist_grad_ptr->unsafe_mutable_value()) =
*(static_cast<phi::DenseTensor*>(grad_out.impl().get()));
grad_input.set_impl(dist_grad_ptr);

VLOG(5) << "Finish C++ API: dtensor_to_local_func";
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,11 +493,21 @@ class DtensorToLocalGradNode : public egr::GradNodeBase {
grad_dist_attr_ = dist_attr;
}

void SetGradPlacements(const phi::distributed::Placements& placements) {
grad_placements_ = placements;
}

void SetGradProcessMesh(const phi::distributed::ProcessMesh& process_mesh) {
grad_process_mesh_ = process_mesh;
}

private:
// TensorWrappers
egr::TensorWrapper input_;

phi::distributed::TensorDistAttr grad_dist_attr_;
phi::distributed::Placements grad_placements_;
phi::distributed::ProcessMesh grad_process_mesh_;
};

class DtensorFromLocalGradNode : public egr::GradNodeBase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,12 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
op_name.compare(paddle::dialect::AllGatherOp::name()) == 0 ||
op_name.compare(paddle::dialect::MpAllreduceSum_Op::name()) == 0 ||
op_name.compare(paddle::dialect::CIdentity_Op::name()) == 0 ||
op_name.compare(paddle::dialect::CConcatOp::name()) == 0) {
op_name.compare(paddle::dialect::CConcatOp::name()) == 0 ||
op_name.compare(paddle::dialect::CConcatOp::name()) == 0 ||
op_name.compare(paddle::dialect::AllGatherOp::name()) == 0 ||
op_name.compare(paddle::dialect::AllToAllOp::name()) == 0 ||
op_name.compare(
paddle::dialect::CSoftmaxWithCrossEntropyOp::name()) == 0) {
if (phi::is_gpu_place(place) && execution_stream == kDefaultStream) {
if (origin_dev_ctx != nullptr) {
// set stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,7 @@ std::unordered_map<std::string, std::set<std::string>> GetNoNeedBufferValues(
no_need_buffer_vars.insert(name);
} else {
no_need_buffer_vars.erase(name);
break;
}
}
}
Expand All @@ -1535,6 +1536,7 @@ std::unordered_map<std::string, std::set<std::string>> GetNoNeedBufferValues(
no_need_buffer_vars.insert(name);
} else {
no_need_buffer_vars.erase(name);
break;
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/distributed/ir/dist_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class DtensorToLocalOp
// void VerifySig();
};

class MoESubMeshTensorsOp : public pir::Op<MoESubMeshTensorsOp> {
class MoESubMeshTensorsOp : public pir::Op<MoESubMeshTensorsOp, VjpInterface> {
public:
using Op::Op;
static const char* name() { return "dist_op.moe_sub_mesh_tensors"; }
Expand Down
19 changes: 11 additions & 8 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ def _cal_global_shape(local_shape, mesh, placements):
def moe_global_mesh_tensor(
local_tensor_list, mesh, placements, local_mesh_dim=-1
):
placements = copy.deepcopy(placements)
local_mesh_list, local_placements = _get_sub_meshes_and_local_placements(
mesh, placements, local_mesh_dim
)
Expand Down Expand Up @@ -548,16 +549,17 @@ def moe_global_mesh_tensor(
global_dims = _cal_global_shape(
local_tensor._local_shape, mesh, placements
)
return paddle.jit.dy2static.py_layer.StaticPyLayer(
_moe_global_mesh_tensor
).apply(
dist_tensor = paddle._C_ops.moe_global_mesh_tensor(
local_tensor_list,
local_mesh_list,
local_placements,
mesh,
placements,
global_dims,
)
dist_tensor.stop_gradient = local_tensor_list[0].stop_gradient
dist_tensor.persistable = local_tensor_list[0].persistable
return dist_tensor
else:
raise NotImplementedError(
"dtensor_from_local_list() are only supported in dynamic and pir mode."
Expand Down Expand Up @@ -691,6 +693,7 @@ def moe_sub_mesh_tensors(
"""
Get the local part of the ``dist_tensor`` on the specific ``local_mesh_dim``.
"""
global_placements = copy.deepcopy(global_placements)
local_mesh_list, local_placements = _get_sub_meshes_and_local_placements(
global_mesh, global_placements, local_mesh_dim
)
Expand All @@ -705,17 +708,17 @@ def moe_sub_mesh_tensors(
global_placements,
)
elif paddle.framework.in_pir_mode():

return paddle.jit.dy2static.py_layer.StaticPyLayer(
_moe_sub_mesh_tensors
).apply(
local_tensors = paddle._C_ops.moe_sub_mesh_tensors(
dist_tensor,
local_mesh_list,
local_placements,
local_mesh_dim,
global_mesh,
global_placements,
)
for local_tensor in local_tensors:
local_tensor.stop_gradient = dist_tensor.stop_gradient
local_tensor.persistable = dist_tensor.persistable
return local_tensors
else:
raise NotImplementedError(
"moe_sub_mesh_tensors is only supported in dynamic mode."
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/distributed/auto_parallel/placement_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ def to_dim_map(placements, tensor_dims):
if placement.is_shard():
shard_dim = cast(Shard, placement).get_dim()
if dim_map[shard_dim] > -1:
raise Exception(
"Tensor dim {shard_dim} is already sharded on mesh dim {dim_map[shard_dim]}"
import logging

logging.warning(
f"Tensor dim {shard_dim} is already sharded on mesh dim {dim_map[shard_dim]}."
)

dim_map[shard_dim] = i
Expand Down
25 changes: 17 additions & 8 deletions test/auto_parallel/pir/test_moe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,23 @@ def check_results(
local_meshes,
local_dims_mapping,
):
# local_tensors_from_dtensor op
self.check_dist_attr(ops[4], local_meshes, local_dims_mapping)
# dtensor_from_local_list op
self.check_dist_attr(ops[5], [global_mesh], global_dims_mapping)
# grad op for dtensor_from_local_list
self.check_dist_attr(ops[10], local_meshes, local_dims_mapping)
# grad op for local_tensors_from_dtensor op
self.check_dist_attr(ops[11], [global_mesh], global_dims_mapping)
op_names = [
"dist_op.moe_sub_mesh_tensors",
"dist_op.moe_global_mesh_tensor",
]
ops_to_check = [op for op in ops if op.name() in op_names]
# moe_sub_mesh_tensors op
self.check_dist_attr(ops_to_check[0], local_meshes, local_dims_mapping)
# moe_global_mesh_tensor op
self.check_dist_attr(
ops_to_check[1], [global_mesh], global_dims_mapping
)
# grad op for moe_global_mesh_tensor
self.check_dist_attr(ops_to_check[2], local_meshes, local_dims_mapping)
# grad op for moe_sub_mesh_tensors op
self.check_dist_attr(
ops_to_check[3], [global_mesh], global_dims_mapping
)


if __name__ == "__main__":
Expand Down