diff --git a/model_zoo/gpt-3/ppfleetx/distributed/apis/env.py b/model_zoo/gpt-3/ppfleetx/distributed/apis/env.py index 3df1adce58d1..2383f3d65e2c 100644 --- a/model_zoo/gpt-3/ppfleetx/distributed/apis/env.py +++ b/model_zoo/gpt-3/ppfleetx/distributed/apis/env.py @@ -31,33 +31,67 @@ def set_seed(seed): + # NOTE(shenliang03): For parameter init seed: + # seed: dp/mp_undistributed_paramter/sharding is same; others is different + # For compute seed(dropout): + # global seed: only mp group is same. + # local seed: all groups are different + if dist.get_world_size() > 1: # obtain rank message of hybrid parallel hcg = get_hcg() + mp_rank = hcg.get_model_parallel_rank() + mp_size = hcg.get_model_parallel_world_size() + pp_rank = hcg.get_stage_id() - data_world_rank = get_data_world_rank() - data_world_size = get_data_world_size() + pp_size = hcg.get_pipe_parallel_world_size() + + dp_rank = hcg.get_data_parallel_rank() + dp_size = hcg.get_data_parallel_world_size() + + sharding_rank = hcg.get_sharding_parallel_rank() + # sharding_size = hcg.get_sharding_parallel_world_size() else: - mp_rank, pp_rank, data_world_rank, data_world_size = 0, 0, 0, 1 + mp_rank, mp_size = 0, 1 + pp_rank, pp_size = 0, 1 + dp_rank, dp_size = 0, 1 + sharding_rank, _ = 0, 1 # NOTE: the commented seeds are set only for precision validation # seed += 100 * pp_rank - # random.seed(seed) - # np.random.seed(seed) - # paddle.seed(seed) - - random.seed(seed + data_world_rank) - np.random.seed(seed + data_world_rank) - paddle.seed(seed + data_world_rank) + random.seed(seed + 100 * pp_rank) + np.random.seed(seed + 100 * pp_rank) + + # seed = mp_rank + + # pp_rank * (mp_size) + + # dp_rank * (mp_size * pp_size) + + # sharding_rank * (mp_size * pp_size * dp_size) + # seed offset is order to avoid conflicts with the parameter initialization seed + + seed_offset = seed + 1024 + paddle.distributed.get_world_size() + global_seed = ( + seed_offset + + pp_rank * (mp_size) + + dp_rank * (mp_size * pp_size) + + sharding_rank * (mp_size * pp_size * dp_size) + ) + + seed_offset += paddle.distributed.get_world_size() + local_seed = ( + seed_offset + + mp_rank + + pp_rank * (mp_size) + + dp_rank * (mp_size * pp_size) + + sharding_rank * (mp_size * pp_size * dp_size) + ) - # local_seed/ global_seed is used to control dropout in ModelParallel - local_seed = seed + 123 + mp_rank * 10 + pp_rank * 1000 + data_world_size - global_seed = seed + data_world_rank tracker = get_rng_state_tracker() tracker.add("global_seed", global_seed) tracker.add("local_seed", local_seed) + paddle.seed(global_seed) + logger.info("The global seed is set to {} and local seed is set to {}.".format(global_seed, local_seed)) global _seed