Skip to content

Commit

Permalink
fix the bug of setting grad's placements in dtensor_to_local
Browse files Browse the repository at this point in the history
  • Loading branch information
pkuzyc committed Feb 26, 2025
1 parent f04aa61 commit 79dea17
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ paddle::Tensor dtensor_to_local_ad_function(
ToTensorDistAttr(process_mesh, placements, input.dims());

grad_node->SetGradDistAttr(grad_dist_attr);
grad_node->SetGradProcessMesh(process_mesh);
grad_node->SetGradPlacements(placements);
}

// Forward API Call
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,15 @@ DtensorToLocalGradNode::operator()(
VLOG(3) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str);
}

std::shared_ptr<phi::DenseTensor> grad_out_ptr =
std::static_pointer_cast<phi::DenseTensor>(grad_out.impl());
// Backward call dtensor_to_local_func function
auto dist_grad_ptr = std::make_shared<phi::distributed::DistTensor>(
grad_out.dims(), grad_dist_attr_);
grad_out_ptr,
out_metas[0][0].DistTensorGlobalDims(),
grad_process_mesh_,
grad_placements_);

*(dist_grad_ptr->unsafe_mutable_value()) =
*(static_cast<phi::DenseTensor*>(grad_out.impl().get()));
grad_input.set_impl(dist_grad_ptr);

VLOG(5) << "Finish C++ API: dtensor_to_local_func";
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,11 +493,21 @@ class DtensorToLocalGradNode : public egr::GradNodeBase {
grad_dist_attr_ = dist_attr;
}

void SetGradPlacements(const phi::distributed::Placements& placements) {
grad_placements_ = placements;
}

void SetGradProcessMesh(const phi::distributed::ProcessMesh& process_mesh) {
grad_process_mesh_ = process_mesh;
}

private:
// TensorWrappers
egr::TensorWrapper input_;

phi::distributed::TensorDistAttr grad_dist_attr_;
phi::distributed::Placements grad_placements_;
phi::distributed::ProcessMesh grad_process_mesh_;
};

class DtensorFromLocalGradNode : public egr::GradNodeBase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,12 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
op_name.compare(paddle::dialect::AllGatherOp::name()) == 0 ||
op_name.compare(paddle::dialect::MpAllreduceSum_Op::name()) == 0 ||
op_name.compare(paddle::dialect::CIdentity_Op::name()) == 0 ||
op_name.compare(paddle::dialect::CConcatOp::name()) == 0) {
op_name.compare(paddle::dialect::CConcatOp::name()) == 0 ||
op_name.compare(paddle::dialect::CConcatOp::name()) == 0 ||
op_name.compare(paddle::dialect::AllGatherOp::name()) == 0 ||
op_name.compare(paddle::dialect::AllToAllOp::name()) == 0 ||
op_name.compare(
paddle::dialect::CSoftmaxWithCrossEntropyOp::name()) == 0) {
if (phi::is_gpu_place(place) && execution_stream == kDefaultStream) {
if (origin_dev_ctx != nullptr) {
// set stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,7 @@ std::unordered_map<std::string, std::set<std::string>> GetNoNeedBufferValues(
no_need_buffer_vars.insert(name);
} else {
no_need_buffer_vars.erase(name);
break;
}
}
}
Expand All @@ -1535,6 +1536,7 @@ std::unordered_map<std::string, std::set<std::string>> GetNoNeedBufferValues(
no_need_buffer_vars.insert(name);
} else {
no_need_buffer_vars.erase(name);
break;
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/distributed/ir/dist_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class DtensorToLocalOp
// void VerifySig();
};

class MoESubMeshTensorsOp : public pir::Op<MoESubMeshTensorsOp> {
class MoESubMeshTensorsOp : public pir::Op<MoESubMeshTensorsOp, VjpInterface> {
public:
using Op::Op;
static const char* name() { return "dist_op.moe_sub_mesh_tensors"; }
Expand Down
19 changes: 11 additions & 8 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ def _cal_global_shape(local_shape, mesh, placements):
def moe_global_mesh_tensor(
local_tensor_list, mesh, placements, local_mesh_dim=-1
):
placements = copy.deepcopy(placements)
local_mesh_list, local_placements = _get_sub_meshes_and_local_placements(
mesh, placements, local_mesh_dim
)
Expand Down Expand Up @@ -548,16 +549,17 @@ def moe_global_mesh_tensor(
global_dims = _cal_global_shape(
local_tensor._local_shape, mesh, placements
)
return paddle.jit.dy2static.py_layer.StaticPyLayer(
_moe_global_mesh_tensor
).apply(
dist_tensor = paddle._C_ops.moe_global_mesh_tensor(
local_tensor_list,
local_mesh_list,
local_placements,
mesh,
placements,
global_dims,
)
dist_tensor.stop_gradient = local_tensor_list[0].stop_gradient
dist_tensor.persistable = local_tensor_list[0].persistable
return dist_tensor
else:
raise NotImplementedError(
"dtensor_from_local_list() are only supported in dynamic and pir mode."
Expand Down Expand Up @@ -691,6 +693,7 @@ def moe_sub_mesh_tensors(
"""
Get the local part of the ``dist_tensor`` on the specific ``local_mesh_dim``.
"""
global_placements = copy.deepcopy(global_placements)
local_mesh_list, local_placements = _get_sub_meshes_and_local_placements(
global_mesh, global_placements, local_mesh_dim
)
Expand All @@ -705,17 +708,17 @@ def moe_sub_mesh_tensors(
global_placements,
)
elif paddle.framework.in_pir_mode():

return paddle.jit.dy2static.py_layer.StaticPyLayer(
_moe_sub_mesh_tensors
).apply(
local_tensors = paddle._C_ops.moe_sub_mesh_tensors(
dist_tensor,
local_mesh_list,
local_placements,
local_mesh_dim,
global_mesh,
global_placements,
)
for local_tensor in local_tensors:
local_tensor.stop_gradient = dist_tensor.stop_gradient
local_tensor.persistable = dist_tensor.persistable
return local_tensors
else:
raise NotImplementedError(
"moe_sub_mesh_tensors is only supported in dynamic mode."
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/distributed/auto_parallel/placement_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ def to_dim_map(placements, tensor_dims):
if placement.is_shard():
shard_dim = cast(Shard, placement).get_dim()
if dim_map[shard_dim] > -1:
raise Exception(
"Tensor dim {shard_dim} is already sharded on mesh dim {dim_map[shard_dim]}"
import logging

logging.warning(
f"Tensor dim {shard_dim} is already sharded on mesh dim {dim_map[shard_dim]}."
)

dim_map[shard_dim] = i
Expand Down
25 changes: 17 additions & 8 deletions test/auto_parallel/pir/test_moe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,23 @@ def check_results(
local_meshes,
local_dims_mapping,
):
# local_tensors_from_dtensor op
self.check_dist_attr(ops[4], local_meshes, local_dims_mapping)
# dtensor_from_local_list op
self.check_dist_attr(ops[5], [global_mesh], global_dims_mapping)
# grad op for dtensor_from_local_list
self.check_dist_attr(ops[10], local_meshes, local_dims_mapping)
# grad op for local_tensors_from_dtensor op
self.check_dist_attr(ops[11], [global_mesh], global_dims_mapping)
op_names = [
"dist_op.moe_sub_mesh_tensors",
"dist_op.moe_global_mesh_tensor",
]
ops_to_check = [op for op in ops if op.name() in op_names]
# moe_sub_mesh_tensors op
self.check_dist_attr(ops_to_check[0], local_meshes, local_dims_mapping)
# moe_global_mesh_tensor op
self.check_dist_attr(
ops_to_check[1], [global_mesh], global_dims_mapping
)
# grad op for moe_global_mesh_tensor
self.check_dist_attr(ops_to_check[2], local_meshes, local_dims_mapping)
# grad op for moe_sub_mesh_tensors op
self.check_dist_attr(
ops_to_check[3], [global_mesh], global_dims_mapping
)


if __name__ == "__main__":
Expand Down

0 comments on commit 79dea17

Please sign in to comment.