Skip to content

Commit

Permalink
support sharding of master weights in dynamic mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Waynezee committed Mar 9, 2025
1 parent bcfa081 commit 0735da2
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,8 @@ def _shard_accumulator(self, param):
self._inner_opt._master_weights[param.name] = (
self._shard_fn.shard_master_weight(param, master_weight)
)
self._inner_opt._master_weights[param.name].name = target_name

# shard the accumulators
for key in self._inner_opt._accumulators.keys():
accumulator = self._inner_opt._accumulators[key][target_name]
Expand Down Expand Up @@ -1414,14 +1416,12 @@ def shard_master_weight(
self, param: Tensor, master_weight: Tensor
) -> Tensor:
if param.is_dist():
placements = get_placement_with_sharding(param, self._sharding_axis)
if isinstance(master_weight, pir.Value):
data_op = master_weight.get_defining_op()
assert (
data_op.name() == "pd_op.data"
), "The master weight must be a result of data op."
placements = get_placement_with_sharding(
param, self._sharding_axis
)
dim_map, partial_status = to_dim_map(
placements, len(master_weight.shape)
)
Expand All @@ -1439,6 +1439,13 @@ def shard_master_weight(
param.process_mesh, [], [dist_attr]
)
)

if paddle.in_dynamic_mode() and master_weight.is_dist():
master_weight = reshard(
master_weight,
mesh=param.process_mesh,
placements=placements,
)
return master_weight


Expand Down

0 comments on commit 0735da2

Please sign in to comment.