From 3bdedabb80c0aa590f1e376fa812cada80d97af9 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Wed, 31 Jul 2024 15:30:08 +0800 Subject: [PATCH 01/30] add --- paddlenlp/trainer/checkpoint_converter.py | 450 ++++++++++++++++++++++ 1 file changed, 450 insertions(+) create mode 100644 paddlenlp/trainer/checkpoint_converter.py diff --git a/paddlenlp/trainer/checkpoint_converter.py b/paddlenlp/trainer/checkpoint_converter.py new file mode 100644 index 000000000000..6432a8793f3e --- /dev/null +++ b/paddlenlp/trainer/checkpoint_converter.py @@ -0,0 +1,450 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import re +from functools import reduce + +import numpy as np +import paddle +import paddle.distributed as dist +from paddle.distributed.checkpoint.metadata import ( + LocalTensorIndex, + LocalTensorMetadata, + Metadata, +) + +MODEL_WEIGHT_SUFFIX = ".pdparams" +OPTIMIZER_WEIGHT_SUFFIX = ".pdopt" +SCHEDULER_NAME = "scheduler.pdparams" +MODEL_META_FILE_NAME = "model_meta.json" + + +def flatten_list(l): + return [item for sublist in l for item in sublist] + + +class DynamicToStaticShardingV2CheckpointConverter: + def __init__(self, dynamic_ckpt_path, model_state_global_shape): + self.path = dynamic_ckpt_path + self.model_state_global_shape = model_state_global_shape + self.model_meta = json.load(open(os.path.join(dynamic_ckpt_path, MODEL_META_FILE_NAME))) + ( + self.cur_rank_model_state_file_names, + self.cur_rank_optimizer_state_file_names, + ) = self.get_local_checkpoint_file_names() + self.cur_rank_loaded_state_dict = {} + ( + self.global_model_state_file_names, + self.global_optimizer_state_file_names, + ) = self.get_all_checkpoint_file_names() + + def get_local_checkpoint_file_names(self): + cur_rank_files = os.listdir(self.path) + cur_rank_model_state_file_names = [] + cur_rank_optimizer_state_file_names = [] + global_model_state_file_names = [] + global_optimizer_state_file_names = [] + for file in cur_rank_files: + if file.endswith(MODEL_WEIGHT_SUFFIX): + cur_rank_model_state_file_names.append(file) + elif file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + cur_rank_optimizer_state_file_names.append(file) + if SCHEDULER_NAME in cur_rank_model_state_file_names: + cur_rank_model_state_file_names.remove(SCHEDULER_NAME) + return cur_rank_model_state_file_names, cur_rank_optimizer_state_file_names + + def get_all_checkpoint_file_names(self): + cur_rank_model_state_file_names, cur_rank_optimizer_state_file_names = self.get_local_checkpoint_file_names() + use_dist = True if paddle.distributed.get_world_size() > 1 else False + global_model_state_file_names = [] + global_optimizer_state_file_names = [] + if use_dist: + paddle.distributed.init_parallel_env() + paddle.distributed.all_gather_object(global_model_state_file_names, cur_rank_model_state_file_names) + paddle.distributed.all_gather_object( + global_optimizer_state_file_names, cur_rank_optimizer_state_file_names + ) + else: + global_model_state_file_names = [cur_rank_model_state_file_names] + global_optimizer_state_file_names = [cur_rank_optimizer_state_file_names] + + return global_model_state_file_names, global_optimizer_state_file_names + + def get_local_load_files(self, rank_to_files): + import copy + + file_to_ranks = {} + for rank, files in rank_to_files.items(): + for file in files: + if file not in file_to_ranks: + file_to_ranks[file] = [] + file_to_ranks[file].append(rank) + rank_to_not_read_files = copy.copy(rank_to_files) + rank_to_read_files = {rank: [] for rank in rank_to_not_read_files.keys()} + for file, ranks in file_to_ranks.items(): + if len(ranks) == 1: + rank = ranks[0] + rank_to_read_files[rank].append(file) + rank_to_not_read_files[rank].remove(file) + if len(rank_to_not_read_files[rank]) == 0: + rank_to_not_read_files.pop(rank) + + def get_least_read_files_ranks(rank_to_read_files): + nums = [(rank, len(files)) for rank, files in rank_to_read_files.items()] + nums = sorted(nums, key=lambda x: x[1]) + ranks = [rank for rank, num in nums if num == nums[0][1]] + return ranks + + def get_read_rank_file(rank_to_not_read_files, ranks): + if len(rank_to_not_read_files) == 0: + return (None, None) + nums = [(rank, len(files)) for rank, files in rank_to_not_read_files.items() if rank in ranks] + nums = sorted(nums, key=lambda x: x[1]) + rank = nums[0][0] + return (rank, rank_to_not_read_files[rank][0]) + + def update(rank_to_read_files, rank_to_not_read_files, rank_file): + rank, file = rank_file + if rank is None and file is None: + return + if rank not in rank_to_read_files: + rank_to_read_files[rank] = [] + rank_to_read_files[rank].append(file) + # update rank_to_not_read_files + file_to_ranks = {} + for r, files in rank_to_not_read_files.items(): + for f in files: + if f not in file_to_ranks: + file_to_ranks[f] = [] + file_to_ranks[f].append(r) + + if file in file_to_ranks: + for r in file_to_ranks[file]: + rank_to_not_read_files[r].remove(file) + if len(rank_to_not_read_files[r]) == 0: + rank_to_not_read_files.pop(r) + + while len(rank_to_not_read_files) > 0: + ranks = get_least_read_files_ranks(rank_to_read_files) + rank_file = get_read_rank_file(rank_to_not_read_files, ranks) + update(rank_to_read_files, rank_to_not_read_files, rank_file) + + cur_rank = paddle.distributed.get_rank() + if cur_rank in rank_to_read_files: + return rank_to_read_files[cur_rank] + else: + return [] + + def extract_distribution_strategy_from_file_name(self, file_name): + pp_degree = 0 + tp_degree = 0 + sharding_degree = 0 + pattern_pp = r"pp(\d+)" + pattern_tp = r"tp(\d+)" + pattern_shard = r"shard(\d+)" + match_pp = re.search(pattern_pp, file_name) + if match_pp: + pp_degree = int(match_pp.group(1)) + match_tp = re.search(pattern_tp, file_name) + if match_tp: + tp_degree = int(match_tp.group(1)) + match_shard = re.search(pattern_shard, file_name) + if match_shard: + sharding_degree = int(match_shard.group(1)) + return (tp_degree, pp_degree, sharding_degree) + + def gen_matadata_for_optimizer(self): + rank_access_files = {} + + for rank in range(paddle.distributed.get_world_size()): + rank_access_files[rank] = ( + self.global_model_state_file_names[rank] + self.global_optimizer_state_file_names[rank] + ) + + # Determine which files need to be read for each rank. + # When each node has a checkpoint path, it will fail + need_read_files = self.get_local_load_files(rank_access_files) + + sharded_tensor_infos = {} + file_to_model_state_names = {} + file_to_optimizer_state_names = {} + model_state_dict_info = {} + cur_rank_sharded_tensor_infos = {} + + for file in need_read_files: + if OPTIMIZER_WEIGHT_SUFFIX in file: + state_dict = paddle.load(os.path.join(self.path, file), return_numpy=True) + state_dict.pop("LR_Scheduler") + master_weights = state_dict.pop("master_weights") + # Extract master weights + for k, v in master_weights.items(): + state_dict[k] = v + # Based on the checkpoint file name, determine the pp_degree, tp_degree, and sharding_degree of the tensor in the current file. + distributed_rank = self.extract_distribution_strategy_from_file_name(file) + dist_strategy_key = ( + "tp" + "{:02d}".format(distributed_rank[0]) + "_" + "pp" + "{:02d}".format(distributed_rank[1]) + ) + + # Map model weight names to their corresponding names of master_weights in the optimizer state. + structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"] + + # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, and then append the tp_degree. + renamed_state_dict = {} + for k, v in state_dict.items(): + for prame_name, opt_name in structure_name_mapping.items(): + if opt_name in k: + new_key = k.replace(opt_name, prame_name) + "_tp" + "{:02d}".format(distributed_rank[0]) + else: + new_key = k.replace(opt_name, prame_name) + renamed_state_dict[new_key] = v + # Calculate the local_shape. + cur_rank_sharded_tensor_infos[(new_key, file)] = [v.shape, str(v.dtype)] + + # Cache the renamed state dict + self.cur_rank_loaded_state_dict[file] = renamed_state_dict + + use_dist = True if paddle.distributed.get_world_size() > 1 else False + + # Obtain the local_shape information of the tensor on all ranks. + all_rank_sharded_tensor_infos = [] + if use_dist: + paddle.distributed.all_gather_object(all_rank_sharded_tensor_infos, cur_rank_sharded_tensor_infos) + else: + all_rank_sharded_tensor_infos = [cur_rank_sharded_tensor_infos] + + global_sharded_tensor_infos = {} + for rank_sharded_tensor_infos in all_rank_sharded_tensor_infos: + for k, v in rank_sharded_tensor_infos.items(): + if k not in global_sharded_tensor_infos: + global_sharded_tensor_infos[k] = v + + # Collect sharding information. + key_to_sharded_info = {} + for k, v in global_sharded_tensor_infos.items(): + distributed_rank = self.extract_distribution_strategy_from_file_name(k[1]) + if k[0] not in key_to_sharded_info: + key_to_sharded_info[k[0]] = [[distributed_rank[2], v[0], v[1], k[1]]] + else: + key_to_sharded_info[k[0]].append([distributed_rank[2], v[0], v[1], k[1]]) + + # x[0] records the sharding rank. + for k, v in key_to_sharded_info.items(): + v.sort(key=lambda x: x[0]) + + state_dict_metadata = {} + storage_metadata = {} + + # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. + for k, v in key_to_sharded_info.items(): + global_offset = 0 + for item in v: + local_tensor_meta_data = LocalTensorMetadata((global_offset,), item[1], item[2]) + local_tensor_index = LocalTensorIndex(k, (global_offset,)) + global_offset += item[1][0] + if k not in state_dict_metadata: + state_dict_metadata[k] = [local_tensor_meta_data] + else: + state_dict_metadata[k].append(local_tensor_meta_data) + storage_metadata[local_tensor_index] = item[3] + ".distcp" + + # Save the metadata and the renamed tensor read by this rank. + metadata = Metadata(state_dict_metadata, storage_metadata, None) + write_path = os.path.join(self.path, "tmp") + for file in self.cur_rank_loaded_state_dict: + paddle.save(self.cur_rank_loaded_state_dict[file], os.path.join(write_path, file + ".distcp")) + if 0 == paddle.distributed.get_rank(): + paddle.save(metadata, os.path.join(write_path, "0.metadata")) + + def concat_optimier_state_dict(self): + # Obtain the global_shape passed in semi-automatic parallel mode on each card in the static graph. + all_rank_model_state_global_shapes = [] + use_dist = True if paddle.distributed.get_world_size() > 1 else False + if use_dist: + paddle.distributed.all_gather_object(all_rank_model_state_global_shapes, self.model_state_global_shape) + else: + all_rank_model_state_global_shapes = [self.model_state_global_shape] + + self.model_state_global_shape = {} + for rank_model_state_global_shape in all_rank_model_state_global_shapes: + for k, v in rank_model_state_global_shape.items(): + self.model_state_global_shape[k] = v + + # Obtain the names and shapes of all model parameters. + global_model_state_shapes = {} + sharding_metas_keys = [] + pp_degree = self.model_meta["parallel_config"]["pp_degree"] + mp_degree = self.model_meta["parallel_config"]["mp_degree"] + for i in range(pp_degree): + for j in range(mp_degree): + sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j)) + for key in sharding_metas_keys: + param_meta = self.model_meta["sharding_metas"][key]["param_meta"] + for k, v in param_meta.items(): + global_model_state_shapes[k] = v[0] + + use_dist = True if paddle.distributed.get_world_size() > 1 else False + world_size = paddle.distributed.get_world_size() + + # Distribute all model parameters evenly across each card for loading + global_model_state_flattened_shapes = {} + global_model_state_size = 0 + for k, v in global_model_state_shapes.items(): + flattened_size = reduce(lambda x, y: x * y, v) + global_model_state_size += flattened_size + global_model_state_flattened_shapes[k] = flattened_size + + partition_model_state_keys = [] + avg_size = global_model_state_size // world_size + + cur_rank_model_state_keys = [] + + cur_rank_size = 0 + for k, v in global_model_state_flattened_shapes.items(): + cur_rank_size += v + cur_rank_model_state_keys.append(k) + if cur_rank_size > avg_size: + partition_model_state_keys.append(cur_rank_model_state_keys) + cur_rank_model_state_keys = [] + cur_rank_size = 0 + + # Since an absolutely even distribution is not achievable, some tanks may not need to load, but the load_state_dict interface might throw an error. Therefore, it is necessary to forcefully assign a parameter. + nend_append = world_size - len(partition_model_state_keys) + for i in range(nend_append): + partition_model_state_keys.append([partition_model_state_keys[0][0]]) + + cur_rank = paddle.distributed.get_rank() + + cur_rank_need_load_model_state_keys = partition_model_state_keys[cur_rank] + + # Generate the optimizer states corresponding to the model weights. + optimizer_state_dict = {} + for key in cur_rank_need_load_model_state_keys: + for tp_rank in range(self.model_meta["parallel_config"]["mp_degree"]): + tp_rank_suffix = "_tp{:02d}".format(tp_rank) + optimizer_state_dict[key + "_fp32_master_0_moment1_0" + tp_rank_suffix] = paddle.zeros( + (global_model_state_flattened_shapes[key],), "float32" + ) + optimizer_state_dict[key + "_fp32_master_0_moment2_0" + tp_rank_suffix] = paddle.zeros( + (global_model_state_flattened_shapes[key],), "float32" + ) + optimizer_state_dict[key + tp_rank_suffix] = paddle.zeros( + (global_model_state_flattened_shapes[key],), "float32" + ) + optimizer_state_dict[key + "_fp32_master_0_beta1_pow_acc_0"] = paddle.zeros((1,), "float32") + optimizer_state_dict[key + "_fp32_master_0_beta2_pow_acc_0"] = paddle.zeros((1,), "float32") + + dist.load_state_dict(optimizer_state_dict, os.path.join(self.path, "tmp")) + + # Reshape + for k, v in optimizer_state_dict.items(): + if v.shape[0] > 1 and "_tp" in k: + for master_weight_key, shape in global_model_state_shapes.items(): + if master_weight_key in k: + reshaped_v = v.reshape(shape) + optimizer_state_dict[k] = reshaped_v + + concat_optimier_state_dict = {} + + optimizer_state_key_to_tp_keys = {} + for key in optimizer_state_dict.keys(): + # Count how each key is split into keys ending with ‘_tpXX’. + # optimizer_state_key_to_tp_keys : {key:[key_tp00,key_tp01]} + if "_pow_acc_0" not in key: + if key[:-5] not in optimizer_state_key_to_tp_keys: + optimizer_state_key_to_tp_keys[key[:-5]] = [key] + else: + optimizer_state_key_to_tp_keys[key[:-5]].append(key) + else: + optimizer_state_key_to_tp_keys[key] = [key] + for key, value in optimizer_state_key_to_tp_keys.items(): + if len(value) == 1: + continue + value.sort(key=lambda x: int(x[-2:])) + + for key, tp_keys in optimizer_state_key_to_tp_keys.items(): + # Optimizer states with a shape of 1 could be replicated; here, perform a check. + is_replicated = True + tp_tensor = optimizer_state_dict[tp_keys[0]] + for tp_key in tp_keys: + if not np.array_equal(tp_tensor.numpy(), optimizer_state_dict[tp_key].numpy()): + is_replicated = False + break + if is_replicated: + concat_optimier_state_dict[key] = tp_tensor + continue + else: + tp_tensors = [] + for tp_key in tp_keys: + tp_tensors.append(optimizer_state_dict[tp_key]) + # Derive the partition strategy based on the global_shape, then concatenate. + axis = 0 + global_shape = [] + # Find the global_shape. + for k, shape in self.model_state_global_shape.items(): + if k in tp_key: + global_shape = shape + break + assert len(global_shape) != 0 + tp_shape = tp_tensors[0].shape + assert (tp_shape[0] == global_shape[0] and len(tp_tensors) * tp_shape[1] == global_shape[1]) or ( + tp_shape[1] == global_shape[1] and len(tp_tensors) * tp_shape[0] == global_shape[0] + ) + if tp_shape[0] == global_shape[0]: + axis = 1 + concat_optimier_state_dict[key] = paddle.concat(tp_tensors, axis=axis) + + file_name = "{:02d}".format(cur_rank) + ".distcp" + local_tensor_meta_data = {} + local_tensor_index = {} + for k, v in concat_optimier_state_dict.items(): + # Generate metadata. + local_shape = v.shape + global_offset = tuple([0] * len(local_shape)) + dtype = str(v.dtype) + local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) + local_tensor_index[k] = [LocalTensorIndex(k, global_offset), file_name] + + global_local_tensor_meta_data = [] + global_local_tensor_index = [] + + if use_dist: + paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) + paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) + else: + global_local_tensor_meta_data = [local_tensor_meta_data] + global_local_tensor_index = [local_tensor_index] + + state_dict_metadata = {} + for tensor_meta_data in global_local_tensor_meta_data: + for k, v in tensor_meta_data.items(): + if k not in state_dict_metadata: + state_dict_metadata[k] = [v] + else: + state_dict_metadata[k].append(v) + + storage_metadata = {} + for tensor_index in global_local_tensor_index: + for k, v in tensor_index.items(): + storage_metadata[v[0]] = v[1] + + meta_data = Metadata(state_dict_metadata, storage_metadata, None) + + save_path = os.path.join(self.path, "tmp2") + + if cur_rank == 0: + paddle.save(meta_data, os.path.join(save_path, "0.metadata")) + + paddle.save(concat_optimier_state_dict, os.path.join(save_path, file_name)) From 280b13f2edbfe674ed71702b6633bcbe3b7d7b3d Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Wed, 7 Aug 2024 20:45:04 +0800 Subject: [PATCH 02/30] Add the checkpoint conversion module --- paddlenlp/trainer/checkpoint_converter.py | 1201 +++++++++++++++------ 1 file changed, 856 insertions(+), 345 deletions(-) diff --git a/paddlenlp/trainer/checkpoint_converter.py b/paddlenlp/trainer/checkpoint_converter.py index 6432a8793f3e..00235c42a870 100644 --- a/paddlenlp/trainer/checkpoint_converter.py +++ b/paddlenlp/trainer/checkpoint_converter.py @@ -17,9 +17,11 @@ import re from functools import reduce -import numpy as np import paddle -import paddle.distributed as dist +from paddle.distributed.checkpoint.load_state_dict import ( + _load_state_dict, + get_local_load_files, +) from paddle.distributed.checkpoint.metadata import ( LocalTensorIndex, LocalTensorMetadata, @@ -32,31 +34,60 @@ MODEL_META_FILE_NAME = "model_meta.json" -def flatten_list(l): - return [item for sublist in l for item in sublist] +class CheckpointConverter: + def __init__(self, dynamic_ckpt_path, model_state, parameter_to_structured_name): + self.use_dist = True if paddle.distributed.get_world_size() > 1 else False + self.path = dynamic_ckpt_path + self.semi_auto_model_state = model_state + self.parameter_to_structured_name = parameter_to_structured_name + model_state_global_shape = {} + for k, v in model_state.items(): + model_state_global_shape[k] = v.shape + self.model_state_global_shape = self.gather_global_object(model_state_global_shape) + self.cur_rank = paddle.distributed.get_rank() + self.save_sharded_model = self.get_save_sharded_model_flag() -class DynamicToStaticShardingV2CheckpointConverter: - def __init__(self, dynamic_ckpt_path, model_state_global_shape): - self.path = dynamic_ckpt_path - self.model_state_global_shape = model_state_global_shape - self.model_meta = json.load(open(os.path.join(dynamic_ckpt_path, MODEL_META_FILE_NAME))) ( self.cur_rank_model_state_file_names, self.cur_rank_optimizer_state_file_names, ) = self.get_local_checkpoint_file_names() - self.cur_rank_loaded_state_dict = {} - ( - self.global_model_state_file_names, - self.global_optimizer_state_file_names, - ) = self.get_all_checkpoint_file_names() + + self.global_model_state_file_names = self.gather_global_object(self.cur_rank_model_state_file_names) + + self.global_optimizer_state_file_names = self.gather_global_object(self.cur_rank_optimizer_state_file_names) + + self.initial_distributed_configuration() + + def get_save_sharded_model_flag(self): + if self.cur_rank == 1: + save_sharded_model_flag = [os.path.exists(os.path.join(self.path, MODEL_META_FILE_NAME))] + else: + save_sharded_model_flag = [] + save_sharded_model_flag = self.gather_global_object(save_sharded_model_flag) + return save_sharded_model_flag[0] + + def gather_global_object(self, cur_rank_object): + all_rank_objects = [] + if self.use_dist: + paddle.distributed.all_gather_object(all_rank_objects, cur_rank_object) + else: + all_rank_objects = [all_rank_objects] + + if isinstance(cur_rank_object, list): + return [item for sublist in all_rank_objects for item in sublist] + elif isinstance(cur_rank_object, dict): + global_map = {} + for rank_map in all_rank_objects: + global_map.update(rank_map) + return global_map + else: + raise ValueError("cur_rank_object should be either a list or a dict") def get_local_checkpoint_file_names(self): cur_rank_files = os.listdir(self.path) cur_rank_model_state_file_names = [] cur_rank_optimizer_state_file_names = [] - global_model_state_file_names = [] - global_optimizer_state_file_names = [] for file in cur_rank_files: if file.endswith(MODEL_WEIGHT_SUFFIX): cur_rank_model_state_file_names.append(file) @@ -66,89 +97,7 @@ def get_local_checkpoint_file_names(self): cur_rank_model_state_file_names.remove(SCHEDULER_NAME) return cur_rank_model_state_file_names, cur_rank_optimizer_state_file_names - def get_all_checkpoint_file_names(self): - cur_rank_model_state_file_names, cur_rank_optimizer_state_file_names = self.get_local_checkpoint_file_names() - use_dist = True if paddle.distributed.get_world_size() > 1 else False - global_model_state_file_names = [] - global_optimizer_state_file_names = [] - if use_dist: - paddle.distributed.init_parallel_env() - paddle.distributed.all_gather_object(global_model_state_file_names, cur_rank_model_state_file_names) - paddle.distributed.all_gather_object( - global_optimizer_state_file_names, cur_rank_optimizer_state_file_names - ) - else: - global_model_state_file_names = [cur_rank_model_state_file_names] - global_optimizer_state_file_names = [cur_rank_optimizer_state_file_names] - - return global_model_state_file_names, global_optimizer_state_file_names - - def get_local_load_files(self, rank_to_files): - import copy - - file_to_ranks = {} - for rank, files in rank_to_files.items(): - for file in files: - if file not in file_to_ranks: - file_to_ranks[file] = [] - file_to_ranks[file].append(rank) - rank_to_not_read_files = copy.copy(rank_to_files) - rank_to_read_files = {rank: [] for rank in rank_to_not_read_files.keys()} - for file, ranks in file_to_ranks.items(): - if len(ranks) == 1: - rank = ranks[0] - rank_to_read_files[rank].append(file) - rank_to_not_read_files[rank].remove(file) - if len(rank_to_not_read_files[rank]) == 0: - rank_to_not_read_files.pop(rank) - - def get_least_read_files_ranks(rank_to_read_files): - nums = [(rank, len(files)) for rank, files in rank_to_read_files.items()] - nums = sorted(nums, key=lambda x: x[1]) - ranks = [rank for rank, num in nums if num == nums[0][1]] - return ranks - - def get_read_rank_file(rank_to_not_read_files, ranks): - if len(rank_to_not_read_files) == 0: - return (None, None) - nums = [(rank, len(files)) for rank, files in rank_to_not_read_files.items() if rank in ranks] - nums = sorted(nums, key=lambda x: x[1]) - rank = nums[0][0] - return (rank, rank_to_not_read_files[rank][0]) - - def update(rank_to_read_files, rank_to_not_read_files, rank_file): - rank, file = rank_file - if rank is None and file is None: - return - if rank not in rank_to_read_files: - rank_to_read_files[rank] = [] - rank_to_read_files[rank].append(file) - # update rank_to_not_read_files - file_to_ranks = {} - for r, files in rank_to_not_read_files.items(): - for f in files: - if f not in file_to_ranks: - file_to_ranks[f] = [] - file_to_ranks[f].append(r) - - if file in file_to_ranks: - for r in file_to_ranks[file]: - rank_to_not_read_files[r].remove(file) - if len(rank_to_not_read_files[r]) == 0: - rank_to_not_read_files.pop(r) - - while len(rank_to_not_read_files) > 0: - ranks = get_least_read_files_ranks(rank_to_read_files) - rank_file = get_read_rank_file(rank_to_not_read_files, ranks) - update(rank_to_read_files, rank_to_not_read_files, rank_file) - - cur_rank = paddle.distributed.get_rank() - if cur_rank in rank_to_read_files: - return rank_to_read_files[cur_rank] - else: - return [] - - def extract_distribution_strategy_from_file_name(self, file_name): + def get_distribution_rank_from_file_name(self, file_name): pp_degree = 0 tp_degree = 0 sharding_degree = 0 @@ -166,285 +115,847 @@ def extract_distribution_strategy_from_file_name(self, file_name): sharding_degree = int(match_shard.group(1)) return (tp_degree, pp_degree, sharding_degree) - def gen_matadata_for_optimizer(self): - rank_access_files = {} + def initial_distributed_configuration(self): + self.pp_degree = 0 + self.tp_degree = 0 + self.sharding_degree = 0 + + all_files = self.global_model_state_file_names + self.global_optimizer_state_file_names + + for file in all_files: + (tp_degree, pp_degree, sharding_degree) = self.get_distribution_rank_from_file_name(file) + self.pp_degree = max(self.pp_degree, pp_degree) + self.tp_degree = max(self.tp_degree, tp_degree) + self.sharding_degree = max(self.sharding_degree, sharding_degree) + + self.pp_degree = self.pp_degree + 1 + self.tp_degree = self.tp_degree + 1 + self.sharding_degree = self.sharding_degree + 1 + + def infer_sharding_stage1_v(self): + sharding_stage1_v = [2] + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX) and sharding_stage1_v[0] == 2: + for k, v in state_dict.items(): + # Under shardingv2, the optimizer state is first flattened and then split. + if "_fp32_master_0_moment" in k and len(v.shape) != 1: + sharding_stage1_v = [1] + break + + sharding_stage1_v = self.gather_global_object(sharding_stage1_v) + if 1 in sharding_stage1_v: + return 1 + return 2 + + def infer_is_sharding_stage3(self): + if self.sharding_degree == 1: + return False + if self.pp_degree > 1 or self.tp_degree > 1: + # Currently, sharding stage 3 does not support concurrent use with tensor parallelism (TP) and pipeline parallelism (PP). + return False - for rank in range(paddle.distributed.get_world_size()): - rank_access_files[rank] = ( - self.global_model_state_file_names[rank] + self.global_optimizer_state_file_names[rank] + is_sharding_stage3 = True + + file_to_state_shape_mapping = {} + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + state_shape_mapping = {} + for k, v in state_dict.items(): + state_shape_mapping[k] = v.shape + file_to_state_shape_mapping[file] = state_shape_mapping + global_file_to_state_shape_mapping = self.gather_global_object(file_to_state_shape_mapping) + + state_dict_std = global_file_to_state_shape_mapping[list(global_file_to_state_shape_mapping.keys())[0]] + + for file, state_dict in global_file_to_state_shape_mapping.items(): + if state_dict != state_dict_std: + is_sharding_stage3 = False + break + return is_sharding_stage3 + + def optimizer_state_name_to_master_weight_name(self, optimizer_state_name): + return optimizer_state_name.split(".")[0] + + def optimizer_state_file_name_to_model_state_file_name(self, optimizer_state_file_name): + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(optimizer_state_file_name) + for model_state_file in self.global_model_state_file_names: + distributed_rank = self.get_distribution_rank_from_file_name(model_state_file) + if tp_rank == distributed_rank[0] and pp_rank == distributed_rank[1]: + return model_state_file + return None + + def optimizer_key_to_model_state_key(self, optimizer_key): + adamw_optimizer_key_suffix = [ + ".w_0_fp32_master_0_beta1_pow_acc_0", + ".w_0_fp32_master_0_beta2_pow_acc_0", + ".w_0_fp32_master_0_moment1_0", + ".w_0_fp32_master_0_moment2_0", + ".w_0", + ] + model_state_key = optimizer_key + for suffix in adamw_optimizer_key_suffix: + if model_state_key.endswith(suffix): + # Remove the suffix from model_state_key + model_state_key = model_state_key[: -len(suffix)] + break + return model_state_key + + def partition_parameters(self, model_state_shapes, is_sort, shard_num): + """ + Partitions parameters among sharding ranks. + + Return: + Dict[int, List] + """ + # Copy from python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py + mapping = {} + for rank_ in range(shard_num): + mapping[rank_] = [] + sizes = [0] * shard_num + + parameters = model_state_shapes.copy() + + if is_sort: + parameters.sort(key=lambda p: reduce(lambda x, y: x * y, p[1]), reverse=True) + + for param in parameters: + rank = sizes.index(min(sizes)) + mapping[rank].append(param) + numel = reduce(lambda x, y: x * y, param[1], 1) + assert numel > 0, f"param [{param[0]}] should larger than 0, but it is [{numel}]" + sizes[rank] += numel + + return mapping + + def rename_using_model_meta(self, file_name): + if not hasattr(self, "model_meta"): + try: + self.model_meta = json.load(open(os.path.join(self.path, MODEL_META_FILE_NAME))) + except Exception as e: + print(e) + distributed_rank = self.get_distribution_rank_from_file_name(file_name) + dist_strategy_key = ( + "tp" + "{:02d}".format(distributed_rank[0]) + "_" + "pp" + "{:02d}".format(distributed_rank[1]) + ) + # Map model weight names to their corresponding names of master_weights in the optimizer state. + if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): + structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"] + master_weight_name_to_model_weight_name_mapping = {} + for k, v in structure_name_mapping.items(): + master_weight_name_to_model_weight_name_mapping[v.split(".")[0]] = k + + renamed_state_dict = {} + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) + state_dict = self.cur_rank_loaded_state_dict[file_name] + for k, v in state_dict.items(): + master_weight_name = self.optimizer_state_name_to_master_weight_name(k) + model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] + new_k = k.replace(master_weight_name, model_weight_name) + renamed_state_dict[new_k] = v + return renamed_state_dict + else: + return self.cur_rank_loaded_state_dict[file_name] + + def rename_using_optimizer_state_order(self, file_name): + if not hasattr(self, "global_file_to_state_dict_keys_mapping"): + file_to_state_dict_keys_mapping = {} + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + file_to_state_dict_keys_mapping[file] = list(state_dict.keys()) + + self.global_file_to_state_dict_keys_mapping = self.gather_global_object(file_to_state_dict_keys_mapping) + + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + model_state_file_name = self.optimizer_state_file_name_to_model_state_file_name(file) + assert model_state_file_name is not None + model_state_keys = self.global_file_to_state_dict_keys_mapping[model_state_file_name] + optimizer_state_keys = self.global_file_to_state_dict_keys_mapping[file] + + master_weight_name_to_model_weight_name_mapping = {} + for i in range(len(model_state_keys)): + master_weight_name = self.optimizer_state_name_to_master_weight_name(optimizer_state_keys[i]) + master_weight_name_to_model_weight_name_mapping[master_weight_name] = model_state_keys[i] + + state_dict = self.cur_rank_loaded_state_dict[file_name] + renamed_state_dict = {} + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + for k, v in state_dict.items(): + master_weight_name = self.optimizer_state_name_to_master_weight_name(k) + model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] + new_k = k.replace(master_weight_name, model_weight_name) + renamed_state_dict[new_k] = v + + return renamed_state_dict + else: + return self.cur_rank_loaded_state_dict[file_name] + + def load_state_dict_and_rename(self): + rank_access_files = {} + if self.save_sharded_model: + rank_access_files[self.cur_rank] = self.cur_rank_optimizer_state_file_names + else: + rank_access_files[self.cur_rank] = ( + self.cur_rank_model_state_file_names + self.cur_rank_optimizer_state_file_names ) - # Determine which files need to be read for each rank. - # When each node has a checkpoint path, it will fail - need_read_files = self.get_local_load_files(rank_access_files) + need_read_files = get_local_load_files(self.gather_global_object(rank_access_files)) - sharded_tensor_infos = {} - file_to_model_state_names = {} - file_to_optimizer_state_names = {} - model_state_dict_info = {} - cur_rank_sharded_tensor_infos = {} + self.cur_rank_loaded_state_dict = {} for file in need_read_files: - if OPTIMIZER_WEIGHT_SUFFIX in file: - state_dict = paddle.load(os.path.join(self.path, file), return_numpy=True) - state_dict.pop("LR_Scheduler") - master_weights = state_dict.pop("master_weights") - # Extract master weights - for k, v in master_weights.items(): - state_dict[k] = v - # Based on the checkpoint file name, determine the pp_degree, tp_degree, and sharding_degree of the tensor in the current file. - distributed_rank = self.extract_distribution_strategy_from_file_name(file) - dist_strategy_key = ( - "tp" + "{:02d}".format(distributed_rank[0]) + "_" + "pp" + "{:02d}".format(distributed_rank[1]) - ) + self.cur_rank_loaded_state_dict[file] = paddle.load(os.path.join(self.path, file)) - # Map model weight names to their corresponding names of master_weights in the optimizer state. - structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"] + file_to_master_weights_keys = {} - # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, and then append the tp_degree. - renamed_state_dict = {} - for k, v in state_dict.items(): - for prame_name, opt_name in structure_name_mapping.items(): - if opt_name in k: - new_key = k.replace(opt_name, prame_name) + "_tp" + "{:02d}".format(distributed_rank[0]) + self.optimizer_state_with_master_weights = False + + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + state_dict.pop("LR_Scheduler") + if "master_weights" in state_dict: + self.optimizer_state_with_master_weights = True + master_weights = state_dict.pop("master_weights") + file_to_master_weights_keys[file] = list(master_weights.keys()) + for k, v in master_weights.items(): + # In sharding stage3, ‘@slice’ will be added in front of the key for master_weight, which is removed here. + k = k.replace("slice@", "") + state_dict[k] = v + + # After the rank has finished loading the files it needs, it can infer sharding_stage1_v and is_sharding_stage3. + self.sharding_stage1_v = self.infer_sharding_stage1_v() + self.is_sharding_stage3 = self.infer_is_sharding_stage3() + + # In sharding stage3, the parameters need to be reordered based on whether they are sliced. + # The threshold for determining whether to slice is segment_size, with a default value of 2**20. + # However, sharding stage3 allows users to specify their own unsliced layers, which seems to be incompatible here. + if self.is_sharding_stage3: + segment_size = 2**20 + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(MODEL_WEIGHT_SUFFIX): + sliced_pramaeters = [] + unseliced_pramaeters = [] + sorted_state_dict = {} + for k, v in state_dict.items(): + if v.numel() > segment_size: + sliced_pramaeters.append(k) else: - new_key = k.replace(opt_name, prame_name) - renamed_state_dict[new_key] = v - # Calculate the local_shape. - cur_rank_sharded_tensor_infos[(new_key, file)] = [v.shape, str(v.dtype)] + unseliced_pramaeters.append(k) + for k in sliced_pramaeters + unseliced_pramaeters: + sorted_state_dict[k] = state_dict.pop(k) + self.cur_rank_loaded_state_dict[file] = sorted_state_dict - # Cache the renamed state dict - self.cur_rank_loaded_state_dict[file] = renamed_state_dict + self.global_file_to_master_weights_keys = self.gather_global_object(file_to_master_weights_keys) - use_dist = True if paddle.distributed.get_world_size() > 1 else False + # rename and record sharded_tensor_info + cur_rank_sharded_tensor_infos = {} - # Obtain the local_shape information of the tensor on all ranks. - all_rank_sharded_tensor_infos = [] - if use_dist: - paddle.distributed.all_gather_object(all_rank_sharded_tensor_infos, cur_rank_sharded_tensor_infos) + # 1. Handling the sharding stage1 v2 scenario, where the save_sharded_model flag must be enabled, independent of master_weights. + if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: + assert self.save_sharded_model + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, + # and then append the tp_degree. + renamed_state_dict = self.rename_using_model_meta(file) + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + for new_k, v in renamed_state_dict.items(): + if new_k not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[new_k] = [ + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ] + else: + cur_rank_sharded_tensor_infos[new_k].append( + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ) + + self.cur_rank_loaded_state_dict[file] = renamed_state_dict + # 2. In handling the sharding stage1 v1 scenario, the optimizer states are distributed across different ranks. + # We need to obtain the name mapping by simulating the partitioning method, without concern for the presence of master_weights. + elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: + if not self.save_sharded_model: + file_to_state_dict_shapes_mapping = {} + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + shapes = [] + for k, v in state_dict.items(): + shapes.append([k, v.shape]) + file_to_state_dict_shapes_mapping[file] = shapes + + global_file_to_state_dict_shapes_mapping = self.gather_global_object(file_to_state_dict_shapes_mapping) + + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + sharding_optimizer_state_shards = [] + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + for k, v in global_file_to_state_dict_shapes_mapping.items(): + (tp_rank_, pp_rank_, sharding_rank_) = self.get_distribution_rank_from_file_name(k) + if tp_rank == tp_rank_ and pp_rank == pp_rank_ and k.endswith(OPTIMIZER_WEIGHT_SUFFIX): + sharding_optimizer_state_shards.append([v, sharding_rank_]) + model_state_file_name = self.optimizer_state_file_name_to_model_state_file_name(file) + model_state_shapes = global_file_to_state_dict_shapes_mapping[model_state_file_name] + sharding_optimizer_state_shards.sort(key=lambda x: x[1]) + + partition_result_0 = self.partition_parameters(model_state_shapes, False, self.sharding_degree) + partition_result_1 = self.partition_parameters(model_state_shapes, True, self.sharding_degree) + + for k, v in partition_result_0.items(): + v = sorted(v, key=model_state_shapes.index) + partition_result_0[k] = v + + for k, v in partition_result_1.items(): + v = sorted(v, key=model_state_shapes.index) + partition_result_1[k] = v + + sharding_sort_parameters = False + + for i in range(len(sharding_optimizer_state_shards)): + if not sharding_sort_parameters: + state_shard = sharding_optimizer_state_shards[i][0] + partitioned_shard = partition_result_0[i] + for j in range(len(partitioned_shard)): + if partitioned_shard[j][1] != state_shard[j][1]: + sharding_sort_parameters = True + break + + if sharding_sort_parameters: + for i in range(len(sharding_optimizer_state_shards)): + state_shard = sharding_optimizer_state_shards[i][0] + partitioned_shard = partition_result_1[i] + for j in range(len(partitioned_shard)): + assert partitioned_shard[j][1] == state_shard[j][1] + + if sharding_sort_parameters: + partition_result = partition_result_1 + else: + partition_result = partition_result_0 + + master_weight_name_to_model_weight_name_mapping = {} + for i in range(len(sharding_optimizer_state_shards)): + state_shard = sharding_optimizer_state_shards[i][0] + partitioned_shard = partition_result[i] + for j in range(len(partitioned_shard)): + master_weight_name = self.optimizer_state_name_to_master_weight_name(state_shard[j][0]) + master_weight_name_to_model_weight_name_mapping[ + master_weight_name + ] = partitioned_shard[j][0] + + renamed_state_dict = {} + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + + # In this branch, sharding does not split the optimizer states; it merely relocates them to different cards. + # Therefore, the sharding information can now be directly removed. + for k, v in state_dict.items(): + master_weight_name = self.optimizer_state_name_to_master_weight_name(k) + model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] + new_k = k.replace(master_weight_name, model_weight_name) + renamed_state_dict[new_k] = v + if new_k not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[new_k] = [ + [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] + ] + else: + cur_rank_sharded_tensor_infos[new_k].append( + [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] + ) + + self.cur_rank_loaded_state_dict[file] = renamed_state_dict + else: + for k, v in state_dict.items(): + if k not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[k] = [ + [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] + ] + else: + cur_rank_sharded_tensor_infos[k].append( + [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] + ) + else: + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + renamed_state_dict = self.rename_using_model_meta(file) + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + for new_k, v in renamed_state_dict.items(): + if new_k not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[new_k] = [ + [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] + ] + else: + cur_rank_sharded_tensor_infos[new_k].append( + [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] + ) + + self.cur_rank_loaded_state_dict[file] = renamed_state_dict else: - all_rank_sharded_tensor_infos = [cur_rank_sharded_tensor_infos] - - global_sharded_tensor_infos = {} - for rank_sharded_tensor_infos in all_rank_sharded_tensor_infos: - for k, v in rank_sharded_tensor_infos.items(): - if k not in global_sharded_tensor_infos: - global_sharded_tensor_infos[k] = v - - # Collect sharding information. - key_to_sharded_info = {} - for k, v in global_sharded_tensor_infos.items(): - distributed_rank = self.extract_distribution_strategy_from_file_name(k[1]) - if k[0] not in key_to_sharded_info: - key_to_sharded_info[k[0]] = [[distributed_rank[2], v[0], v[1], k[1]]] + # 3. Handling the case of disabling sharding, independent of master_weights, but without considering the save_sharded_model flag. + if not self.save_sharded_model: + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + + renamed_state_dict = self.rename_using_optimizer_state_order(file) + + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + for new_k, v in renamed_state_dict.items(): + if new_k not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[new_k] = [ + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ] + else: + cur_rank_sharded_tensor_infos[new_k].append( + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ) + + self.cur_rank_loaded_state_dict[file] = renamed_state_dict + else: + for k, v in state_dict.items(): + if k not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[k] = [ + [ + {"tp_rank": tp_rank, "sharding_rank": -1}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ] + else: + cur_rank_sharded_tensor_infos[k].append( + [ + {"tp_rank": tp_rank, "sharding_rank": -1}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ) + else: - key_to_sharded_info[k[0]].append([distributed_rank[2], v[0], v[1], k[1]]) + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, + # and then append the tp_degree. + renamed_state_dict = self.rename_using_model_meta(file) + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + for new_k, v in renamed_state_dict.items(): + if new_k not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[new_k] = [ + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ] + else: + cur_rank_sharded_tensor_infos[new_k].append( + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ) + + self.cur_rank_loaded_state_dict[file] = renamed_state_dict + # gather global sharded tensor infos + sharded_tensor_infos = self.gather_global_object({self.cur_rank: cur_rank_sharded_tensor_infos}) + + self.global_sharded_tensor_infos = {} + for rank, sharded_tensor_info in sharded_tensor_infos.items(): + for k, v in sharded_tensor_info.items(): + if k not in self.global_sharded_tensor_infos: + self.global_sharded_tensor_infos[k] = v + else: + self.global_sharded_tensor_infos[k] += v - # x[0] records the sharding rank. - for k, v in key_to_sharded_info.items(): - v.sort(key=lambda x: x[0]) + def gen_metadata_for_tp_sharded_tensor(self): + for k, v in self.global_sharded_tensor_infos.items(): + v.sort(key=lambda x: x[0]["tp_rank"]) state_dict_metadata = {} storage_metadata = {} - # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. - for k, v in key_to_sharded_info.items(): + for k, v in self.global_sharded_tensor_infos.items(): global_offset = 0 + local_shape = v[0][1] + model_state_name = self.optimizer_key_to_model_state_key(k) + if "_pow_acc_0" not in k: + global_shape = self.model_state_global_shape[model_state_name] + else: + global_shape = (1,) + + assert len(local_shape) == len(global_shape) + axis = -1 + for i in range(len(local_shape)): + if local_shape[i] != global_shape[i]: + axis = i + break + + is_replicated = axis == -1 + global_offset = [0] * len(local_shape) + + if is_replicated: + v = [v[0]] + for item in v: - local_tensor_meta_data = LocalTensorMetadata((global_offset,), item[1], item[2]) - local_tensor_index = LocalTensorIndex(k, (global_offset,)) - global_offset += item[1][0] + local_tensor_meta_data = LocalTensorMetadata(tuple(global_offset), item[1], item[2]) + local_tensor_index = LocalTensorIndex(k, tuple(global_offset)) + global_offset[axis] += item[1][axis] if k not in state_dict_metadata: state_dict_metadata[k] = [local_tensor_meta_data] else: state_dict_metadata[k].append(local_tensor_meta_data) - storage_metadata[local_tensor_index] = item[3] + ".distcp" - - # Save the metadata and the renamed tensor read by this rank. - metadata = Metadata(state_dict_metadata, storage_metadata, None) - write_path = os.path.join(self.path, "tmp") - for file in self.cur_rank_loaded_state_dict: - paddle.save(self.cur_rank_loaded_state_dict[file], os.path.join(write_path, file + ".distcp")) - if 0 == paddle.distributed.get_rank(): - paddle.save(metadata, os.path.join(write_path, "0.metadata")) - - def concat_optimier_state_dict(self): - # Obtain the global_shape passed in semi-automatic parallel mode on each card in the static graph. - all_rank_model_state_global_shapes = [] - use_dist = True if paddle.distributed.get_world_size() > 1 else False - if use_dist: - paddle.distributed.all_gather_object(all_rank_model_state_global_shapes, self.model_state_global_shape) - else: - all_rank_model_state_global_shapes = [self.model_state_global_shape] - - self.model_state_global_shape = {} - for rank_model_state_global_shape in all_rank_model_state_global_shapes: - for k, v in rank_model_state_global_shape.items(): - self.model_state_global_shape[k] = v - - # Obtain the names and shapes of all model parameters. - global_model_state_shapes = {} - sharding_metas_keys = [] - pp_degree = self.model_meta["parallel_config"]["pp_degree"] - mp_degree = self.model_meta["parallel_config"]["mp_degree"] - for i in range(pp_degree): - for j in range(mp_degree): - sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j)) - for key in sharding_metas_keys: - param_meta = self.model_meta["sharding_metas"][key]["param_meta"] - for k, v in param_meta.items(): - global_model_state_shapes[k] = v[0] - - use_dist = True if paddle.distributed.get_world_size() > 1 else False - world_size = paddle.distributed.get_world_size() - - # Distribute all model parameters evenly across each card for loading - global_model_state_flattened_shapes = {} - global_model_state_size = 0 - for k, v in global_model_state_shapes.items(): - flattened_size = reduce(lambda x, y: x * y, v) - global_model_state_size += flattened_size - global_model_state_flattened_shapes[k] = flattened_size - - partition_model_state_keys = [] - avg_size = global_model_state_size // world_size - - cur_rank_model_state_keys = [] - - cur_rank_size = 0 - for k, v in global_model_state_flattened_shapes.items(): - cur_rank_size += v - cur_rank_model_state_keys.append(k) - if cur_rank_size > avg_size: - partition_model_state_keys.append(cur_rank_model_state_keys) - cur_rank_model_state_keys = [] - cur_rank_size = 0 - - # Since an absolutely even distribution is not achievable, some tanks may not need to load, but the load_state_dict interface might throw an error. Therefore, it is necessary to forcefully assign a parameter. - nend_append = world_size - len(partition_model_state_keys) - for i in range(nend_append): - partition_model_state_keys.append([partition_model_state_keys[0][0]]) - - cur_rank = paddle.distributed.get_rank() - - cur_rank_need_load_model_state_keys = partition_model_state_keys[cur_rank] - - # Generate the optimizer states corresponding to the model weights. - optimizer_state_dict = {} - for key in cur_rank_need_load_model_state_keys: - for tp_rank in range(self.model_meta["parallel_config"]["mp_degree"]): - tp_rank_suffix = "_tp{:02d}".format(tp_rank) - optimizer_state_dict[key + "_fp32_master_0_moment1_0" + tp_rank_suffix] = paddle.zeros( - (global_model_state_flattened_shapes[key],), "float32" - ) - optimizer_state_dict[key + "_fp32_master_0_moment2_0" + tp_rank_suffix] = paddle.zeros( - (global_model_state_flattened_shapes[key],), "float32" - ) - optimizer_state_dict[key + tp_rank_suffix] = paddle.zeros( - (global_model_state_flattened_shapes[key],), "float32" - ) - optimizer_state_dict[key + "_fp32_master_0_beta1_pow_acc_0"] = paddle.zeros((1,), "float32") - optimizer_state_dict[key + "_fp32_master_0_beta2_pow_acc_0"] = paddle.zeros((1,), "float32") - - dist.load_state_dict(optimizer_state_dict, os.path.join(self.path, "tmp")) - - # Reshape - for k, v in optimizer_state_dict.items(): - if v.shape[0] > 1 and "_tp" in k: - for master_weight_key, shape in global_model_state_shapes.items(): - if master_weight_key in k: - reshaped_v = v.reshape(shape) - optimizer_state_dict[k] = reshaped_v - - concat_optimier_state_dict = {} - - optimizer_state_key_to_tp_keys = {} - for key in optimizer_state_dict.keys(): - # Count how each key is split into keys ending with ‘_tpXX’. - # optimizer_state_key_to_tp_keys : {key:[key_tp00,key_tp01]} - if "_pow_acc_0" not in key: - if key[:-5] not in optimizer_state_key_to_tp_keys: - optimizer_state_key_to_tp_keys[key[:-5]] = [key] + storage_metadata[local_tensor_index] = item[3] + + metadata = Metadata(state_dict_metadata, storage_metadata, None) + source_state_dict = self.cur_rank_loaded_state_dict + + return metadata, source_state_dict + + def gen_metadata_and_prepare_source_state_dict(self): + self.load_state_dict_and_rename() + if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: + for k, v in self.global_sharded_tensor_infos.items(): + v.sort(key=lambda x: x[0]["sharding_rank"]) + + state_dict_metadata = {} + storage_metadata = {} + # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. + for k, v in self.global_sharded_tensor_infos.items(): + global_offset = 0 + for item in v: + if item[0]["tp_rank"] != -1: + k_with_tp_rank = k + "_tp" + "{:02d}".format(item[0]["tp_rank"]) + else: + k_with_tp_rank = k + local_tensor_meta_data = LocalTensorMetadata((global_offset,), item[1], item[2]) + local_tensor_index = LocalTensorIndex(k_with_tp_rank, (global_offset,)) + global_offset += item[1][0] + if k_with_tp_rank not in state_dict_metadata: + state_dict_metadata[k_with_tp_rank] = [local_tensor_meta_data] + else: + state_dict_metadata[k_with_tp_rank].append(local_tensor_meta_data) + storage_metadata[local_tensor_index] = item[3] + + metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) + + source_state_dict_for_merge_sharding = {} + for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): + renamed_state_dict = {} + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) + for k, v in state_dict.items(): + if self.global_sharded_tensor_infos[k][0][0]["tp_rank"] != -1: + k_with_tp_rank = k + "_tp" + "{:02d}".format(tp_rank) + renamed_state_dict[k_with_tp_rank] = v + else: + renamed_state_dict[k] = v + + source_state_dict_for_merge_sharding[file_name] = renamed_state_dict + + assert self.model_meta is not None + global_model_state_shapes = [] + sharding_metas_keys = [] + for i in range(self.pp_degree): + for j in range(self.tp_degree): + sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j)) + for key in sharding_metas_keys: + param_meta = self.model_meta["sharding_metas"][key]["param_meta"] + for k, v in param_meta.items(): + global_model_state_shapes.append([k, v[0]]) + + # Distribute all model parameters evenly across each card for loading + + world_size = paddle.distributed.get_world_size() + + partition_mapping = self.partition_parameters(global_model_state_shapes, True, world_size) + + partition_model_state_keys = [] + for cur_rank, partition_model_state in partition_mapping.items(): + partition_model_state_keys.append([item[0] for item in partition_model_state]) + + global_model_state_flattened_shapes = {} + for item in global_model_state_shapes: + name = item[0] + shape = item[1] + flattened_size = reduce(lambda x, y: x * y, shape) + global_model_state_flattened_shapes[name] = flattened_size + + cur_rank_need_load_model_state_keys = partition_model_state_keys[self.cur_rank] + + # Generate the optimizer states corresponding to the model weights. + optimizer_state_dict = {} + for key in cur_rank_need_load_model_state_keys: + for tp_rank in range(self.tp_degree): + tp_rank_suffix = "_tp{:02d}".format(tp_rank) + optimizer_state_dict[key + ".w_0_fp32_master_0_moment1_0" + tp_rank_suffix] = paddle.zeros( + (global_model_state_flattened_shapes[key],), "float32" + ) + optimizer_state_dict[key + ".w_0_fp32_master_0_moment2_0" + tp_rank_suffix] = paddle.zeros( + (global_model_state_flattened_shapes[key],), "float32" + ) + if self.optimizer_state_with_master_weights: + optimizer_state_dict[key + ".w_0" + tp_rank_suffix] = paddle.zeros( + (global_model_state_flattened_shapes[key],), "float32" + ) + # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. + # Later, when these are compared with the global shape, we realize that they are replicated. + + optimizer_state_dict[key + ".w_0_fp32_master_0_beta1_pow_acc_0" + tp_rank_suffix] = paddle.zeros( + (1,), "float32" + ) + optimizer_state_dict[key + ".w_0_fp32_master_0_beta2_pow_acc_0" + tp_rank_suffix] = paddle.zeros( + (1,), "float32" + ) + + # merge sharding + _load_state_dict(optimizer_state_dict, source_state_dict_for_merge_sharding, [metadata_for_merge_sharding]) + + # Reshape + for k, v in optimizer_state_dict.items(): + if v.shape[0] > 1 and "_tp" in k: + for item in global_model_state_shapes: + master_weight_key = item[0] + shape = item[1] + if master_weight_key in k and reduce(lambda a, b: a * b, shape) == v.numel(): + reshaped_v = v.reshape(shape) + optimizer_state_dict[k] = reshaped_v + + concat_optimier_state_dict = {} + + optimizer_state_key_to_tp_keys = {} + for key in optimizer_state_dict.keys(): + # Count how each key is split into keys ending with ‘_tpXX’. + # optimizer_state_key_to_tp_keys : {key:[key_tp00,key_tp01]} + key_removed_tp_rank = key[:-5] + if key_removed_tp_rank not in optimizer_state_key_to_tp_keys: + optimizer_state_key_to_tp_keys[key_removed_tp_rank] = [key] else: - optimizer_state_key_to_tp_keys[key[:-5]].append(key) - else: - optimizer_state_key_to_tp_keys[key] = [key] - for key, value in optimizer_state_key_to_tp_keys.items(): - if len(value) == 1: - continue - value.sort(key=lambda x: int(x[-2:])) - - for key, tp_keys in optimizer_state_key_to_tp_keys.items(): - # Optimizer states with a shape of 1 could be replicated; here, perform a check. - is_replicated = True - tp_tensor = optimizer_state_dict[tp_keys[0]] - for tp_key in tp_keys: - if not np.array_equal(tp_tensor.numpy(), optimizer_state_dict[tp_key].numpy()): - is_replicated = False - break - if is_replicated: - concat_optimier_state_dict[key] = tp_tensor - continue - else: + optimizer_state_key_to_tp_keys[key_removed_tp_rank].append(key) + + for key, value in optimizer_state_key_to_tp_keys.items(): + value.sort(key=lambda x: int(x[-2:])) + + for key, tp_keys in optimizer_state_key_to_tp_keys.items(): + model_state_name = self.optimizer_key_to_model_state_key(key) + local_shape = optimizer_state_dict[tp_keys[0]].shape + if "_pow_acc_0" not in key: + global_shape = self.model_state_global_shape[model_state_name] + else: + global_shape = (1,) + + assert len(local_shape) == len(global_shape) + + axis = -1 + for i in range(len(local_shape)): + if local_shape[i] != global_shape[i]: + axis = i + break + + is_replicated = axis == -1 tp_tensors = [] for tp_key in tp_keys: tp_tensors.append(optimizer_state_dict[tp_key]) - # Derive the partition strategy based on the global_shape, then concatenate. - axis = 0 - global_shape = [] - # Find the global_shape. - for k, shape in self.model_state_global_shape.items(): - if k in tp_key: - global_shape = shape - break - assert len(global_shape) != 0 - tp_shape = tp_tensors[0].shape - assert (tp_shape[0] == global_shape[0] and len(tp_tensors) * tp_shape[1] == global_shape[1]) or ( - tp_shape[1] == global_shape[1] and len(tp_tensors) * tp_shape[0] == global_shape[0] - ) - if tp_shape[0] == global_shape[0]: - axis = 1 - concat_optimier_state_dict[key] = paddle.concat(tp_tensors, axis=axis) - - file_name = "{:02d}".format(cur_rank) + ".distcp" - local_tensor_meta_data = {} - local_tensor_index = {} - for k, v in concat_optimier_state_dict.items(): - # Generate metadata. - local_shape = v.shape - global_offset = tuple([0] * len(local_shape)) - dtype = str(v.dtype) - local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) - local_tensor_index[k] = [LocalTensorIndex(k, global_offset), file_name] - - global_local_tensor_meta_data = [] - global_local_tensor_index = [] - - if use_dist: - paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) - paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) - else: - global_local_tensor_meta_data = [local_tensor_meta_data] - global_local_tensor_index = [local_tensor_index] - state_dict_metadata = {} - for tensor_meta_data in global_local_tensor_meta_data: - for k, v in tensor_meta_data.items(): - if k not in state_dict_metadata: - state_dict_metadata[k] = [v] + if not is_replicated: + # Derive the partition strategy based on the global_shape, then concatenate. + concat_optimier_state_dict[key] = paddle.concat(tp_tensors, axis=axis) else: - state_dict_metadata[k].append(v) + concat_optimier_state_dict[key] = tp_tensors[0] + + fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" + local_tensor_meta_data = {} + local_tensor_index = {} + for k, v in concat_optimier_state_dict.items(): + # Generate metadata. + local_shape = v.shape + global_offset = tuple([0] * len(local_shape)) + dtype = str(v.dtype).split(".")[1] + local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) + local_tensor_index[k] = [LocalTensorIndex(k, global_offset), fake_file_name] + + global_local_tensor_meta_data = [] + global_local_tensor_index = [] + + use_dist = True if paddle.distributed.get_world_size() > 1 else False + + if use_dist: + paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) + paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) + else: + global_local_tensor_meta_data = [local_tensor_meta_data] + global_local_tensor_index = [local_tensor_index] - storage_metadata = {} - for tensor_index in global_local_tensor_index: - for k, v in tensor_index.items(): - storage_metadata[v[0]] = v[1] + state_dict_metadata = {} + for tensor_meta_data in global_local_tensor_meta_data: + for k, v in tensor_meta_data.items(): + if k not in state_dict_metadata: + state_dict_metadata[k] = [v] + else: + state_dict_metadata[k].append(v) + + storage_metadata = {} + for tensor_index in global_local_tensor_index: + for k, v in tensor_index.items(): + storage_metadata[v[0]] = v[1] + + meta_data = Metadata(state_dict_metadata, storage_metadata, None) + source_state_dict = {fake_file_name: concat_optimier_state_dict} + + return meta_data, source_state_dict + + elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: + return self.gen_metadata_for_tp_sharded_tensor() + else: + if self.is_sharding_stage3: + return + for k, v in self.global_sharded_tensor_infos.items(): + v.sort(key=lambda x: x[0]["sharding_rank"]) + + state_dict_metadata = {} + storage_metadata = {} + # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. + for k, v in self.global_sharded_tensor_infos.items(): + global_offset = 0 + for item in v: + local_tensor_meta_data = LocalTensorMetadata((global_offset,), item[1], item[2]) + local_tensor_index = LocalTensorIndex(k, (global_offset,)) + global_offset += item[1][0] + if k not in state_dict_metadata: + state_dict_metadata[k] = [local_tensor_meta_data] + else: + state_dict_metadata[k].append(local_tensor_meta_data) + storage_metadata[local_tensor_index] = item[3] - meta_data = Metadata(state_dict_metadata, storage_metadata, None) + metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) - save_path = os.path.join(self.path, "tmp2") + model_state_shapes = [] + dtype = "" + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(MODEL_WEIGHT_SUFFIX): + for k, v in state_dict.items(): + model_state_shapes.append([k, v.shape]) + dtype = str(v.dtype).split(".")[1] - if cur_rank == 0: - paddle.save(meta_data, os.path.join(save_path, "0.metadata")) + dtypes = self.gather_global_object([dtype]) + for dtype_s in dtypes: + if len(dtype_s) > 0: + dtype = dtype_s - paddle.save(concat_optimier_state_dict, os.path.join(save_path, file_name)) + assert len(dtype) > 0 + + global_model_state_shapes = self.gather_global_object(model_state_shapes) + + partition_result = self.partition_parameters( + global_model_state_shapes, True, paddle.distributed.get_world_size() + ) + + cur_rank_merger_model_params = partition_result[self.cur_rank] + target_state_dict = {} + for item in cur_rank_merger_model_params: + key = item[0] + shape = item[1] + flatten_shape = reduce(lambda a, b: a * b, item[1]) + target_state_dict[key] = paddle.zeros(shape, dtype) + target_state_dict[key + ".w_0_moment1_0"] = paddle.zeros((flatten_shape,), "float32") + target_state_dict[key + ".w_0_moment2_0"] = paddle.zeros((flatten_shape,), "float32") + if self.optimizer_state_with_master_weights: + target_state_dict[key + ".w_0"] = paddle.zeros((flatten_shape,), "float32") + # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. + # Later, when these are compared with the global shape, we realize that they are replicated. + + target_state_dict[key + ".w_0_beta1_pow_acc_0"] = paddle.zeros((1,), "float32") + target_state_dict[key + ".w_0_beta2_pow_acc_0"] = paddle.zeros((1,), "float32") + + # TODO(zhuxinming) To resolve hanging during the loading of weights in sharding stage 3. + _load_state_dict(target_state_dict, self.cur_rank_loaded_state_dict, [metadata_for_merge_sharding]) + + else: + return self.gen_metadata_for_tp_sharded_tensor() + + def rename_semi_auto_state_dict(self): + need_remove_key_pattern = ["eager_tmp", "learning_rate", "@GRAD@MERG", "gradient_merge_"] + + need_remove_key = set() + for key in self.semi_auto_model_state.keys(): + for pattern in need_remove_key_pattern: + if pattern in key: + need_remove_key.add(key) + break + + for key in need_remove_key: + self.semi_auto_model_state.pop(key) + + adamw_optimizer_status_name_suffix_mappings = { + "_fp32_master_1_moment1_0": ".w_0_fp32_master_0_moment1_0", + "_fp32_master_1_moment2_0": ".w_0_fp32_master_0_moment2_0", + "_fp32_master_1_beta1_pow_acc_0": ".w_0_fp32_master_0_beta1_pow_acc_0", + "_fp32_master_1_beta2_pow_acc_0": ".w_0_fp32_master_0_beta2_pow_acc_0", + "_fp32_master_1": ".w_0", + } + + def rename(old_name, map1, map2): + for i in range(1, len(old_name)): + str1 = old_name[:i] + str2 = old_name[i:] + if (str1 in map1) and (str2 in map2): + transformed_str1 = map1[str1] + transformed_str2 = map2[str2] + return transformed_str1 + transformed_str2 + return None + + renamed_state_dict = {} + for key, value in self.semi_auto_model_state.items(): + if key in self.parameter_to_structured_name.values(): + new_name = key + else: + new_name = rename(key, self.parameter_to_structured_name, adamw_optimizer_status_name_suffix_mappings) + print(new_name) + assert new_name is not None + renamed_state_dict[new_name] = value + + self.semi_auto_model_state = renamed_state_dict + + def load_from_dynamic_checkpoint(self): + self.rename_semi_auto_state_dict() + metadata, source_state_dict = self.gen_metadata_and_prepare_source_state_dict() + if self.save_sharded_model: + model_params = {} + for k, v in self.semi_auto_model_state.items(): + if k in self.parameter_to_structured_name.values(): + model_params[k] = v + for k in model_params.keys(): + self.semi_auto_model_state.pop(k) + + appended_master_weight_names = [] + + for k, v in model_params.items(): + master_weight = k + ".w_0" + if master_weight not in self.semi_auto_model_state: + appended_master_weight_names.append(master_weight) + # TODO(zhuxinming) Create a new distributed tensor with the same distribution information as the corresponding parameter. + self.semi_auto_model_state[master_weight] = paddle.zeros(v._local_value().shape, "float32") + + _load_state_dict(self.semi_auto_model_state, source_state_dict, [metadata]) + for k, v in model_params.items(): + master_weight = self.semi_auto_model_state[k + ".w_0"] + # cast_master_weight = paddle.cast(master_weight, "bfloat16") + + for k in appended_master_weight_names: + self.semi_auto_model_state.pop(k) + + else: + _load_state_dict(self.semi_auto_model_state, source_state_dict, [metadata]) From e00d34cda77353a7d99f4fca11119f56fb6e34f1 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Thu, 8 Aug 2024 20:33:48 +0800 Subject: [PATCH 03/30] fix sharding stage1 v2 --- paddlenlp/trainer/checkpoint_converter.py | 116 +++++++++++++--------- 1 file changed, 68 insertions(+), 48 deletions(-) diff --git a/paddlenlp/trainer/checkpoint_converter.py b/paddlenlp/trainer/checkpoint_converter.py index 00235c42a870..73e9be3ab113 100644 --- a/paddlenlp/trainer/checkpoint_converter.py +++ b/paddlenlp/trainer/checkpoint_converter.py @@ -18,6 +18,7 @@ from functools import reduce import paddle +import paddle.distributed as dist from paddle.distributed.checkpoint.load_state_dict import ( _load_state_dict, get_local_load_files, @@ -39,7 +40,7 @@ def __init__(self, dynamic_ckpt_path, model_state, parameter_to_structured_name) self.use_dist = True if paddle.distributed.get_world_size() > 1 else False self.path = dynamic_ckpt_path self.semi_auto_model_state = model_state - self.parameter_to_structured_name = parameter_to_structured_name + self.parameter_to_structured_name = self.gather_global_object(parameter_to_structured_name) model_state_global_shape = {} for k, v in model_state.items(): model_state_global_shape[k] = v.shape @@ -138,7 +139,7 @@ def infer_sharding_stage1_v(self): if file.endswith(OPTIMIZER_WEIGHT_SUFFIX) and sharding_stage1_v[0] == 2: for k, v in state_dict.items(): # Under shardingv2, the optimizer state is first flattened and then split. - if "_fp32_master_0_moment" in k and len(v.shape) != 1: + if "_moment" in k and len(v.shape) != 1: sharding_stage1_v = [1] break @@ -186,10 +187,10 @@ def optimizer_state_file_name_to_model_state_file_name(self, optimizer_state_fil def optimizer_key_to_model_state_key(self, optimizer_key): adamw_optimizer_key_suffix = [ - ".w_0_fp32_master_0_beta1_pow_acc_0", - ".w_0_fp32_master_0_beta2_pow_acc_0", - ".w_0_fp32_master_0_moment1_0", - ".w_0_fp32_master_0_moment2_0", + ".w_0_beta1_pow_acc_0", + ".w_0_beta2_pow_acc_0", + ".w_0_moment1_0", + ".w_0_moment2_0", ".w_0", ] model_state_key = optimizer_key @@ -321,6 +322,25 @@ def load_state_dict_and_rename(self): k = k.replace("slice@", "") state_dict[k] = v + # Standardize the state names of the AdamW optimizer. + adamw_optimizer_param_suffix_name_mapping = { + ".w_0_fp32_master_0_moment1_0": ".w_0_moment1_0", + ".w_0_fp32_master_0_moment2_0": ".w_0_moment2_0", + ".w_0_fp32_master_0_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", + ".w_0_fp32_master_0_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", + } + + unified_name_state_dict = {} + for k, v in state_dict.items(): + new_k = k + for suffix in adamw_optimizer_param_suffix_name_mapping: + if k.endswith(suffix): + new_k = k.replace(suffix, adamw_optimizer_param_suffix_name_mapping[suffix]) + break + unified_name_state_dict[new_k] = v + + self.cur_rank_loaded_state_dict[file] = unified_name_state_dict + # After the rank has finished loading the files it needs, it can infer sharding_stage1_v and is_sharding_stage3. self.sharding_stage1_v = self.infer_sharding_stage1_v() self.is_sharding_stage3 = self.infer_is_sharding_stage3() @@ -636,15 +656,13 @@ def gen_metadata_and_prepare_source_state_dict(self): storage_metadata = {} # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. for k, v in self.global_sharded_tensor_infos.items(): - global_offset = 0 + global_offset = [0] * self.tp_degree for item in v: - if item[0]["tp_rank"] != -1: - k_with_tp_rank = k + "_tp" + "{:02d}".format(item[0]["tp_rank"]) - else: - k_with_tp_rank = k - local_tensor_meta_data = LocalTensorMetadata((global_offset,), item[1], item[2]) - local_tensor_index = LocalTensorIndex(k_with_tp_rank, (global_offset,)) - global_offset += item[1][0] + tp_rank = item[0]["tp_rank"] + k_with_tp_rank = k + "_tp" + "{:02d}".format(tp_rank) + local_tensor_meta_data = LocalTensorMetadata((global_offset[tp_rank],), item[1], item[2]) + local_tensor_index = LocalTensorIndex(k_with_tp_rank, (global_offset[tp_rank],)) + global_offset[tp_rank] += item[1][0] if k_with_tp_rank not in state_dict_metadata: state_dict_metadata[k_with_tp_rank] = [local_tensor_meta_data] else: @@ -687,12 +705,17 @@ def gen_metadata_and_prepare_source_state_dict(self): for cur_rank, partition_model_state in partition_mapping.items(): partition_model_state_keys.append([item[0] for item in partition_model_state]) - global_model_state_flattened_shapes = {} - for item in global_model_state_shapes: - name = item[0] - shape = item[1] - flattened_size = reduce(lambda x, y: x * y, shape) - global_model_state_flattened_shapes[name] = flattened_size + param_meta = {} + for i in range(self.tp_degree): + for j in range(self.pp_degree): + key = "tp{:02d}_pp{:02d}".format(i, j) + pm = self.model_meta["sharding_metas"][key]["param_meta"] + for k, v in pm.items(): + param_meta[k] = v + + param_flattened_shapes = {} + for k, v in param_meta.items(): + param_flattened_shapes[k] = reduce(lambda x, y: x * y, v[0]) cur_rank_need_load_model_state_keys = partition_model_state_keys[self.cur_rank] @@ -701,25 +724,21 @@ def gen_metadata_and_prepare_source_state_dict(self): for key in cur_rank_need_load_model_state_keys: for tp_rank in range(self.tp_degree): tp_rank_suffix = "_tp{:02d}".format(tp_rank) - optimizer_state_dict[key + ".w_0_fp32_master_0_moment1_0" + tp_rank_suffix] = paddle.zeros( - (global_model_state_flattened_shapes[key],), "float32" + optimizer_state_dict[key + ".w_0_moment1_0" + tp_rank_suffix] = paddle.zeros( + (param_flattened_shapes[key],), "float32" ) - optimizer_state_dict[key + ".w_0_fp32_master_0_moment2_0" + tp_rank_suffix] = paddle.zeros( - (global_model_state_flattened_shapes[key],), "float32" + optimizer_state_dict[key + ".w_0_moment2_0" + tp_rank_suffix] = paddle.zeros( + (param_flattened_shapes[key],), "float32" ) if self.optimizer_state_with_master_weights: optimizer_state_dict[key + ".w_0" + tp_rank_suffix] = paddle.zeros( - (global_model_state_flattened_shapes[key],), "float32" + (param_flattened_shapes[key],), "float32" ) # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. # Later, when these are compared with the global shape, we realize that they are replicated. - optimizer_state_dict[key + ".w_0_fp32_master_0_beta1_pow_acc_0" + tp_rank_suffix] = paddle.zeros( - (1,), "float32" - ) - optimizer_state_dict[key + ".w_0_fp32_master_0_beta2_pow_acc_0" + tp_rank_suffix] = paddle.zeros( - (1,), "float32" - ) + optimizer_state_dict[key + ".w_0_beta1_pow_acc_0" + tp_rank_suffix] = paddle.zeros((1,), "float32") + optimizer_state_dict[key + ".w_0_beta2_pow_acc_0" + tp_rank_suffix] = paddle.zeros((1,), "float32") # merge sharding _load_state_dict(optimizer_state_dict, source_state_dict_for_merge_sharding, [metadata_for_merge_sharding]) @@ -727,13 +746,11 @@ def gen_metadata_and_prepare_source_state_dict(self): # Reshape for k, v in optimizer_state_dict.items(): if v.shape[0] > 1 and "_tp" in k: - for item in global_model_state_shapes: - master_weight_key = item[0] - shape = item[1] - if master_weight_key in k and reduce(lambda a, b: a * b, shape) == v.numel(): - reshaped_v = v.reshape(shape) - optimizer_state_dict[k] = reshaped_v - + param_name = self.optimizer_key_to_model_state_key(k[:-5]) + param_shape = param_meta[param_name][0] + assert v.numel() == reduce(lambda x, y: x * y, param_shape) + reshaped_v = v.reshape(param_shape) + optimizer_state_dict[k] = reshaped_v concat_optimier_state_dict = {} optimizer_state_key_to_tp_keys = {} @@ -821,7 +838,6 @@ def gen_metadata_and_prepare_source_state_dict(self): return self.gen_metadata_for_tp_sharded_tensor() else: if self.is_sharding_stage3: - return for k, v in self.global_sharded_tensor_infos.items(): v.sort(key=lambda x: x[0]["sharding_rank"]) @@ -900,11 +916,15 @@ def rename_semi_auto_state_dict(self): self.semi_auto_model_state.pop(key) adamw_optimizer_status_name_suffix_mappings = { - "_fp32_master_1_moment1_0": ".w_0_fp32_master_0_moment1_0", - "_fp32_master_1_moment2_0": ".w_0_fp32_master_0_moment2_0", - "_fp32_master_1_beta1_pow_acc_0": ".w_0_fp32_master_0_beta1_pow_acc_0", - "_fp32_master_1_beta2_pow_acc_0": ".w_0_fp32_master_0_beta2_pow_acc_0", + "_fp32_master_1_moment1_0": ".w_0_moment1_0", + "_fp32_master_1_moment2_0": ".w_0_moment2_0", + "_fp32_master_1_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", + "_fp32_master_1_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", "_fp32_master_1": ".w_0", + "_moment1_0": ".w_0_moment1_0", + "_moment2_0": ".w_0_moment2_0", + "_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", + "_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", } def rename(old_name, map1, map2): @@ -923,7 +943,6 @@ def rename(old_name, map1, map2): new_name = key else: new_name = rename(key, self.parameter_to_structured_name, adamw_optimizer_status_name_suffix_mappings) - print(new_name) assert new_name is not None renamed_state_dict[new_name] = value @@ -946,14 +965,15 @@ def load_from_dynamic_checkpoint(self): master_weight = k + ".w_0" if master_weight not in self.semi_auto_model_state: appended_master_weight_names.append(master_weight) - # TODO(zhuxinming) Create a new distributed tensor with the same distribution information as the corresponding parameter. - self.semi_auto_model_state[master_weight] = paddle.zeros(v._local_value().shape, "float32") + tmp_tensor = paddle.zeros(v.shape, "float32") + dist_tmp_tensor = dist.shard_tensor(tmp_tensor, v.process_mesh, v.placements) + self.semi_auto_model_state[master_weight] = dist_tmp_tensor _load_state_dict(self.semi_auto_model_state, source_state_dict, [metadata]) for k, v in model_params.items(): master_weight = self.semi_auto_model_state[k + ".w_0"] - # cast_master_weight = paddle.cast(master_weight, "bfloat16") - + cast_master_weight = paddle.cast(master_weight._local_value(), "bfloat16") + paddle.assign(cast_master_weight, v._local_value()) for k in appended_master_weight_names: self.semi_auto_model_state.pop(k) From 75b927903199014e128493d34e336136bf4b371b Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Fri, 9 Aug 2024 15:38:39 +0800 Subject: [PATCH 04/30] add flag --- paddlenlp/trainer/auto_trainer.py | 25 ++-- paddlenlp/trainer/checkpoint_converter.py | 133 +++++++++++++++------- paddlenlp/trainer/training_args.py | 51 +++++---- 3 files changed, 140 insertions(+), 69 deletions(-) diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index e19db0659d9c..9bf5699cf6ac 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -28,6 +28,7 @@ from ..utils.log import logger from .argparser import strtobool +from .ckpt_converter import CheckpointConverter from .trainer import SCALER_NAME, SCHEDULER_NAME, TRAINER_STATE_NAME, TRAINING_ARGS_NAME from .trainer_callback import TrainerState from .trainer_utils import ( # set_hyrbid_parallel_seed, @@ -663,11 +664,6 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): ) ) - ckpt_path = os.path.join(resume_from_checkpoint, DIST_CKPT_PATH) - - if not os.path.isdir(ckpt_path): - raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") - if self.args.to_static: state_dict = self.model_wrapped.state_dict() else: @@ -681,12 +677,21 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): optim_state_dict = self.optimizer.state_dict() optim_state_dict.pop("LR_Scheduler", None) - state_dict = { - MODEL_NAME: model_state_dict, - OPTIMIZER_NAME: optim_state_dict, - } + state_dict = {} + for k, v in model_state_dict.items(): + state_dict[k] = v + for k, v in optim_state_dict.items(): + state_dict[k] = v - self._load_ckpt_func(state_dict, ckpt_path) + if self.args.resume_form_hybrid_parallel: + CheckpointConverter( + resume_from_checkpoint, state_dict, self.model_wrapped._parameter_to_structured_name + ).load_from_hybrid_parallel_checkpoint() + else: + ckpt_path = os.path.join(resume_from_checkpoint, DIST_CKPT_PATH) + if not os.path.isdir(ckpt_path): + raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") + self._load_ckpt_func(state_dict, ckpt_path) # release memory del state_dict diff --git a/paddlenlp/trainer/checkpoint_converter.py b/paddlenlp/trainer/checkpoint_converter.py index 73e9be3ab113..470f66f23cef 100644 --- a/paddlenlp/trainer/checkpoint_converter.py +++ b/paddlenlp/trainer/checkpoint_converter.py @@ -36,10 +36,10 @@ class CheckpointConverter: - def __init__(self, dynamic_ckpt_path, model_state, parameter_to_structured_name): + def __init__(self, hybrid_parallel_ckpt_path, model_state, parameter_to_structured_name): self.use_dist = True if paddle.distributed.get_world_size() > 1 else False - self.path = dynamic_ckpt_path - self.semi_auto_model_state = model_state + self.path = hybrid_parallel_ckpt_path + self.auto_parallel_state_dict = model_state self.parameter_to_structured_name = self.gather_global_object(parameter_to_structured_name) model_state_global_shape = {} for k, v in model_state.items(): @@ -139,7 +139,7 @@ def infer_sharding_stage1_v(self): if file.endswith(OPTIMIZER_WEIGHT_SUFFIX) and sharding_stage1_v[0] == 2: for k, v in state_dict.items(): # Under shardingv2, the optimizer state is first flattened and then split. - if "_moment" in k and len(v.shape) != 1: + if len(v.shape) != 1: sharding_stage1_v = [1] break @@ -163,6 +163,8 @@ def infer_is_sharding_stage3(self): state_shape_mapping = {} for k, v in state_dict.items(): state_shape_mapping[k] = v.shape + if len(v.shape) != 1: + return False file_to_state_shape_mapping[file] = state_shape_mapping global_file_to_state_shape_mapping = self.gather_global_object(file_to_state_shape_mapping) @@ -174,10 +176,10 @@ def infer_is_sharding_stage3(self): break return is_sharding_stage3 - def optimizer_state_name_to_master_weight_name(self, optimizer_state_name): + def parse_master_weight_name_by(self, optimizer_state_name): return optimizer_state_name.split(".")[0] - def optimizer_state_file_name_to_model_state_file_name(self, optimizer_state_file_name): + def get_model_state_file_from(self, optimizer_state_file_name): (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(optimizer_state_file_name) for model_state_file in self.global_model_state_file_names: distributed_rank = self.get_distribution_rank_from_file_name(model_state_file) @@ -249,7 +251,7 @@ def rename_using_model_meta(self, file_name): (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) state_dict = self.cur_rank_loaded_state_dict[file_name] for k, v in state_dict.items(): - master_weight_name = self.optimizer_state_name_to_master_weight_name(k) + master_weight_name = self.parse_master_weight_name_by(k) model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] new_k = k.replace(master_weight_name, model_weight_name) renamed_state_dict[new_k] = v @@ -266,22 +268,22 @@ def rename_using_optimizer_state_order(self, file_name): self.global_file_to_state_dict_keys_mapping = self.gather_global_object(file_to_state_dict_keys_mapping) (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) - if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): - model_state_file_name = self.optimizer_state_file_name_to_model_state_file_name(file) + if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): + model_state_file_name = self.get_model_state_file_from(file_name) assert model_state_file_name is not None model_state_keys = self.global_file_to_state_dict_keys_mapping[model_state_file_name] - optimizer_state_keys = self.global_file_to_state_dict_keys_mapping[file] + optimizer_state_keys = self.global_file_to_state_dict_keys_mapping[file_name] master_weight_name_to_model_weight_name_mapping = {} for i in range(len(model_state_keys)): - master_weight_name = self.optimizer_state_name_to_master_weight_name(optimizer_state_keys[i]) + master_weight_name = self.parse_master_weight_name_by(optimizer_state_keys[i]) master_weight_name_to_model_weight_name_mapping[master_weight_name] = model_state_keys[i] state_dict = self.cur_rank_loaded_state_dict[file_name] renamed_state_dict = {} - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) for k, v in state_dict.items(): - master_weight_name = self.optimizer_state_name_to_master_weight_name(k) + master_weight_name = self.parse_master_weight_name_by(k) model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] new_k = k.replace(master_weight_name, model_weight_name) renamed_state_dict[new_k] = v @@ -419,7 +421,7 @@ def load_state_dict_and_rename(self): (tp_rank_, pp_rank_, sharding_rank_) = self.get_distribution_rank_from_file_name(k) if tp_rank == tp_rank_ and pp_rank == pp_rank_ and k.endswith(OPTIMIZER_WEIGHT_SUFFIX): sharding_optimizer_state_shards.append([v, sharding_rank_]) - model_state_file_name = self.optimizer_state_file_name_to_model_state_file_name(file) + model_state_file_name = self.get_model_state_file_from(file) model_state_shapes = global_file_to_state_dict_shapes_mapping[model_state_file_name] sharding_optimizer_state_shards.sort(key=lambda x: x[1]) @@ -462,7 +464,7 @@ def load_state_dict_and_rename(self): state_shard = sharding_optimizer_state_shards[i][0] partitioned_shard = partition_result[i] for j in range(len(partitioned_shard)): - master_weight_name = self.optimizer_state_name_to_master_weight_name(state_shard[j][0]) + master_weight_name = self.parse_master_weight_name_by(state_shard[j][0]) master_weight_name_to_model_weight_name_mapping[ master_weight_name ] = partitioned_shard[j][0] @@ -473,7 +475,7 @@ def load_state_dict_and_rename(self): # In this branch, sharding does not split the optimizer states; it merely relocates them to different cards. # Therefore, the sharding information can now be directly removed. for k, v in state_dict.items(): - master_weight_name = self.optimizer_state_name_to_master_weight_name(k) + master_weight_name = self.parse_master_weight_name_by(k) model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] new_k = k.replace(master_weight_name, model_weight_name) renamed_state_dict[new_k] = v @@ -548,7 +550,7 @@ def load_state_dict_and_rename(self): if k not in cur_rank_sharded_tensor_infos: cur_rank_sharded_tensor_infos[k] = [ [ - {"tp_rank": tp_rank, "sharding_rank": -1}, + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, v.shape, str(v.dtype).split(".")[1], file, @@ -557,7 +559,7 @@ def load_state_dict_and_rename(self): else: cur_rank_sharded_tensor_infos[k].append( [ - {"tp_rank": tp_rank, "sharding_rank": -1}, + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, v.shape, str(v.dtype).split(".")[1], file, @@ -840,16 +842,20 @@ def gen_metadata_and_prepare_source_state_dict(self): if self.is_sharding_stage3: for k, v in self.global_sharded_tensor_infos.items(): v.sort(key=lambda x: x[0]["sharding_rank"]) - state_dict_metadata = {} storage_metadata = {} # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. for k, v in self.global_sharded_tensor_infos.items(): global_offset = 0 for item in v: - local_tensor_meta_data = LocalTensorMetadata((global_offset,), item[1], item[2]) - local_tensor_index = LocalTensorIndex(k, (global_offset,)) - global_offset += item[1][0] + if len(item[1]) == 1: + local_tensor_meta_data = LocalTensorMetadata((global_offset,), item[1], item[2]) + local_tensor_index = LocalTensorIndex(k, (global_offset,)) + global_offset += item[1][0] + else: + global_offset = tuple([0] * len(item[1])) + local_tensor_meta_data = LocalTensorMetadata(global_offset, item[1], item[2]) + local_tensor_index = LocalTensorIndex(k, global_offset) if k not in state_dict_metadata: state_dict_metadata[k] = [local_tensor_meta_data] else: @@ -857,7 +863,6 @@ def gen_metadata_and_prepare_source_state_dict(self): storage_metadata[local_tensor_index] = item[3] metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) - model_state_shapes = [] dtype = "" for file, state_dict in self.cur_rank_loaded_state_dict.items(): @@ -896,9 +901,58 @@ def gen_metadata_and_prepare_source_state_dict(self): target_state_dict[key + ".w_0_beta1_pow_acc_0"] = paddle.zeros((1,), "float32") target_state_dict[key + ".w_0_beta2_pow_acc_0"] = paddle.zeros((1,), "float32") - # TODO(zhuxinming) To resolve hanging during the loading of weights in sharding stage 3. _load_state_dict(target_state_dict, self.cur_rank_loaded_state_dict, [metadata_for_merge_sharding]) + # Reshape + for item in cur_rank_merger_model_params: + key = item[0] + shape = item[1] + for k, v in target_state_dict.items(): + if key == self.optimizer_key_to_model_state_key(k): + if tuple(shape) != tuple(v.shape) and v.numel() == reduce(lambda x, y: x * y, shape): + reshaped_v = v.reshape(shape) + target_state_dict[k] = reshaped_v + + fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" + local_tensor_meta_data = {} + local_tensor_index = {} + for k, v in target_state_dict.items(): + # Generate metadata. + local_shape = v.shape + global_offset = tuple([0] * len(local_shape)) + dtype = str(v.dtype).split(".")[1] + local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) + local_tensor_index[k] = [LocalTensorIndex(k, global_offset), fake_file_name] + + global_local_tensor_meta_data = [] + global_local_tensor_index = [] + + use_dist = True if paddle.distributed.get_world_size() > 1 else False + + if use_dist: + paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) + paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) + else: + global_local_tensor_meta_data = [local_tensor_meta_data] + global_local_tensor_index = [local_tensor_index] + + state_dict_metadata = {} + for tensor_meta_data in global_local_tensor_meta_data: + for k, v in tensor_meta_data.items(): + if k not in state_dict_metadata: + state_dict_metadata[k] = [v] + else: + state_dict_metadata[k].append(v) + + storage_metadata = {} + for tensor_index in global_local_tensor_index: + for k, v in tensor_index.items(): + storage_metadata[v[0]] = v[1] + + meta_data = Metadata(state_dict_metadata, storage_metadata, None) + source_state_dict = {fake_file_name: target_state_dict} + + return meta_data, source_state_dict else: return self.gen_metadata_for_tp_sharded_tensor() @@ -906,14 +960,14 @@ def rename_semi_auto_state_dict(self): need_remove_key_pattern = ["eager_tmp", "learning_rate", "@GRAD@MERG", "gradient_merge_"] need_remove_key = set() - for key in self.semi_auto_model_state.keys(): + for key in self.auto_parallel_state_dict.keys(): for pattern in need_remove_key_pattern: if pattern in key: need_remove_key.add(key) break for key in need_remove_key: - self.semi_auto_model_state.pop(key) + self.auto_parallel_state_dict.pop(key) adamw_optimizer_status_name_suffix_mappings = { "_fp32_master_1_moment1_0": ".w_0_moment1_0", @@ -938,44 +992,45 @@ def rename(old_name, map1, map2): return None renamed_state_dict = {} - for key, value in self.semi_auto_model_state.items(): + + for key, value in self.auto_parallel_state_dict.items(): + if key in self.parameter_to_structured_name.values(): new_name = key else: new_name = rename(key, self.parameter_to_structured_name, adamw_optimizer_status_name_suffix_mappings) + assert new_name is not None renamed_state_dict[new_name] = value - self.semi_auto_model_state = renamed_state_dict + self.auto_parallel_state_dict = renamed_state_dict - def load_from_dynamic_checkpoint(self): + def load_from_hybrid_parallel_checkpoint(self): self.rename_semi_auto_state_dict() metadata, source_state_dict = self.gen_metadata_and_prepare_source_state_dict() if self.save_sharded_model: model_params = {} - for k, v in self.semi_auto_model_state.items(): + for k, v in self.auto_parallel_state_dict.items(): if k in self.parameter_to_structured_name.values(): model_params[k] = v for k in model_params.keys(): - self.semi_auto_model_state.pop(k) + self.auto_parallel_state_dict.pop(k) appended_master_weight_names = [] - for k, v in model_params.items(): master_weight = k + ".w_0" - if master_weight not in self.semi_auto_model_state: + if master_weight not in self.auto_parallel_state_dict: appended_master_weight_names.append(master_weight) tmp_tensor = paddle.zeros(v.shape, "float32") dist_tmp_tensor = dist.shard_tensor(tmp_tensor, v.process_mesh, v.placements) - self.semi_auto_model_state[master_weight] = dist_tmp_tensor + self.auto_parallel_state_dict[master_weight] = dist_tmp_tensor - _load_state_dict(self.semi_auto_model_state, source_state_dict, [metadata]) + _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) for k, v in model_params.items(): - master_weight = self.semi_auto_model_state[k + ".w_0"] + master_weight = self.auto_parallel_state_dict[k + ".w_0"] cast_master_weight = paddle.cast(master_weight._local_value(), "bfloat16") paddle.assign(cast_master_weight, v._local_value()) for k in appended_master_weight_names: - self.semi_auto_model_state.pop(k) - + self.auto_parallel_state_dict.pop(k) else: - _load_state_dict(self.semi_auto_model_state, source_state_dict, [metadata]) + _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 0eb13bffa1b1..1b1713f4a566 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -245,7 +245,6 @@ class TrainingArguments: enable_mp_async_allreduce, it supports all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear backward when it set True, which can accelerate model parallel performance. enable_mp_skip_c_identity, it supports skip c_identity in ColumnParallelLinear and RowParallelLinear. It only works when set mp_async_allreduce is True. It can accelerate model parallel further. enable_mp_fused_linear_param_grad_add, it supports fused_linear_param_grad_add in ColumnParallelLinear (cuda >= 11.6). It only works when mp_async_allreduce is true. It can accelerate model parallel further. - enable_sp_async_reduce_scatter, it supports async reduce_scatter in ColumnSequenceParallelLinear. It only works when set sp_async_reduce_scatter is True. It can accelerate sequence parallel further. enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly. sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False. sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False. @@ -630,7 +629,6 @@ class TrainingArguments: "enable_mp_async_allreduce, it supports all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear backward when it set True, which can accelerate model parallel performance. \n" "enable_mp_skip_c_identity, it supports skip c_identity in ColumnParallelLinear and RowParallelLinear. It only works when set mp_async_allreduce is True. It can accelerate model parallel further.\n" "enable_mp_fused_linear_param_grad_add, it supports fused_linear_param_grad_add in ColumnParallelLinear (cuda >= 11.6). It only works when mp_async_allreduce is true. It can accelerate model parallel further.\n" - "enable_sp_async_reduce_scatter, it supports async reduce_scatter in ColumnSequenceParallelLinear. It only works when set sp_async_reduce_scatter is True. It can accelerate sequence parallel further.\n" "enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.\n" "sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.\n" "sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.\n" @@ -833,6 +831,11 @@ class TrainingArguments: metadata={"help": "Enable MoE (Mixture of Experts) expert parallel training"}, ) + resume_form_hybrid_parallel: Optional[bool] = field( + default=False, + metadata={"help": "Wether hybrid paralle checkpoints be loaded in automatic parallel mode"}, + ) + def __post_init__(self): env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1)) if env_local_rank != -1 and env_local_rank != self.local_rank and paddle.distributed.get_world_size() > 1: @@ -1024,13 +1027,6 @@ def __post_init__(self): logger.warning("set amp_master_grad to false since amp is disabled.") self.amp_master_grad = False - def split_parallel_config(parallel_config): - if "," in parallel_config: - parallel_config = set(parallel_config.split(",")) - else: - parallel_config = set(parallel_config.split(" ")) - return parallel_config - # use_hybrid_parallel if self.use_hybrid_parallel: @@ -1048,7 +1044,10 @@ def split_parallel_config(parallel_config): strategy = fleet.DistributedStrategy() assert self.data_parallel_config == "", "data_parallle_config is not supported in hybrid parallel" if self.pipeline_parallel_degree > 1: - pipeline_parallel_config = split_parallel_config(self.pipeline_parallel_config) + if " " in self.pipeline_parallel_config: + pipeline_parallel_config = set(self.pipeline_parallel_config.split(" ")) + else: + pipeline_parallel_config = set(self.pipeline_parallel_config.split(",")) for x in pipeline_parallel_config: if len(x) > 0: if x not in [ @@ -1122,7 +1121,10 @@ def split_parallel_config(parallel_config): if self.tensor_parallel_degree > 1: strategy.tensor_parallel_configs = {"tensor_init_seed": self.seed} - mp_config = split_parallel_config(self.tensor_parallel_config) + if " " in self.tensor_parallel_config: + mp_config = set(self.tensor_parallel_config.split(" ")) + else: + mp_config = set(self.tensor_parallel_config.split(",")) for x in mp_config: if len(x) > 0: @@ -1130,7 +1132,6 @@ def split_parallel_config(parallel_config): "enable_mp_async_allreduce", "enable_mp_skip_c_identity", "enable_mp_fused_linear_param_grad_add", - "enable_sp_async_reduce_scatter", "enable_delay_scale_loss", "sync_param", "sync_grad", @@ -1138,7 +1139,7 @@ def split_parallel_config(parallel_config): ]: raise ValueError( f"Found unknown tensor parallell config {x}, " - f"accept config is enable_mp_async_allreduce, enable_mp_skip_c_identity, enable_mp_fused_linear_param_grad_add, enable_sp_async_reduce_scatter, enable_delay_scale_loss, sync_param, sync_grad and sync_moment." + f"accept config is enable_mp_async_allreduce, enable_mp_skip_c_identity, enable_mp_fused_linear_param_grad_add, sync_param, sync_grad and sync_moment." ) try: if "enable_mp_async_allreduce" in mp_config: @@ -1156,8 +1157,6 @@ def split_parallel_config(parallel_config): warnings.warn( "enable_mp_fused_linear_param_grad_add only works with enable_mp_async_allreduce. It will not work." ) - if "enable_sp_async_reduce_scatter" in mp_config: - strategy.hybrid_configs["mp_configs"].sp_async_reduce_scatter = True sync_param = "sync_param" in mp_config sync_grad = "sync_grad" in mp_config @@ -1231,8 +1230,10 @@ def is_segment_parallel_supported(): strategy.hybrid_configs = hybrid_configs if self.sharding_parallel_degree > 1: - sharding_parallel_config = split_parallel_config(self.sharding_parallel_config) - + if " " in self.sharding_parallel_config: + sharding_parallel_config = set(self.sharding_parallel_config.split(" ")) + else: + sharding_parallel_config = set(self.sharding_parallel_config.split(",")) for x in sharding_parallel_config: if len(x) > 0: if x not in [ @@ -1388,7 +1389,10 @@ def is_segment_parallel_supported(): # navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1 if self.pipeline_parallel_degree > 1 and self.gradient_accumulation_steps > 1: - pipeline_parallel_config = split_parallel_config(self.pipeline_parallel_config) + if " " in self.pipeline_parallel_config: + pipeline_parallel_config = set(self.pipeline_parallel_config.split(" ")) + else: + pipeline_parallel_config = set(self.pipeline_parallel_config.split(",")) for x in pipeline_parallel_config: if len(x) > 0: if x not in [ @@ -1437,7 +1441,11 @@ def is_segment_parallel_supported(): if self.tensor_parallel_degree > 1: mp_optimization = strategy.mp_optimization - mp_config = split_parallel_config(self.tensor_parallel_config) + + if " " in self.tensor_parallel_config: + mp_config = set(self.tensor_parallel_config.split(" ")) + else: + mp_config = set(self.tensor_parallel_config.split(",")) for x in mp_config: if len(x) > 0: @@ -1470,7 +1478,10 @@ def is_segment_parallel_supported(): elif ShardingOption.FULL_SHARD in self.sharding: sharding.stage = 3 - sharding_parallel_config = split_parallel_config(self.sharding_parallel_config) + if " " in self.sharding_parallel_config: + sharding_parallel_config = set(self.sharding_parallel_config.split(" ")) + else: + sharding_parallel_config = set(self.sharding_parallel_config.split(",")) for x in sharding_parallel_config: if len(x) > 0: if x not in [ From 38b654421752b227bb185aca79fdb7fb7ed5b704 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Fri, 9 Aug 2024 15:46:24 +0800 Subject: [PATCH 05/30] add flag --- paddlenlp/trainer/training_args.py | 71 +++++++++++++++++------------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 1b1713f4a566..1d170ab4216b 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -245,6 +245,7 @@ class TrainingArguments: enable_mp_async_allreduce, it supports all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear backward when it set True, which can accelerate model parallel performance. enable_mp_skip_c_identity, it supports skip c_identity in ColumnParallelLinear and RowParallelLinear. It only works when set mp_async_allreduce is True. It can accelerate model parallel further. enable_mp_fused_linear_param_grad_add, it supports fused_linear_param_grad_add in ColumnParallelLinear (cuda >= 11.6). It only works when mp_async_allreduce is true. It can accelerate model parallel further. + enable_sp_async_reduce_scatter, it supports async reduce_scatter in ColumnSequenceParallelLinear. It only works when set sp_async_reduce_scatter is True. It can accelerate sequence parallel further. enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly. sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False. sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False. @@ -270,7 +271,7 @@ class TrainingArguments: enable_stage1_broadcast_overlap, overlap stage1 V1 broadcast with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap forward compute and no other sync could be called during the training for broadcast overlap. enable_stage1_allgather_overlap, overlap stage1 V2 allgather with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for allgather overlap forward compute and no other sync could be called during the training for allgather overlap. disable_stage1_reduce_avg, replace reduce_avg with original reduce_sum+scale in stage1, which can be used for accuracy verification. - enable_release_graHEADds, reduce peak memory usage by releasing gradients after each iteration. The creation of gradients will be postponed until backward propagation of the next iteration. + enable_release_grads, reduce peak memory usage by releasing gradients after each iteration. The creation of gradients will be postponed until backward propagation of the next iteration. recompute (`bool`, *optional*, defaults to `False`): Recompute the forward pass to calculate gradients. Used for saving memory. Only support for networks with transformer blocks. @@ -354,6 +355,10 @@ class TrainingArguments: Whether skip profile timer, timer will record time usage of forward/ backward/ step, etc. distributed_dataloader (`bool`, *optional*): Whether to use distributed dataloader. Default is `False`. + release_grads (`bool`, *optional*): + Whether to release gradients during training. Default is `False`. + resume_form_hybrid_parallel (`bool`, *optional*): + Wether hybrid paralle checkpoints be loaded in auto parallel mode. """ output_dir: str = field( @@ -629,6 +634,7 @@ class TrainingArguments: "enable_mp_async_allreduce, it supports all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear backward when it set True, which can accelerate model parallel performance. \n" "enable_mp_skip_c_identity, it supports skip c_identity in ColumnParallelLinear and RowParallelLinear. It only works when set mp_async_allreduce is True. It can accelerate model parallel further.\n" "enable_mp_fused_linear_param_grad_add, it supports fused_linear_param_grad_add in ColumnParallelLinear (cuda >= 11.6). It only works when mp_async_allreduce is true. It can accelerate model parallel further.\n" + "enable_sp_async_reduce_scatter, it supports async reduce_scatter in ColumnSequenceParallelLinear. It only works when set sp_async_reduce_scatter is True. It can accelerate sequence parallel further.\n" "enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.\n" "sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.\n" "sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.\n" @@ -830,10 +836,13 @@ class TrainingArguments: default=False, metadata={"help": "Enable MoE (Mixture of Experts) expert parallel training"}, ) + release_grads: Optional[bool] = field( + default=False, metadata={"help": "Whether to release gradients during training. Default is `False`."} + ) resume_form_hybrid_parallel: Optional[bool] = field( default=False, - metadata={"help": "Wether hybrid paralle checkpoints be loaded in automatic parallel mode"}, + metadata={"help": "Wether hybrid paralle checkpoints be loaded in auto parallel mode."}, ) def __post_init__(self): @@ -1027,6 +1036,13 @@ def __post_init__(self): logger.warning("set amp_master_grad to false since amp is disabled.") self.amp_master_grad = False + def split_parallel_config(parallel_config): + if "," in parallel_config: + parallel_config = set(parallel_config.split(",")) + else: + parallel_config = set(parallel_config.split(" ")) + return parallel_config + # use_hybrid_parallel if self.use_hybrid_parallel: @@ -1044,10 +1060,7 @@ def __post_init__(self): strategy = fleet.DistributedStrategy() assert self.data_parallel_config == "", "data_parallle_config is not supported in hybrid parallel" if self.pipeline_parallel_degree > 1: - if " " in self.pipeline_parallel_config: - pipeline_parallel_config = set(self.pipeline_parallel_config.split(" ")) - else: - pipeline_parallel_config = set(self.pipeline_parallel_config.split(",")) + pipeline_parallel_config = split_parallel_config(self.pipeline_parallel_config) for x in pipeline_parallel_config: if len(x) > 0: if x not in [ @@ -1121,10 +1134,7 @@ def __post_init__(self): if self.tensor_parallel_degree > 1: strategy.tensor_parallel_configs = {"tensor_init_seed": self.seed} - if " " in self.tensor_parallel_config: - mp_config = set(self.tensor_parallel_config.split(" ")) - else: - mp_config = set(self.tensor_parallel_config.split(",")) + mp_config = split_parallel_config(self.tensor_parallel_config) for x in mp_config: if len(x) > 0: @@ -1132,6 +1142,7 @@ def __post_init__(self): "enable_mp_async_allreduce", "enable_mp_skip_c_identity", "enable_mp_fused_linear_param_grad_add", + "enable_sp_async_reduce_scatter", "enable_delay_scale_loss", "sync_param", "sync_grad", @@ -1139,7 +1150,7 @@ def __post_init__(self): ]: raise ValueError( f"Found unknown tensor parallell config {x}, " - f"accept config is enable_mp_async_allreduce, enable_mp_skip_c_identity, enable_mp_fused_linear_param_grad_add, sync_param, sync_grad and sync_moment." + f"accept config is enable_mp_async_allreduce, enable_mp_skip_c_identity, enable_mp_fused_linear_param_grad_add, enable_sp_async_reduce_scatter, enable_delay_scale_loss, sync_param, sync_grad and sync_moment." ) try: if "enable_mp_async_allreduce" in mp_config: @@ -1157,6 +1168,8 @@ def __post_init__(self): warnings.warn( "enable_mp_fused_linear_param_grad_add only works with enable_mp_async_allreduce. It will not work." ) + if "enable_sp_async_reduce_scatter" in mp_config: + strategy.hybrid_configs["mp_configs"].sp_async_reduce_scatter = True sync_param = "sync_param" in mp_config sync_grad = "sync_grad" in mp_config @@ -1165,15 +1178,23 @@ def __post_init__(self): # sync_param_name = [""] matches any parameter name. # If sync_param, sync_grad and sync_moment are not set, the default value in Paddle is : # sync_param = True, sync_grad = False, sync_moment = False, sync_param_name = ["embedding", "layer_norm", ".b_"]. + + if sync_param or sync_grad or sync_moment: + logger.info("setting sync_param_name") + strategy.sync_param_name = [""] + if sync_param: + logger.info("setting sync_param") strategy.hybrid_configs["mp_configs"].sync_param = True - strategy.hybrid_configs["mp_configs"].sync_param_name = [""] + if sync_grad: + logger.info("setting sync_grad") strategy.hybrid_configs["mp_configs"].sync_grad = True - strategy.hybrid_configs["mp_configs"].sync_grad_name = [""] + if sync_moment: + logger.info("setting sync_moment") strategy.hybrid_configs["mp_configs"].sync_moment = True - strategy.hybrid_configs["mp_configs"].sync_moment_name = [""] + except: warnings.warn( "The enable_mp_async_allreduce, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add are not supported " @@ -1230,10 +1251,8 @@ def is_segment_parallel_supported(): strategy.hybrid_configs = hybrid_configs if self.sharding_parallel_degree > 1: - if " " in self.sharding_parallel_config: - sharding_parallel_config = set(self.sharding_parallel_config.split(" ")) - else: - sharding_parallel_config = set(self.sharding_parallel_config.split(",")) + sharding_parallel_config = split_parallel_config(self.sharding_parallel_config) + for x in sharding_parallel_config: if len(x) > 0: if x not in [ @@ -1389,10 +1408,7 @@ def is_segment_parallel_supported(): # navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1 if self.pipeline_parallel_degree > 1 and self.gradient_accumulation_steps > 1: - if " " in self.pipeline_parallel_config: - pipeline_parallel_config = set(self.pipeline_parallel_config.split(" ")) - else: - pipeline_parallel_config = set(self.pipeline_parallel_config.split(",")) + pipeline_parallel_config = split_parallel_config(self.pipeline_parallel_config) for x in pipeline_parallel_config: if len(x) > 0: if x not in [ @@ -1441,11 +1457,7 @@ def is_segment_parallel_supported(): if self.tensor_parallel_degree > 1: mp_optimization = strategy.mp_optimization - - if " " in self.tensor_parallel_config: - mp_config = set(self.tensor_parallel_config.split(" ")) - else: - mp_config = set(self.tensor_parallel_config.split(",")) + mp_config = split_parallel_config(self.tensor_parallel_config) for x in mp_config: if len(x) > 0: @@ -1478,10 +1490,7 @@ def is_segment_parallel_supported(): elif ShardingOption.FULL_SHARD in self.sharding: sharding.stage = 3 - if " " in self.sharding_parallel_config: - sharding_parallel_config = set(self.sharding_parallel_config.split(" ")) - else: - sharding_parallel_config = set(self.sharding_parallel_config.split(",")) + sharding_parallel_config = split_parallel_config(self.sharding_parallel_config) for x in sharding_parallel_config: if len(x) > 0: if x not in [ From c3bf61aec722bb68fc715a30dabea307b24ada22 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Fri, 9 Aug 2024 15:59:19 +0800 Subject: [PATCH 06/30] fix conf --- paddlenlp/trainer/training_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 295c1ac655de..0563f67b8e0f 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -843,6 +843,7 @@ class TrainingArguments: default=False, metadata={"help": "Wether hybrid paralle checkpoints be loaded in auto parallel mode."}, ) + def __post_init__(self): env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1)) if env_local_rank != -1 and env_local_rank != self.local_rank and paddle.distributed.get_world_size() > 1: From 2b4d3de3d92769987f7b314953b97a1e020ab542 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Fri, 9 Aug 2024 17:01:55 +0800 Subject: [PATCH 07/30] fix comments --- paddlenlp/trainer/checkpoint_converter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlenlp/trainer/checkpoint_converter.py b/paddlenlp/trainer/checkpoint_converter.py index 470f66f23cef..56c7006dbf7b 100644 --- a/paddlenlp/trainer/checkpoint_converter.py +++ b/paddlenlp/trainer/checkpoint_converter.py @@ -400,7 +400,7 @@ def load_state_dict_and_rename(self): ) self.cur_rank_loaded_state_dict[file] = renamed_state_dict - # 2. In handling the sharding stage1 v1 scenario, the optimizer states are distributed across different ranks. + # 2. In handling the sharding stage1 v1 and stage2 scenario, the optimizer states are distributed across different ranks. # We need to obtain the name mapping by simulating the partitioning method, without concern for the presence of master_weights. elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: if not self.save_sharded_model: @@ -515,7 +515,7 @@ def load_state_dict_and_rename(self): self.cur_rank_loaded_state_dict[file] = renamed_state_dict else: - # 3. Handling the case of disabling sharding, independent of master_weights, but without considering the save_sharded_model flag. + # 3. Handling the sharding stage3 and non-sharding scenario if not self.save_sharded_model: for file, state_dict in self.cur_rank_loaded_state_dict.items(): (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) From 650404acf9853052244fb34aed3b3da1d2e05cca Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Mon, 12 Aug 2024 11:38:05 +0800 Subject: [PATCH 08/30] fix --- paddlenlp/trainer/auto_trainer.py | 69 +- paddlenlp/trainer/checkpoint_converter.py | 1036 --------------------- 2 files changed, 14 insertions(+), 1091 deletions(-) delete mode 100644 paddlenlp/trainer/checkpoint_converter.py diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 4d94ebae0ed3..e9ba8a2eabd4 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -50,7 +50,6 @@ MODEL_NAME = "model" OPTIMIZER_NAME = "optimizer" DIST_CKPT_PATH = "dist_ckpt" -FREE_SVAE_LOAD_KEY_PATTERNS = ["learning_rate_", "gradient_merge_", "@GRAD@MERG", "eager_tmp"] class AutoTrainer(Trainer): @@ -159,43 +158,20 @@ def _split_batches_for_accumulation(self, inputs): if self.args.gradient_accumulation_steps == 1: return [inputs] + # if self.args.to_static: if self.args.to_static and self.args.pipeline_parallel_degree > 1: return [inputs] local_batches = [{} for i in range(self.args.gradient_accumulation_steps)] - assert isinstance(inputs, dict) - def split_dtensor_by_axis(dtensor, axis): - mesh = dtensor.process_mesh - placements = [dist.Replicate() for _ in range(len(mesh.shape))] - replicate_value = dist.reshard(dtensor, mesh, placements) + for key, value in inputs.items(): + ori_mesh, ori_placements = value.process_mesh, value.placements + replicate_value = dist.reshard(value, ori_mesh, [dist.Replicate(), dist.Replicate()]) local_datas = replicate_value.split(self.args.gradient_accumulation_steps, axis=0) - return local_datas - - for key, dtensors in inputs.items(): - if isinstance(dtensors, paddle.Tensor): - mesh, placements = dtensors.process_mesh, dtensors.placements - local_datas = split_dtensor_by_axis(dtensors, 0) - for index, data in enumerate(local_datas): - local_batches[index].update({key: dist.reshard(data, mesh, placements)}) - elif isinstance(dtensors, (list, tuple)): - if len(dtensors) == 0: - for i in range(self.args.gradient_accumulation_steps): - local_batches[i].update({key: []}) - else: - for dtensor in dtensors: - if isinstance(dtensor, paddle.Tensor): - mesh, placements = dtensor.process_mesh, dtensor.placements - local_datas = split_dtensor_by_axis(dtensor, 0) - for index, data in enumerate(local_datas): - if key in local_batches[index].keys(): - local_batches[index][key].append(dist.reshard(data, mesh, placements)) - else: - local_batches[index].update({key: [dist.reshard(data, mesh, placements)]}) - else: - raise ValueError(f"unsupported type: {type(dtensor)}") - else: - raise ValueError(f"unsupported type: {type(dtensors)}") + + for index, data in enumerate(local_datas): + local_batches[index].update({key: dist.reshard(data, ori_mesh, ori_placements)}) + return local_batches def _inner_training_loop( @@ -568,15 +544,7 @@ def _save_checkpoint(self, model, metrics=None): if self.args.should_save_model_state: if self.args.to_static: - opt_state_dict = { - key: value - for key, value in model.state_dict("opt").items() - if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS) - } - state_dict = { - MODEL_NAME: model.state_dict("param"), - OPTIMIZER_NAME: opt_state_dict, - } + state_dict = model.state_dict() else: optim_state_dict = self.optimizer.state_dict() optim_state_dict.pop("LR_Scheduler", None) @@ -697,15 +665,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): ) if self.args.to_static: - opt_state_dict = { - key: value - for key, value in self.model_wrapped.state_dict("opt").items() - if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS) - } - state_dict = { - MODEL_NAME: self.model_wrapped.state_dict("param"), - OPTIMIZER_NAME: opt_state_dict, - } + state_dict = self.model_wrapped.state_dict() else: model_state_dict = self.model_wrapped.state_dict() optim_state_dict = self.optimizer.state_dict() @@ -717,11 +677,10 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): optim_state_dict = self.optimizer.state_dict() optim_state_dict.pop("LR_Scheduler", None) - state_dict = {} - for k, v in model_state_dict.items(): - state_dict[k] = v - for k, v in optim_state_dict.items(): - state_dict[k] = v + state_dict = { + MODEL_NAME: model_state_dict, + OPTIMIZER_NAME: optim_state_dict, + } if self.args.resume_form_hybrid_parallel: CheckpointConverter( diff --git a/paddlenlp/trainer/checkpoint_converter.py b/paddlenlp/trainer/checkpoint_converter.py deleted file mode 100644 index 56c7006dbf7b..000000000000 --- a/paddlenlp/trainer/checkpoint_converter.py +++ /dev/null @@ -1,1036 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -import re -from functools import reduce - -import paddle -import paddle.distributed as dist -from paddle.distributed.checkpoint.load_state_dict import ( - _load_state_dict, - get_local_load_files, -) -from paddle.distributed.checkpoint.metadata import ( - LocalTensorIndex, - LocalTensorMetadata, - Metadata, -) - -MODEL_WEIGHT_SUFFIX = ".pdparams" -OPTIMIZER_WEIGHT_SUFFIX = ".pdopt" -SCHEDULER_NAME = "scheduler.pdparams" -MODEL_META_FILE_NAME = "model_meta.json" - - -class CheckpointConverter: - def __init__(self, hybrid_parallel_ckpt_path, model_state, parameter_to_structured_name): - self.use_dist = True if paddle.distributed.get_world_size() > 1 else False - self.path = hybrid_parallel_ckpt_path - self.auto_parallel_state_dict = model_state - self.parameter_to_structured_name = self.gather_global_object(parameter_to_structured_name) - model_state_global_shape = {} - for k, v in model_state.items(): - model_state_global_shape[k] = v.shape - self.model_state_global_shape = self.gather_global_object(model_state_global_shape) - self.cur_rank = paddle.distributed.get_rank() - - self.save_sharded_model = self.get_save_sharded_model_flag() - - ( - self.cur_rank_model_state_file_names, - self.cur_rank_optimizer_state_file_names, - ) = self.get_local_checkpoint_file_names() - - self.global_model_state_file_names = self.gather_global_object(self.cur_rank_model_state_file_names) - - self.global_optimizer_state_file_names = self.gather_global_object(self.cur_rank_optimizer_state_file_names) - - self.initial_distributed_configuration() - - def get_save_sharded_model_flag(self): - if self.cur_rank == 1: - save_sharded_model_flag = [os.path.exists(os.path.join(self.path, MODEL_META_FILE_NAME))] - else: - save_sharded_model_flag = [] - save_sharded_model_flag = self.gather_global_object(save_sharded_model_flag) - return save_sharded_model_flag[0] - - def gather_global_object(self, cur_rank_object): - all_rank_objects = [] - if self.use_dist: - paddle.distributed.all_gather_object(all_rank_objects, cur_rank_object) - else: - all_rank_objects = [all_rank_objects] - - if isinstance(cur_rank_object, list): - return [item for sublist in all_rank_objects for item in sublist] - elif isinstance(cur_rank_object, dict): - global_map = {} - for rank_map in all_rank_objects: - global_map.update(rank_map) - return global_map - else: - raise ValueError("cur_rank_object should be either a list or a dict") - - def get_local_checkpoint_file_names(self): - cur_rank_files = os.listdir(self.path) - cur_rank_model_state_file_names = [] - cur_rank_optimizer_state_file_names = [] - for file in cur_rank_files: - if file.endswith(MODEL_WEIGHT_SUFFIX): - cur_rank_model_state_file_names.append(file) - elif file.endswith(OPTIMIZER_WEIGHT_SUFFIX): - cur_rank_optimizer_state_file_names.append(file) - if SCHEDULER_NAME in cur_rank_model_state_file_names: - cur_rank_model_state_file_names.remove(SCHEDULER_NAME) - return cur_rank_model_state_file_names, cur_rank_optimizer_state_file_names - - def get_distribution_rank_from_file_name(self, file_name): - pp_degree = 0 - tp_degree = 0 - sharding_degree = 0 - pattern_pp = r"pp(\d+)" - pattern_tp = r"tp(\d+)" - pattern_shard = r"shard(\d+)" - match_pp = re.search(pattern_pp, file_name) - if match_pp: - pp_degree = int(match_pp.group(1)) - match_tp = re.search(pattern_tp, file_name) - if match_tp: - tp_degree = int(match_tp.group(1)) - match_shard = re.search(pattern_shard, file_name) - if match_shard: - sharding_degree = int(match_shard.group(1)) - return (tp_degree, pp_degree, sharding_degree) - - def initial_distributed_configuration(self): - self.pp_degree = 0 - self.tp_degree = 0 - self.sharding_degree = 0 - - all_files = self.global_model_state_file_names + self.global_optimizer_state_file_names - - for file in all_files: - (tp_degree, pp_degree, sharding_degree) = self.get_distribution_rank_from_file_name(file) - self.pp_degree = max(self.pp_degree, pp_degree) - self.tp_degree = max(self.tp_degree, tp_degree) - self.sharding_degree = max(self.sharding_degree, sharding_degree) - - self.pp_degree = self.pp_degree + 1 - self.tp_degree = self.tp_degree + 1 - self.sharding_degree = self.sharding_degree + 1 - - def infer_sharding_stage1_v(self): - sharding_stage1_v = [2] - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - if file.endswith(OPTIMIZER_WEIGHT_SUFFIX) and sharding_stage1_v[0] == 2: - for k, v in state_dict.items(): - # Under shardingv2, the optimizer state is first flattened and then split. - if len(v.shape) != 1: - sharding_stage1_v = [1] - break - - sharding_stage1_v = self.gather_global_object(sharding_stage1_v) - if 1 in sharding_stage1_v: - return 1 - return 2 - - def infer_is_sharding_stage3(self): - if self.sharding_degree == 1: - return False - if self.pp_degree > 1 or self.tp_degree > 1: - # Currently, sharding stage 3 does not support concurrent use with tensor parallelism (TP) and pipeline parallelism (PP). - return False - - is_sharding_stage3 = True - - file_to_state_shape_mapping = {} - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): - state_shape_mapping = {} - for k, v in state_dict.items(): - state_shape_mapping[k] = v.shape - if len(v.shape) != 1: - return False - file_to_state_shape_mapping[file] = state_shape_mapping - global_file_to_state_shape_mapping = self.gather_global_object(file_to_state_shape_mapping) - - state_dict_std = global_file_to_state_shape_mapping[list(global_file_to_state_shape_mapping.keys())[0]] - - for file, state_dict in global_file_to_state_shape_mapping.items(): - if state_dict != state_dict_std: - is_sharding_stage3 = False - break - return is_sharding_stage3 - - def parse_master_weight_name_by(self, optimizer_state_name): - return optimizer_state_name.split(".")[0] - - def get_model_state_file_from(self, optimizer_state_file_name): - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(optimizer_state_file_name) - for model_state_file in self.global_model_state_file_names: - distributed_rank = self.get_distribution_rank_from_file_name(model_state_file) - if tp_rank == distributed_rank[0] and pp_rank == distributed_rank[1]: - return model_state_file - return None - - def optimizer_key_to_model_state_key(self, optimizer_key): - adamw_optimizer_key_suffix = [ - ".w_0_beta1_pow_acc_0", - ".w_0_beta2_pow_acc_0", - ".w_0_moment1_0", - ".w_0_moment2_0", - ".w_0", - ] - model_state_key = optimizer_key - for suffix in adamw_optimizer_key_suffix: - if model_state_key.endswith(suffix): - # Remove the suffix from model_state_key - model_state_key = model_state_key[: -len(suffix)] - break - return model_state_key - - def partition_parameters(self, model_state_shapes, is_sort, shard_num): - """ - Partitions parameters among sharding ranks. - - Return: - Dict[int, List] - """ - # Copy from python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py - mapping = {} - for rank_ in range(shard_num): - mapping[rank_] = [] - sizes = [0] * shard_num - - parameters = model_state_shapes.copy() - - if is_sort: - parameters.sort(key=lambda p: reduce(lambda x, y: x * y, p[1]), reverse=True) - - for param in parameters: - rank = sizes.index(min(sizes)) - mapping[rank].append(param) - numel = reduce(lambda x, y: x * y, param[1], 1) - assert numel > 0, f"param [{param[0]}] should larger than 0, but it is [{numel}]" - sizes[rank] += numel - - return mapping - - def rename_using_model_meta(self, file_name): - if not hasattr(self, "model_meta"): - try: - self.model_meta = json.load(open(os.path.join(self.path, MODEL_META_FILE_NAME))) - except Exception as e: - print(e) - distributed_rank = self.get_distribution_rank_from_file_name(file_name) - dist_strategy_key = ( - "tp" + "{:02d}".format(distributed_rank[0]) + "_" + "pp" + "{:02d}".format(distributed_rank[1]) - ) - # Map model weight names to their corresponding names of master_weights in the optimizer state. - if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): - structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"] - master_weight_name_to_model_weight_name_mapping = {} - for k, v in structure_name_mapping.items(): - master_weight_name_to_model_weight_name_mapping[v.split(".")[0]] = k - - renamed_state_dict = {} - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) - state_dict = self.cur_rank_loaded_state_dict[file_name] - for k, v in state_dict.items(): - master_weight_name = self.parse_master_weight_name_by(k) - model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] - new_k = k.replace(master_weight_name, model_weight_name) - renamed_state_dict[new_k] = v - return renamed_state_dict - else: - return self.cur_rank_loaded_state_dict[file_name] - - def rename_using_optimizer_state_order(self, file_name): - if not hasattr(self, "global_file_to_state_dict_keys_mapping"): - file_to_state_dict_keys_mapping = {} - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - file_to_state_dict_keys_mapping[file] = list(state_dict.keys()) - - self.global_file_to_state_dict_keys_mapping = self.gather_global_object(file_to_state_dict_keys_mapping) - - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) - if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): - model_state_file_name = self.get_model_state_file_from(file_name) - assert model_state_file_name is not None - model_state_keys = self.global_file_to_state_dict_keys_mapping[model_state_file_name] - optimizer_state_keys = self.global_file_to_state_dict_keys_mapping[file_name] - - master_weight_name_to_model_weight_name_mapping = {} - for i in range(len(model_state_keys)): - master_weight_name = self.parse_master_weight_name_by(optimizer_state_keys[i]) - master_weight_name_to_model_weight_name_mapping[master_weight_name] = model_state_keys[i] - - state_dict = self.cur_rank_loaded_state_dict[file_name] - renamed_state_dict = {} - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) - for k, v in state_dict.items(): - master_weight_name = self.parse_master_weight_name_by(k) - model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] - new_k = k.replace(master_weight_name, model_weight_name) - renamed_state_dict[new_k] = v - - return renamed_state_dict - else: - return self.cur_rank_loaded_state_dict[file_name] - - def load_state_dict_and_rename(self): - rank_access_files = {} - if self.save_sharded_model: - rank_access_files[self.cur_rank] = self.cur_rank_optimizer_state_file_names - else: - rank_access_files[self.cur_rank] = ( - self.cur_rank_model_state_file_names + self.cur_rank_optimizer_state_file_names - ) - - need_read_files = get_local_load_files(self.gather_global_object(rank_access_files)) - - self.cur_rank_loaded_state_dict = {} - - for file in need_read_files: - self.cur_rank_loaded_state_dict[file] = paddle.load(os.path.join(self.path, file)) - - file_to_master_weights_keys = {} - - self.optimizer_state_with_master_weights = False - - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): - state_dict.pop("LR_Scheduler") - if "master_weights" in state_dict: - self.optimizer_state_with_master_weights = True - master_weights = state_dict.pop("master_weights") - file_to_master_weights_keys[file] = list(master_weights.keys()) - for k, v in master_weights.items(): - # In sharding stage3, ‘@slice’ will be added in front of the key for master_weight, which is removed here. - k = k.replace("slice@", "") - state_dict[k] = v - - # Standardize the state names of the AdamW optimizer. - adamw_optimizer_param_suffix_name_mapping = { - ".w_0_fp32_master_0_moment1_0": ".w_0_moment1_0", - ".w_0_fp32_master_0_moment2_0": ".w_0_moment2_0", - ".w_0_fp32_master_0_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", - ".w_0_fp32_master_0_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", - } - - unified_name_state_dict = {} - for k, v in state_dict.items(): - new_k = k - for suffix in adamw_optimizer_param_suffix_name_mapping: - if k.endswith(suffix): - new_k = k.replace(suffix, adamw_optimizer_param_suffix_name_mapping[suffix]) - break - unified_name_state_dict[new_k] = v - - self.cur_rank_loaded_state_dict[file] = unified_name_state_dict - - # After the rank has finished loading the files it needs, it can infer sharding_stage1_v and is_sharding_stage3. - self.sharding_stage1_v = self.infer_sharding_stage1_v() - self.is_sharding_stage3 = self.infer_is_sharding_stage3() - - # In sharding stage3, the parameters need to be reordered based on whether they are sliced. - # The threshold for determining whether to slice is segment_size, with a default value of 2**20. - # However, sharding stage3 allows users to specify their own unsliced layers, which seems to be incompatible here. - if self.is_sharding_stage3: - segment_size = 2**20 - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - if file.endswith(MODEL_WEIGHT_SUFFIX): - sliced_pramaeters = [] - unseliced_pramaeters = [] - sorted_state_dict = {} - for k, v in state_dict.items(): - if v.numel() > segment_size: - sliced_pramaeters.append(k) - else: - unseliced_pramaeters.append(k) - for k in sliced_pramaeters + unseliced_pramaeters: - sorted_state_dict[k] = state_dict.pop(k) - self.cur_rank_loaded_state_dict[file] = sorted_state_dict - - self.global_file_to_master_weights_keys = self.gather_global_object(file_to_master_weights_keys) - - # rename and record sharded_tensor_info - cur_rank_sharded_tensor_infos = {} - - # 1. Handling the sharding stage1 v2 scenario, where the save_sharded_model flag must be enabled, independent of master_weights. - if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: - assert self.save_sharded_model - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, - # and then append the tp_degree. - renamed_state_dict = self.rename_using_model_meta(file) - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) - for new_k, v in renamed_state_dict.items(): - if new_k not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[new_k] = [ - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ] - else: - cur_rank_sharded_tensor_infos[new_k].append( - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ) - - self.cur_rank_loaded_state_dict[file] = renamed_state_dict - # 2. In handling the sharding stage1 v1 and stage2 scenario, the optimizer states are distributed across different ranks. - # We need to obtain the name mapping by simulating the partitioning method, without concern for the presence of master_weights. - elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: - if not self.save_sharded_model: - file_to_state_dict_shapes_mapping = {} - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - shapes = [] - for k, v in state_dict.items(): - shapes.append([k, v.shape]) - file_to_state_dict_shapes_mapping[file] = shapes - - global_file_to_state_dict_shapes_mapping = self.gather_global_object(file_to_state_dict_shapes_mapping) - - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) - sharding_optimizer_state_shards = [] - if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): - for k, v in global_file_to_state_dict_shapes_mapping.items(): - (tp_rank_, pp_rank_, sharding_rank_) = self.get_distribution_rank_from_file_name(k) - if tp_rank == tp_rank_ and pp_rank == pp_rank_ and k.endswith(OPTIMIZER_WEIGHT_SUFFIX): - sharding_optimizer_state_shards.append([v, sharding_rank_]) - model_state_file_name = self.get_model_state_file_from(file) - model_state_shapes = global_file_to_state_dict_shapes_mapping[model_state_file_name] - sharding_optimizer_state_shards.sort(key=lambda x: x[1]) - - partition_result_0 = self.partition_parameters(model_state_shapes, False, self.sharding_degree) - partition_result_1 = self.partition_parameters(model_state_shapes, True, self.sharding_degree) - - for k, v in partition_result_0.items(): - v = sorted(v, key=model_state_shapes.index) - partition_result_0[k] = v - - for k, v in partition_result_1.items(): - v = sorted(v, key=model_state_shapes.index) - partition_result_1[k] = v - - sharding_sort_parameters = False - - for i in range(len(sharding_optimizer_state_shards)): - if not sharding_sort_parameters: - state_shard = sharding_optimizer_state_shards[i][0] - partitioned_shard = partition_result_0[i] - for j in range(len(partitioned_shard)): - if partitioned_shard[j][1] != state_shard[j][1]: - sharding_sort_parameters = True - break - - if sharding_sort_parameters: - for i in range(len(sharding_optimizer_state_shards)): - state_shard = sharding_optimizer_state_shards[i][0] - partitioned_shard = partition_result_1[i] - for j in range(len(partitioned_shard)): - assert partitioned_shard[j][1] == state_shard[j][1] - - if sharding_sort_parameters: - partition_result = partition_result_1 - else: - partition_result = partition_result_0 - - master_weight_name_to_model_weight_name_mapping = {} - for i in range(len(sharding_optimizer_state_shards)): - state_shard = sharding_optimizer_state_shards[i][0] - partitioned_shard = partition_result[i] - for j in range(len(partitioned_shard)): - master_weight_name = self.parse_master_weight_name_by(state_shard[j][0]) - master_weight_name_to_model_weight_name_mapping[ - master_weight_name - ] = partitioned_shard[j][0] - - renamed_state_dict = {} - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) - - # In this branch, sharding does not split the optimizer states; it merely relocates them to different cards. - # Therefore, the sharding information can now be directly removed. - for k, v in state_dict.items(): - master_weight_name = self.parse_master_weight_name_by(k) - model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] - new_k = k.replace(master_weight_name, model_weight_name) - renamed_state_dict[new_k] = v - if new_k not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[new_k] = [ - [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] - ] - else: - cur_rank_sharded_tensor_infos[new_k].append( - [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] - ) - - self.cur_rank_loaded_state_dict[file] = renamed_state_dict - else: - for k, v in state_dict.items(): - if k not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[k] = [ - [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] - ] - else: - cur_rank_sharded_tensor_infos[k].append( - [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] - ) - else: - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - renamed_state_dict = self.rename_using_model_meta(file) - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) - for new_k, v in renamed_state_dict.items(): - if new_k not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[new_k] = [ - [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] - ] - else: - cur_rank_sharded_tensor_infos[new_k].append( - [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] - ) - - self.cur_rank_loaded_state_dict[file] = renamed_state_dict - else: - # 3. Handling the sharding stage3 and non-sharding scenario - if not self.save_sharded_model: - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) - if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): - - renamed_state_dict = self.rename_using_optimizer_state_order(file) - - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) - for new_k, v in renamed_state_dict.items(): - if new_k not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[new_k] = [ - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ] - else: - cur_rank_sharded_tensor_infos[new_k].append( - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ) - - self.cur_rank_loaded_state_dict[file] = renamed_state_dict - else: - for k, v in state_dict.items(): - if k not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[k] = [ - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ] - else: - cur_rank_sharded_tensor_infos[k].append( - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ) - - else: - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, - # and then append the tp_degree. - renamed_state_dict = self.rename_using_model_meta(file) - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) - for new_k, v in renamed_state_dict.items(): - if new_k not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[new_k] = [ - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ] - else: - cur_rank_sharded_tensor_infos[new_k].append( - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ) - - self.cur_rank_loaded_state_dict[file] = renamed_state_dict - # gather global sharded tensor infos - sharded_tensor_infos = self.gather_global_object({self.cur_rank: cur_rank_sharded_tensor_infos}) - - self.global_sharded_tensor_infos = {} - for rank, sharded_tensor_info in sharded_tensor_infos.items(): - for k, v in sharded_tensor_info.items(): - if k not in self.global_sharded_tensor_infos: - self.global_sharded_tensor_infos[k] = v - else: - self.global_sharded_tensor_infos[k] += v - - def gen_metadata_for_tp_sharded_tensor(self): - for k, v in self.global_sharded_tensor_infos.items(): - v.sort(key=lambda x: x[0]["tp_rank"]) - - state_dict_metadata = {} - storage_metadata = {} - # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. - for k, v in self.global_sharded_tensor_infos.items(): - global_offset = 0 - local_shape = v[0][1] - model_state_name = self.optimizer_key_to_model_state_key(k) - if "_pow_acc_0" not in k: - global_shape = self.model_state_global_shape[model_state_name] - else: - global_shape = (1,) - - assert len(local_shape) == len(global_shape) - axis = -1 - for i in range(len(local_shape)): - if local_shape[i] != global_shape[i]: - axis = i - break - - is_replicated = axis == -1 - global_offset = [0] * len(local_shape) - - if is_replicated: - v = [v[0]] - - for item in v: - local_tensor_meta_data = LocalTensorMetadata(tuple(global_offset), item[1], item[2]) - local_tensor_index = LocalTensorIndex(k, tuple(global_offset)) - global_offset[axis] += item[1][axis] - if k not in state_dict_metadata: - state_dict_metadata[k] = [local_tensor_meta_data] - else: - state_dict_metadata[k].append(local_tensor_meta_data) - storage_metadata[local_tensor_index] = item[3] - - metadata = Metadata(state_dict_metadata, storage_metadata, None) - source_state_dict = self.cur_rank_loaded_state_dict - - return metadata, source_state_dict - - def gen_metadata_and_prepare_source_state_dict(self): - self.load_state_dict_and_rename() - if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: - for k, v in self.global_sharded_tensor_infos.items(): - v.sort(key=lambda x: x[0]["sharding_rank"]) - - state_dict_metadata = {} - storage_metadata = {} - # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. - for k, v in self.global_sharded_tensor_infos.items(): - global_offset = [0] * self.tp_degree - for item in v: - tp_rank = item[0]["tp_rank"] - k_with_tp_rank = k + "_tp" + "{:02d}".format(tp_rank) - local_tensor_meta_data = LocalTensorMetadata((global_offset[tp_rank],), item[1], item[2]) - local_tensor_index = LocalTensorIndex(k_with_tp_rank, (global_offset[tp_rank],)) - global_offset[tp_rank] += item[1][0] - if k_with_tp_rank not in state_dict_metadata: - state_dict_metadata[k_with_tp_rank] = [local_tensor_meta_data] - else: - state_dict_metadata[k_with_tp_rank].append(local_tensor_meta_data) - storage_metadata[local_tensor_index] = item[3] - - metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) - - source_state_dict_for_merge_sharding = {} - for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): - renamed_state_dict = {} - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) - for k, v in state_dict.items(): - if self.global_sharded_tensor_infos[k][0][0]["tp_rank"] != -1: - k_with_tp_rank = k + "_tp" + "{:02d}".format(tp_rank) - renamed_state_dict[k_with_tp_rank] = v - else: - renamed_state_dict[k] = v - - source_state_dict_for_merge_sharding[file_name] = renamed_state_dict - - assert self.model_meta is not None - global_model_state_shapes = [] - sharding_metas_keys = [] - for i in range(self.pp_degree): - for j in range(self.tp_degree): - sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j)) - for key in sharding_metas_keys: - param_meta = self.model_meta["sharding_metas"][key]["param_meta"] - for k, v in param_meta.items(): - global_model_state_shapes.append([k, v[0]]) - - # Distribute all model parameters evenly across each card for loading - - world_size = paddle.distributed.get_world_size() - - partition_mapping = self.partition_parameters(global_model_state_shapes, True, world_size) - - partition_model_state_keys = [] - for cur_rank, partition_model_state in partition_mapping.items(): - partition_model_state_keys.append([item[0] for item in partition_model_state]) - - param_meta = {} - for i in range(self.tp_degree): - for j in range(self.pp_degree): - key = "tp{:02d}_pp{:02d}".format(i, j) - pm = self.model_meta["sharding_metas"][key]["param_meta"] - for k, v in pm.items(): - param_meta[k] = v - - param_flattened_shapes = {} - for k, v in param_meta.items(): - param_flattened_shapes[k] = reduce(lambda x, y: x * y, v[0]) - - cur_rank_need_load_model_state_keys = partition_model_state_keys[self.cur_rank] - - # Generate the optimizer states corresponding to the model weights. - optimizer_state_dict = {} - for key in cur_rank_need_load_model_state_keys: - for tp_rank in range(self.tp_degree): - tp_rank_suffix = "_tp{:02d}".format(tp_rank) - optimizer_state_dict[key + ".w_0_moment1_0" + tp_rank_suffix] = paddle.zeros( - (param_flattened_shapes[key],), "float32" - ) - optimizer_state_dict[key + ".w_0_moment2_0" + tp_rank_suffix] = paddle.zeros( - (param_flattened_shapes[key],), "float32" - ) - if self.optimizer_state_with_master_weights: - optimizer_state_dict[key + ".w_0" + tp_rank_suffix] = paddle.zeros( - (param_flattened_shapes[key],), "float32" - ) - # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. - # Later, when these are compared with the global shape, we realize that they are replicated. - - optimizer_state_dict[key + ".w_0_beta1_pow_acc_0" + tp_rank_suffix] = paddle.zeros((1,), "float32") - optimizer_state_dict[key + ".w_0_beta2_pow_acc_0" + tp_rank_suffix] = paddle.zeros((1,), "float32") - - # merge sharding - _load_state_dict(optimizer_state_dict, source_state_dict_for_merge_sharding, [metadata_for_merge_sharding]) - - # Reshape - for k, v in optimizer_state_dict.items(): - if v.shape[0] > 1 and "_tp" in k: - param_name = self.optimizer_key_to_model_state_key(k[:-5]) - param_shape = param_meta[param_name][0] - assert v.numel() == reduce(lambda x, y: x * y, param_shape) - reshaped_v = v.reshape(param_shape) - optimizer_state_dict[k] = reshaped_v - concat_optimier_state_dict = {} - - optimizer_state_key_to_tp_keys = {} - for key in optimizer_state_dict.keys(): - # Count how each key is split into keys ending with ‘_tpXX’. - # optimizer_state_key_to_tp_keys : {key:[key_tp00,key_tp01]} - key_removed_tp_rank = key[:-5] - if key_removed_tp_rank not in optimizer_state_key_to_tp_keys: - optimizer_state_key_to_tp_keys[key_removed_tp_rank] = [key] - else: - optimizer_state_key_to_tp_keys[key_removed_tp_rank].append(key) - - for key, value in optimizer_state_key_to_tp_keys.items(): - value.sort(key=lambda x: int(x[-2:])) - - for key, tp_keys in optimizer_state_key_to_tp_keys.items(): - model_state_name = self.optimizer_key_to_model_state_key(key) - local_shape = optimizer_state_dict[tp_keys[0]].shape - if "_pow_acc_0" not in key: - global_shape = self.model_state_global_shape[model_state_name] - else: - global_shape = (1,) - - assert len(local_shape) == len(global_shape) - - axis = -1 - for i in range(len(local_shape)): - if local_shape[i] != global_shape[i]: - axis = i - break - - is_replicated = axis == -1 - tp_tensors = [] - for tp_key in tp_keys: - tp_tensors.append(optimizer_state_dict[tp_key]) - - if not is_replicated: - # Derive the partition strategy based on the global_shape, then concatenate. - concat_optimier_state_dict[key] = paddle.concat(tp_tensors, axis=axis) - else: - concat_optimier_state_dict[key] = tp_tensors[0] - - fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" - local_tensor_meta_data = {} - local_tensor_index = {} - for k, v in concat_optimier_state_dict.items(): - # Generate metadata. - local_shape = v.shape - global_offset = tuple([0] * len(local_shape)) - dtype = str(v.dtype).split(".")[1] - local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) - local_tensor_index[k] = [LocalTensorIndex(k, global_offset), fake_file_name] - - global_local_tensor_meta_data = [] - global_local_tensor_index = [] - - use_dist = True if paddle.distributed.get_world_size() > 1 else False - - if use_dist: - paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) - paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) - else: - global_local_tensor_meta_data = [local_tensor_meta_data] - global_local_tensor_index = [local_tensor_index] - - state_dict_metadata = {} - for tensor_meta_data in global_local_tensor_meta_data: - for k, v in tensor_meta_data.items(): - if k not in state_dict_metadata: - state_dict_metadata[k] = [v] - else: - state_dict_metadata[k].append(v) - - storage_metadata = {} - for tensor_index in global_local_tensor_index: - for k, v in tensor_index.items(): - storage_metadata[v[0]] = v[1] - - meta_data = Metadata(state_dict_metadata, storage_metadata, None) - source_state_dict = {fake_file_name: concat_optimier_state_dict} - - return meta_data, source_state_dict - - elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: - return self.gen_metadata_for_tp_sharded_tensor() - else: - if self.is_sharding_stage3: - for k, v in self.global_sharded_tensor_infos.items(): - v.sort(key=lambda x: x[0]["sharding_rank"]) - state_dict_metadata = {} - storage_metadata = {} - # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. - for k, v in self.global_sharded_tensor_infos.items(): - global_offset = 0 - for item in v: - if len(item[1]) == 1: - local_tensor_meta_data = LocalTensorMetadata((global_offset,), item[1], item[2]) - local_tensor_index = LocalTensorIndex(k, (global_offset,)) - global_offset += item[1][0] - else: - global_offset = tuple([0] * len(item[1])) - local_tensor_meta_data = LocalTensorMetadata(global_offset, item[1], item[2]) - local_tensor_index = LocalTensorIndex(k, global_offset) - if k not in state_dict_metadata: - state_dict_metadata[k] = [local_tensor_meta_data] - else: - state_dict_metadata[k].append(local_tensor_meta_data) - storage_metadata[local_tensor_index] = item[3] - - metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) - model_state_shapes = [] - dtype = "" - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - if file.endswith(MODEL_WEIGHT_SUFFIX): - for k, v in state_dict.items(): - model_state_shapes.append([k, v.shape]) - dtype = str(v.dtype).split(".")[1] - - dtypes = self.gather_global_object([dtype]) - for dtype_s in dtypes: - if len(dtype_s) > 0: - dtype = dtype_s - - assert len(dtype) > 0 - - global_model_state_shapes = self.gather_global_object(model_state_shapes) - - partition_result = self.partition_parameters( - global_model_state_shapes, True, paddle.distributed.get_world_size() - ) - - cur_rank_merger_model_params = partition_result[self.cur_rank] - target_state_dict = {} - for item in cur_rank_merger_model_params: - key = item[0] - shape = item[1] - flatten_shape = reduce(lambda a, b: a * b, item[1]) - target_state_dict[key] = paddle.zeros(shape, dtype) - target_state_dict[key + ".w_0_moment1_0"] = paddle.zeros((flatten_shape,), "float32") - target_state_dict[key + ".w_0_moment2_0"] = paddle.zeros((flatten_shape,), "float32") - if self.optimizer_state_with_master_weights: - target_state_dict[key + ".w_0"] = paddle.zeros((flatten_shape,), "float32") - # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. - # Later, when these are compared with the global shape, we realize that they are replicated. - - target_state_dict[key + ".w_0_beta1_pow_acc_0"] = paddle.zeros((1,), "float32") - target_state_dict[key + ".w_0_beta2_pow_acc_0"] = paddle.zeros((1,), "float32") - - _load_state_dict(target_state_dict, self.cur_rank_loaded_state_dict, [metadata_for_merge_sharding]) - - # Reshape - for item in cur_rank_merger_model_params: - key = item[0] - shape = item[1] - for k, v in target_state_dict.items(): - if key == self.optimizer_key_to_model_state_key(k): - if tuple(shape) != tuple(v.shape) and v.numel() == reduce(lambda x, y: x * y, shape): - reshaped_v = v.reshape(shape) - target_state_dict[k] = reshaped_v - - fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" - local_tensor_meta_data = {} - local_tensor_index = {} - for k, v in target_state_dict.items(): - # Generate metadata. - local_shape = v.shape - global_offset = tuple([0] * len(local_shape)) - dtype = str(v.dtype).split(".")[1] - local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) - local_tensor_index[k] = [LocalTensorIndex(k, global_offset), fake_file_name] - - global_local_tensor_meta_data = [] - global_local_tensor_index = [] - - use_dist = True if paddle.distributed.get_world_size() > 1 else False - - if use_dist: - paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) - paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) - else: - global_local_tensor_meta_data = [local_tensor_meta_data] - global_local_tensor_index = [local_tensor_index] - - state_dict_metadata = {} - for tensor_meta_data in global_local_tensor_meta_data: - for k, v in tensor_meta_data.items(): - if k not in state_dict_metadata: - state_dict_metadata[k] = [v] - else: - state_dict_metadata[k].append(v) - - storage_metadata = {} - for tensor_index in global_local_tensor_index: - for k, v in tensor_index.items(): - storage_metadata[v[0]] = v[1] - - meta_data = Metadata(state_dict_metadata, storage_metadata, None) - source_state_dict = {fake_file_name: target_state_dict} - - return meta_data, source_state_dict - else: - return self.gen_metadata_for_tp_sharded_tensor() - - def rename_semi_auto_state_dict(self): - need_remove_key_pattern = ["eager_tmp", "learning_rate", "@GRAD@MERG", "gradient_merge_"] - - need_remove_key = set() - for key in self.auto_parallel_state_dict.keys(): - for pattern in need_remove_key_pattern: - if pattern in key: - need_remove_key.add(key) - break - - for key in need_remove_key: - self.auto_parallel_state_dict.pop(key) - - adamw_optimizer_status_name_suffix_mappings = { - "_fp32_master_1_moment1_0": ".w_0_moment1_0", - "_fp32_master_1_moment2_0": ".w_0_moment2_0", - "_fp32_master_1_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", - "_fp32_master_1_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", - "_fp32_master_1": ".w_0", - "_moment1_0": ".w_0_moment1_0", - "_moment2_0": ".w_0_moment2_0", - "_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", - "_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", - } - - def rename(old_name, map1, map2): - for i in range(1, len(old_name)): - str1 = old_name[:i] - str2 = old_name[i:] - if (str1 in map1) and (str2 in map2): - transformed_str1 = map1[str1] - transformed_str2 = map2[str2] - return transformed_str1 + transformed_str2 - return None - - renamed_state_dict = {} - - for key, value in self.auto_parallel_state_dict.items(): - - if key in self.parameter_to_structured_name.values(): - new_name = key - else: - new_name = rename(key, self.parameter_to_structured_name, adamw_optimizer_status_name_suffix_mappings) - - assert new_name is not None - renamed_state_dict[new_name] = value - - self.auto_parallel_state_dict = renamed_state_dict - - def load_from_hybrid_parallel_checkpoint(self): - self.rename_semi_auto_state_dict() - metadata, source_state_dict = self.gen_metadata_and_prepare_source_state_dict() - if self.save_sharded_model: - model_params = {} - for k, v in self.auto_parallel_state_dict.items(): - if k in self.parameter_to_structured_name.values(): - model_params[k] = v - for k in model_params.keys(): - self.auto_parallel_state_dict.pop(k) - - appended_master_weight_names = [] - for k, v in model_params.items(): - master_weight = k + ".w_0" - if master_weight not in self.auto_parallel_state_dict: - appended_master_weight_names.append(master_weight) - tmp_tensor = paddle.zeros(v.shape, "float32") - dist_tmp_tensor = dist.shard_tensor(tmp_tensor, v.process_mesh, v.placements) - self.auto_parallel_state_dict[master_weight] = dist_tmp_tensor - - _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) - for k, v in model_params.items(): - master_weight = self.auto_parallel_state_dict[k + ".w_0"] - cast_master_weight = paddle.cast(master_weight._local_value(), "bfloat16") - paddle.assign(cast_master_weight, v._local_value()) - for k in appended_master_weight_names: - self.auto_parallel_state_dict.pop(k) - else: - _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) From 479f23ce9fc682f90004584b6fb12349c8e6a59a Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Mon, 12 Aug 2024 11:38:23 +0800 Subject: [PATCH 09/30] fix --- paddlenlp/trainer/ckpt_converter.py | 1046 +++++++++++++++++++++++++++ 1 file changed, 1046 insertions(+) create mode 100644 paddlenlp/trainer/ckpt_converter.py diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py new file mode 100644 index 000000000000..51f6ea27085b --- /dev/null +++ b/paddlenlp/trainer/ckpt_converter.py @@ -0,0 +1,1046 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import re +from functools import reduce + +import paddle +import paddle.distributed as dist +from paddle.distributed.checkpoint.load_state_dict import ( + _load_state_dict, + get_local_load_files, +) +from paddle.distributed.checkpoint.metadata import ( + LocalTensorIndex, + LocalTensorMetadata, + Metadata, +) +from paddle.distributed.checkpoint.utils import flatten_state_dict + +MODEL_WEIGHT_SUFFIX = ".pdparams" +OPTIMIZER_WEIGHT_SUFFIX = ".pdopt" +SCHEDULER_NAME = "scheduler.pdparams" +MODEL_META_FILE_NAME = "model_meta.json" + + +class CheckpointConverter: + def __init__(self, hybrid_parallel_ckpt_path, model_state, parameter_to_structured_name): + self.use_dist = True if paddle.distributed.get_world_size() > 1 else False + self.path = hybrid_parallel_ckpt_path + self.auto_parallel_state_dict = self.flatten_state_dict(model_state) + self.parameter_to_structured_name = self.gather_global_object(parameter_to_structured_name) + model_state_global_shape = {} + for k, v in model_state.items(): + model_state_global_shape[k] = v.shape + self.model_state_global_shape = self.gather_global_object(model_state_global_shape) + self.cur_rank = paddle.distributed.get_rank() + + self.save_sharded_model = self.get_save_sharded_model_flag() + + ( + self.cur_rank_model_state_file_names, + self.cur_rank_optimizer_state_file_names, + ) = self.get_local_checkpoint_file_names() + + self.global_model_state_file_names = self.gather_global_object(self.cur_rank_model_state_file_names) + + self.global_optimizer_state_file_names = self.gather_global_object(self.cur_rank_optimizer_state_file_names) + + self.initial_distributed_configuration() + + def get_save_sharded_model_flag(self): + if self.cur_rank == 1: + save_sharded_model_flag = [os.path.exists(os.path.join(self.path, MODEL_META_FILE_NAME))] + else: + save_sharded_model_flag = [] + save_sharded_model_flag = self.gather_global_object(save_sharded_model_flag) + return save_sharded_model_flag[0] + + def flatten_state_dict(self, state_dict): + flattened_state_dict = {} + flat_state_dict, mapping = flatten_state_dict(state_dict) + for k, v in flat_state_dict.items(): + last_level_key = mapping[k][-1] + assert last_level_key not in flattened_state_dict + flattened_state_dict[last_level_key] = v + return flattened_state_dict + + def gather_global_object(self, cur_rank_object): + all_rank_objects = [] + if self.use_dist: + paddle.distributed.all_gather_object(all_rank_objects, cur_rank_object) + else: + all_rank_objects = [all_rank_objects] + + if isinstance(cur_rank_object, list): + return [item for sublist in all_rank_objects for item in sublist] + elif isinstance(cur_rank_object, dict): + global_map = {} + for rank_map in all_rank_objects: + global_map.update(rank_map) + return global_map + else: + raise ValueError("cur_rank_object should be either a list or a dict") + + def get_local_checkpoint_file_names(self): + cur_rank_files = os.listdir(self.path) + cur_rank_model_state_file_names = [] + cur_rank_optimizer_state_file_names = [] + for file in cur_rank_files: + if file.endswith(MODEL_WEIGHT_SUFFIX): + cur_rank_model_state_file_names.append(file) + elif file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + cur_rank_optimizer_state_file_names.append(file) + if SCHEDULER_NAME in cur_rank_model_state_file_names: + cur_rank_model_state_file_names.remove(SCHEDULER_NAME) + return cur_rank_model_state_file_names, cur_rank_optimizer_state_file_names + + def get_distribution_rank_from_file_name(self, file_name): + pp_degree = 0 + tp_degree = 0 + sharding_degree = 0 + pattern_pp = r"pp(\d+)" + pattern_tp = r"tp(\d+)" + pattern_shard = r"shard(\d+)" + match_pp = re.search(pattern_pp, file_name) + if match_pp: + pp_degree = int(match_pp.group(1)) + match_tp = re.search(pattern_tp, file_name) + if match_tp: + tp_degree = int(match_tp.group(1)) + match_shard = re.search(pattern_shard, file_name) + if match_shard: + sharding_degree = int(match_shard.group(1)) + return (tp_degree, pp_degree, sharding_degree) + + def initial_distributed_configuration(self): + self.pp_degree = 0 + self.tp_degree = 0 + self.sharding_degree = 0 + + all_files = self.global_model_state_file_names + self.global_optimizer_state_file_names + + for file in all_files: + (tp_degree, pp_degree, sharding_degree) = self.get_distribution_rank_from_file_name(file) + self.pp_degree = max(self.pp_degree, pp_degree) + self.tp_degree = max(self.tp_degree, tp_degree) + self.sharding_degree = max(self.sharding_degree, sharding_degree) + + self.pp_degree = self.pp_degree + 1 + self.tp_degree = self.tp_degree + 1 + self.sharding_degree = self.sharding_degree + 1 + + def infer_sharding_stage1_v(self): + sharding_stage1_v = [2] + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX) and sharding_stage1_v[0] == 2: + for k, v in state_dict.items(): + # Under shardingv2, the optimizer state is first flattened and then split. + if len(v.shape) != 1: + sharding_stage1_v = [1] + break + + sharding_stage1_v = self.gather_global_object(sharding_stage1_v) + if 1 in sharding_stage1_v: + return 1 + return 2 + + def infer_is_sharding_stage3(self): + if self.sharding_degree == 1: + return False + if self.pp_degree > 1 or self.tp_degree > 1: + # Currently, sharding stage 3 does not support concurrent use with tensor parallelism (TP) and pipeline parallelism (PP). + return False + + is_sharding_stage3 = True + + file_to_state_shape_mapping = {} + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + state_shape_mapping = {} + for k, v in state_dict.items(): + state_shape_mapping[k] = v.shape + if len(v.shape) != 1: + return False + file_to_state_shape_mapping[file] = state_shape_mapping + global_file_to_state_shape_mapping = self.gather_global_object(file_to_state_shape_mapping) + + state_dict_std = global_file_to_state_shape_mapping[list(global_file_to_state_shape_mapping.keys())[0]] + + for file, state_dict in global_file_to_state_shape_mapping.items(): + if state_dict != state_dict_std: + is_sharding_stage3 = False + break + return is_sharding_stage3 + + def parse_master_weight_name_by(self, optimizer_state_name): + return optimizer_state_name.split(".")[0] + + def get_model_state_file_from(self, optimizer_state_file_name): + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(optimizer_state_file_name) + for model_state_file in self.global_model_state_file_names: + distributed_rank = self.get_distribution_rank_from_file_name(model_state_file) + if tp_rank == distributed_rank[0] and pp_rank == distributed_rank[1]: + return model_state_file + return None + + def optimizer_key_to_model_state_key(self, optimizer_key): + adamw_optimizer_key_suffix = [ + ".w_0_beta1_pow_acc_0", + ".w_0_beta2_pow_acc_0", + ".w_0_moment1_0", + ".w_0_moment2_0", + ".w_0", + ] + model_state_key = optimizer_key + for suffix in adamw_optimizer_key_suffix: + if model_state_key.endswith(suffix): + # Remove the suffix from model_state_key + model_state_key = model_state_key[: -len(suffix)] + break + return model_state_key + + def partition_parameters(self, model_state_shapes, is_sort, shard_num): + """ + Partitions parameters among sharding ranks. + + Return: + Dict[int, List] + """ + # Copy from python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py + mapping = {} + for rank_ in range(shard_num): + mapping[rank_] = [] + sizes = [0] * shard_num + + parameters = model_state_shapes.copy() + + if is_sort: + parameters.sort(key=lambda p: reduce(lambda x, y: x * y, p[1]), reverse=True) + + for param in parameters: + rank = sizes.index(min(sizes)) + mapping[rank].append(param) + numel = reduce(lambda x, y: x * y, param[1], 1) + assert numel > 0, f"param [{param[0]}] should larger than 0, but it is [{numel}]" + sizes[rank] += numel + + return mapping + + def rename_using_model_meta(self, file_name): + if not hasattr(self, "model_meta"): + try: + self.model_meta = json.load(open(os.path.join(self.path, MODEL_META_FILE_NAME))) + except Exception as e: + print(e) + distributed_rank = self.get_distribution_rank_from_file_name(file_name) + dist_strategy_key = ( + "tp" + "{:02d}".format(distributed_rank[0]) + "_" + "pp" + "{:02d}".format(distributed_rank[1]) + ) + # Map model weight names to their corresponding names of master_weights in the optimizer state. + if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): + structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"] + master_weight_name_to_model_weight_name_mapping = {} + for k, v in structure_name_mapping.items(): + master_weight_name_to_model_weight_name_mapping[v.split(".")[0]] = k + + renamed_state_dict = {} + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) + state_dict = self.cur_rank_loaded_state_dict[file_name] + for k, v in state_dict.items(): + master_weight_name = self.parse_master_weight_name_by(k) + model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] + new_k = k.replace(master_weight_name, model_weight_name) + renamed_state_dict[new_k] = v + return renamed_state_dict + else: + return self.cur_rank_loaded_state_dict[file_name] + + def rename_using_optimizer_state_order(self, file_name): + if not hasattr(self, "global_file_to_state_dict_keys_mapping"): + file_to_state_dict_keys_mapping = {} + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + file_to_state_dict_keys_mapping[file] = list(state_dict.keys()) + + self.global_file_to_state_dict_keys_mapping = self.gather_global_object(file_to_state_dict_keys_mapping) + + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) + if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): + model_state_file_name = self.get_model_state_file_from(file_name) + assert model_state_file_name is not None + model_state_keys = self.global_file_to_state_dict_keys_mapping[model_state_file_name] + optimizer_state_keys = self.global_file_to_state_dict_keys_mapping[file_name] + + master_weight_name_to_model_weight_name_mapping = {} + for i in range(len(model_state_keys)): + master_weight_name = self.parse_master_weight_name_by(optimizer_state_keys[i]) + master_weight_name_to_model_weight_name_mapping[master_weight_name] = model_state_keys[i] + + state_dict = self.cur_rank_loaded_state_dict[file_name] + renamed_state_dict = {} + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) + for k, v in state_dict.items(): + master_weight_name = self.parse_master_weight_name_by(k) + model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] + new_k = k.replace(master_weight_name, model_weight_name) + renamed_state_dict[new_k] = v + + return renamed_state_dict + else: + return self.cur_rank_loaded_state_dict[file_name] + + def load_state_dict_and_rename(self): + rank_access_files = {} + if self.save_sharded_model: + rank_access_files[self.cur_rank] = self.cur_rank_optimizer_state_file_names + else: + rank_access_files[self.cur_rank] = ( + self.cur_rank_model_state_file_names + self.cur_rank_optimizer_state_file_names + ) + + need_read_files = get_local_load_files(self.gather_global_object(rank_access_files)) + + self.cur_rank_loaded_state_dict = {} + + for file in need_read_files: + self.cur_rank_loaded_state_dict[file] = paddle.load(os.path.join(self.path, file)) + + file_to_master_weights_keys = {} + + self.optimizer_state_with_master_weights = False + + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + state_dict.pop("LR_Scheduler") + if "master_weights" in state_dict: + self.optimizer_state_with_master_weights = True + master_weights = state_dict.pop("master_weights") + file_to_master_weights_keys[file] = list(master_weights.keys()) + for k, v in master_weights.items(): + # In sharding stage3, ‘@slice’ will be added in front of the key for master_weight, which is removed here. + k = k.replace("slice@", "") + state_dict[k] = v + + # Standardize the state names of the AdamW optimizer. + adamw_optimizer_param_suffix_name_mapping = { + ".w_0_fp32_master_0_moment1_0": ".w_0_moment1_0", + ".w_0_fp32_master_0_moment2_0": ".w_0_moment2_0", + ".w_0_fp32_master_0_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", + ".w_0_fp32_master_0_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", + } + + unified_name_state_dict = {} + for k, v in state_dict.items(): + new_k = k + for suffix in adamw_optimizer_param_suffix_name_mapping: + if k.endswith(suffix): + new_k = k.replace(suffix, adamw_optimizer_param_suffix_name_mapping[suffix]) + break + unified_name_state_dict[new_k] = v + + self.cur_rank_loaded_state_dict[file] = unified_name_state_dict + + # After the rank has finished loading the files it needs, it can infer sharding_stage1_v and is_sharding_stage3. + self.sharding_stage1_v = self.infer_sharding_stage1_v() + self.is_sharding_stage3 = self.infer_is_sharding_stage3() + + # In sharding stage3, the parameters need to be reordered based on whether they are sliced. + # The threshold for determining whether to slice is segment_size, with a default value of 2**20. + # However, sharding stage3 allows users to specify their own unsliced layers, which seems to be incompatible here. + if self.is_sharding_stage3: + segment_size = 2**20 + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(MODEL_WEIGHT_SUFFIX): + sliced_pramaeters = [] + unseliced_pramaeters = [] + sorted_state_dict = {} + for k, v in state_dict.items(): + if v.numel() > segment_size: + sliced_pramaeters.append(k) + else: + unseliced_pramaeters.append(k) + for k in sliced_pramaeters + unseliced_pramaeters: + sorted_state_dict[k] = state_dict.pop(k) + self.cur_rank_loaded_state_dict[file] = sorted_state_dict + + self.global_file_to_master_weights_keys = self.gather_global_object(file_to_master_weights_keys) + + # rename and record sharded_tensor_info + cur_rank_sharded_tensor_infos = {} + + # 1. Handling the sharding stage1 v2 scenario, where the save_sharded_model flag must be enabled, independent of master_weights. + if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: + assert self.save_sharded_model + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, + # and then append the tp_degree. + renamed_state_dict = self.rename_using_model_meta(file) + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + for new_k, v in renamed_state_dict.items(): + if new_k not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[new_k] = [ + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ] + else: + cur_rank_sharded_tensor_infos[new_k].append( + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ) + + self.cur_rank_loaded_state_dict[file] = renamed_state_dict + # 2. In handling the sharding stage1 v1 and sharding stage2 scenario, the optimizer states are distributed across different ranks. + # We need to obtain the name mapping by simulating the partitioning method, without concern for the presence of master_weights. + elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: + if not self.save_sharded_model: + file_to_state_dict_shapes_mapping = {} + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + shapes = [] + for k, v in state_dict.items(): + shapes.append([k, v.shape]) + file_to_state_dict_shapes_mapping[file] = shapes + + global_file_to_state_dict_shapes_mapping = self.gather_global_object(file_to_state_dict_shapes_mapping) + + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + sharding_optimizer_state_shards = [] + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + for k, v in global_file_to_state_dict_shapes_mapping.items(): + (tp_rank_, pp_rank_, sharding_rank_) = self.get_distribution_rank_from_file_name(k) + if tp_rank == tp_rank_ and pp_rank == pp_rank_ and k.endswith(OPTIMIZER_WEIGHT_SUFFIX): + sharding_optimizer_state_shards.append([v, sharding_rank_]) + model_state_file_name = self.get_model_state_file_from(file) + model_state_shapes = global_file_to_state_dict_shapes_mapping[model_state_file_name] + sharding_optimizer_state_shards.sort(key=lambda x: x[1]) + + partition_result_0 = self.partition_parameters(model_state_shapes, False, self.sharding_degree) + partition_result_1 = self.partition_parameters(model_state_shapes, True, self.sharding_degree) + + for k, v in partition_result_0.items(): + v = sorted(v, key=model_state_shapes.index) + partition_result_0[k] = v + + for k, v in partition_result_1.items(): + v = sorted(v, key=model_state_shapes.index) + partition_result_1[k] = v + + sharding_sort_parameters = False + + for i in range(len(sharding_optimizer_state_shards)): + if not sharding_sort_parameters: + state_shard = sharding_optimizer_state_shards[i][0] + partitioned_shard = partition_result_0[i] + for j in range(len(partitioned_shard)): + if partitioned_shard[j][1] != state_shard[j][1]: + sharding_sort_parameters = True + break + + if sharding_sort_parameters: + for i in range(len(sharding_optimizer_state_shards)): + state_shard = sharding_optimizer_state_shards[i][0] + partitioned_shard = partition_result_1[i] + for j in range(len(partitioned_shard)): + assert partitioned_shard[j][1] == state_shard[j][1] + + if sharding_sort_parameters: + partition_result = partition_result_1 + else: + partition_result = partition_result_0 + + master_weight_name_to_model_weight_name_mapping = {} + for i in range(len(sharding_optimizer_state_shards)): + state_shard = sharding_optimizer_state_shards[i][0] + partitioned_shard = partition_result[i] + for j in range(len(partitioned_shard)): + master_weight_name = self.parse_master_weight_name_by(state_shard[j][0]) + master_weight_name_to_model_weight_name_mapping[ + master_weight_name + ] = partitioned_shard[j][0] + + renamed_state_dict = {} + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + + # In this branch, sharding does not split the optimizer states; it merely relocates them to different cards. + # Therefore, the sharding information can now be directly removed. + for k, v in state_dict.items(): + master_weight_name = self.parse_master_weight_name_by(k) + model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] + new_k = k.replace(master_weight_name, model_weight_name) + renamed_state_dict[new_k] = v + if new_k not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[new_k] = [ + [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] + ] + else: + cur_rank_sharded_tensor_infos[new_k].append( + [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] + ) + + self.cur_rank_loaded_state_dict[file] = renamed_state_dict + else: + for k, v in state_dict.items(): + if k not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[k] = [ + [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] + ] + else: + cur_rank_sharded_tensor_infos[k].append( + [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] + ) + else: + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + renamed_state_dict = self.rename_using_model_meta(file) + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + for new_k, v in renamed_state_dict.items(): + if new_k not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[new_k] = [ + [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] + ] + else: + cur_rank_sharded_tensor_infos[new_k].append( + [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] + ) + + self.cur_rank_loaded_state_dict[file] = renamed_state_dict + else: + # 3. Handling the case of disabling sharding, independent of master_weights, but without considering the save_sharded_model flag. + if not self.save_sharded_model: + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + + renamed_state_dict = self.rename_using_optimizer_state_order(file) + + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + for new_k, v in renamed_state_dict.items(): + if new_k not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[new_k] = [ + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ] + else: + cur_rank_sharded_tensor_infos[new_k].append( + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ) + + self.cur_rank_loaded_state_dict[file] = renamed_state_dict + else: + for k, v in state_dict.items(): + if k not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[k] = [ + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ] + else: + cur_rank_sharded_tensor_infos[k].append( + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ) + + else: + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, + # and then append the tp_degree. + renamed_state_dict = self.rename_using_model_meta(file) + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + for new_k, v in renamed_state_dict.items(): + if new_k not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[new_k] = [ + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ] + else: + cur_rank_sharded_tensor_infos[new_k].append( + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + v.shape, + str(v.dtype).split(".")[1], + file, + ] + ) + + self.cur_rank_loaded_state_dict[file] = renamed_state_dict + # gather global sharded tensor infos + sharded_tensor_infos = self.gather_global_object({self.cur_rank: cur_rank_sharded_tensor_infos}) + + self.global_sharded_tensor_infos = {} + for rank, sharded_tensor_info in sharded_tensor_infos.items(): + for k, v in sharded_tensor_info.items(): + if k not in self.global_sharded_tensor_infos: + self.global_sharded_tensor_infos[k] = v + else: + self.global_sharded_tensor_infos[k] += v + + def gen_metadata_for_tp_sharded_tensor(self): + for k, v in self.global_sharded_tensor_infos.items(): + v.sort(key=lambda x: x[0]["tp_rank"]) + + state_dict_metadata = {} + storage_metadata = {} + # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. + for k, v in self.global_sharded_tensor_infos.items(): + global_offset = 0 + local_shape = v[0][1] + model_state_name = self.optimizer_key_to_model_state_key(k) + if "_pow_acc_0" not in k: + global_shape = self.model_state_global_shape[model_state_name] + else: + global_shape = (1,) + + assert len(local_shape) == len(global_shape) + axis = -1 + for i in range(len(local_shape)): + if local_shape[i] != global_shape[i]: + axis = i + break + + is_replicated = axis == -1 + global_offset = [0] * len(local_shape) + + if is_replicated: + v = [v[0]] + + for item in v: + local_tensor_meta_data = LocalTensorMetadata(tuple(global_offset), item[1], item[2]) + local_tensor_index = LocalTensorIndex(k, tuple(global_offset)) + global_offset[axis] += item[1][axis] + if k not in state_dict_metadata: + state_dict_metadata[k] = [local_tensor_meta_data] + else: + state_dict_metadata[k].append(local_tensor_meta_data) + storage_metadata[local_tensor_index] = item[3] + + metadata = Metadata(state_dict_metadata, storage_metadata, None) + source_state_dict = self.cur_rank_loaded_state_dict + + return metadata, source_state_dict + + def gen_metadata_and_prepare_source_state_dict(self): + self.load_state_dict_and_rename() + if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: + for k, v in self.global_sharded_tensor_infos.items(): + v.sort(key=lambda x: x[0]["sharding_rank"]) + + state_dict_metadata = {} + storage_metadata = {} + # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. + for k, v in self.global_sharded_tensor_infos.items(): + global_offset = [0] * self.tp_degree + for item in v: + tp_rank = item[0]["tp_rank"] + k_with_tp_rank = k + "_tp" + "{:02d}".format(tp_rank) + local_tensor_meta_data = LocalTensorMetadata((global_offset[tp_rank],), item[1], item[2]) + local_tensor_index = LocalTensorIndex(k_with_tp_rank, (global_offset[tp_rank],)) + global_offset[tp_rank] += item[1][0] + if k_with_tp_rank not in state_dict_metadata: + state_dict_metadata[k_with_tp_rank] = [local_tensor_meta_data] + else: + state_dict_metadata[k_with_tp_rank].append(local_tensor_meta_data) + storage_metadata[local_tensor_index] = item[3] + + metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) + + source_state_dict_for_merge_sharding = {} + for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): + renamed_state_dict = {} + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) + for k, v in state_dict.items(): + if self.global_sharded_tensor_infos[k][0][0]["tp_rank"] != -1: + k_with_tp_rank = k + "_tp" + "{:02d}".format(tp_rank) + renamed_state_dict[k_with_tp_rank] = v + else: + renamed_state_dict[k] = v + + source_state_dict_for_merge_sharding[file_name] = renamed_state_dict + + assert self.model_meta is not None + global_model_state_shapes = [] + sharding_metas_keys = [] + for i in range(self.pp_degree): + for j in range(self.tp_degree): + sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j)) + for key in sharding_metas_keys: + param_meta = self.model_meta["sharding_metas"][key]["param_meta"] + for k, v in param_meta.items(): + global_model_state_shapes.append([k, v[0]]) + + # Distribute all model parameters evenly across each card for loading + + world_size = paddle.distributed.get_world_size() + + partition_mapping = self.partition_parameters(global_model_state_shapes, True, world_size) + + partition_model_state_keys = [] + for cur_rank, partition_model_state in partition_mapping.items(): + partition_model_state_keys.append([item[0] for item in partition_model_state]) + + param_meta = {} + for i in range(self.tp_degree): + for j in range(self.pp_degree): + key = "tp{:02d}_pp{:02d}".format(i, j) + pm = self.model_meta["sharding_metas"][key]["param_meta"] + for k, v in pm.items(): + param_meta[k] = v + + param_flattened_shapes = {} + for k, v in param_meta.items(): + param_flattened_shapes[k] = reduce(lambda x, y: x * y, v[0]) + + cur_rank_need_load_model_state_keys = partition_model_state_keys[self.cur_rank] + + # Generate the optimizer states corresponding to the model weights. + optimizer_state_dict = {} + for key in cur_rank_need_load_model_state_keys: + for tp_rank in range(self.tp_degree): + tp_rank_suffix = "_tp{:02d}".format(tp_rank) + optimizer_state_dict[key + ".w_0_moment1_0" + tp_rank_suffix] = paddle.zeros( + (param_flattened_shapes[key],), "float32" + ) + optimizer_state_dict[key + ".w_0_moment2_0" + tp_rank_suffix] = paddle.zeros( + (param_flattened_shapes[key],), "float32" + ) + if self.optimizer_state_with_master_weights: + optimizer_state_dict[key + ".w_0" + tp_rank_suffix] = paddle.zeros( + (param_flattened_shapes[key],), "float32" + ) + # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. + # Later, when these are compared with the global shape, we realize that they are replicated. + + optimizer_state_dict[key + ".w_0_beta1_pow_acc_0" + tp_rank_suffix] = paddle.zeros((1,), "float32") + optimizer_state_dict[key + ".w_0_beta2_pow_acc_0" + tp_rank_suffix] = paddle.zeros((1,), "float32") + + # merge sharding + _load_state_dict(optimizer_state_dict, source_state_dict_for_merge_sharding, [metadata_for_merge_sharding]) + + # Reshape + for k, v in optimizer_state_dict.items(): + if v.shape[0] > 1 and "_tp" in k: + param_name = self.optimizer_key_to_model_state_key(k[:-5]) + param_shape = param_meta[param_name][0] + assert v.numel() == reduce(lambda x, y: x * y, param_shape) + reshaped_v = v.reshape(param_shape) + optimizer_state_dict[k] = reshaped_v + concat_optimier_state_dict = {} + + optimizer_state_key_to_tp_keys = {} + for key in optimizer_state_dict.keys(): + # Count how each key is split into keys ending with ‘_tpXX’. + # optimizer_state_key_to_tp_keys : {key:[key_tp00,key_tp01]} + key_removed_tp_rank = key[:-5] + if key_removed_tp_rank not in optimizer_state_key_to_tp_keys: + optimizer_state_key_to_tp_keys[key_removed_tp_rank] = [key] + else: + optimizer_state_key_to_tp_keys[key_removed_tp_rank].append(key) + + for key, value in optimizer_state_key_to_tp_keys.items(): + value.sort(key=lambda x: int(x[-2:])) + + for key, tp_keys in optimizer_state_key_to_tp_keys.items(): + model_state_name = self.optimizer_key_to_model_state_key(key) + local_shape = optimizer_state_dict[tp_keys[0]].shape + if "_pow_acc_0" not in key: + global_shape = self.model_state_global_shape[model_state_name] + else: + global_shape = (1,) + + assert len(local_shape) == len(global_shape) + + axis = -1 + for i in range(len(local_shape)): + if local_shape[i] != global_shape[i]: + axis = i + break + + is_replicated = axis == -1 + tp_tensors = [] + for tp_key in tp_keys: + tp_tensors.append(optimizer_state_dict[tp_key]) + + if not is_replicated: + # Derive the partition strategy based on the global_shape, then concatenate. + concat_optimier_state_dict[key] = paddle.concat(tp_tensors, axis=axis) + else: + concat_optimier_state_dict[key] = tp_tensors[0] + + fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" + local_tensor_meta_data = {} + local_tensor_index = {} + for k, v in concat_optimier_state_dict.items(): + # Generate metadata. + local_shape = v.shape + global_offset = tuple([0] * len(local_shape)) + dtype = str(v.dtype).split(".")[1] + local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) + local_tensor_index[k] = [LocalTensorIndex(k, global_offset), fake_file_name] + + global_local_tensor_meta_data = [] + global_local_tensor_index = [] + + use_dist = True if paddle.distributed.get_world_size() > 1 else False + + if use_dist: + paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) + paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) + else: + global_local_tensor_meta_data = [local_tensor_meta_data] + global_local_tensor_index = [local_tensor_index] + + state_dict_metadata = {} + for tensor_meta_data in global_local_tensor_meta_data: + for k, v in tensor_meta_data.items(): + if k not in state_dict_metadata: + state_dict_metadata[k] = [v] + else: + state_dict_metadata[k].append(v) + + storage_metadata = {} + for tensor_index in global_local_tensor_index: + for k, v in tensor_index.items(): + storage_metadata[v[0]] = v[1] + + meta_data = Metadata(state_dict_metadata, storage_metadata, None) + source_state_dict = {fake_file_name: concat_optimier_state_dict} + + return meta_data, source_state_dict + + elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: + return self.gen_metadata_for_tp_sharded_tensor() + else: + if self.is_sharding_stage3: + for k, v in self.global_sharded_tensor_infos.items(): + v.sort(key=lambda x: x[0]["sharding_rank"]) + state_dict_metadata = {} + storage_metadata = {} + # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. + for k, v in self.global_sharded_tensor_infos.items(): + global_offset = 0 + for item in v: + if len(item[1]) == 1: + local_tensor_meta_data = LocalTensorMetadata((global_offset,), item[1], item[2]) + local_tensor_index = LocalTensorIndex(k, (global_offset,)) + global_offset += item[1][0] + else: + global_offset = tuple([0] * len(item[1])) + local_tensor_meta_data = LocalTensorMetadata(global_offset, item[1], item[2]) + local_tensor_index = LocalTensorIndex(k, global_offset) + if k not in state_dict_metadata: + state_dict_metadata[k] = [local_tensor_meta_data] + else: + state_dict_metadata[k].append(local_tensor_meta_data) + storage_metadata[local_tensor_index] = item[3] + + metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) + model_state_shapes = [] + dtype = "" + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(MODEL_WEIGHT_SUFFIX): + for k, v in state_dict.items(): + model_state_shapes.append([k, v.shape]) + dtype = str(v.dtype).split(".")[1] + + dtypes = self.gather_global_object([dtype]) + for dtype_s in dtypes: + if len(dtype_s) > 0: + dtype = dtype_s + + assert len(dtype) > 0 + + global_model_state_shapes = self.gather_global_object(model_state_shapes) + + partition_result = self.partition_parameters( + global_model_state_shapes, True, paddle.distributed.get_world_size() + ) + + cur_rank_merger_model_params = partition_result[self.cur_rank] + target_state_dict = {} + for item in cur_rank_merger_model_params: + key = item[0] + shape = item[1] + flatten_shape = reduce(lambda a, b: a * b, item[1]) + target_state_dict[key] = paddle.zeros(shape, dtype) + target_state_dict[key + ".w_0_moment1_0"] = paddle.zeros((flatten_shape,), "float32") + target_state_dict[key + ".w_0_moment2_0"] = paddle.zeros((flatten_shape,), "float32") + if self.optimizer_state_with_master_weights: + target_state_dict[key + ".w_0"] = paddle.zeros((flatten_shape,), "float32") + # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. + # Later, when these are compared with the global shape, we realize that they are replicated. + + target_state_dict[key + ".w_0_beta1_pow_acc_0"] = paddle.zeros((1,), "float32") + target_state_dict[key + ".w_0_beta2_pow_acc_0"] = paddle.zeros((1,), "float32") + + _load_state_dict(target_state_dict, self.cur_rank_loaded_state_dict, [metadata_for_merge_sharding]) + + # Reshape + for item in cur_rank_merger_model_params: + key = item[0] + shape = item[1] + for k, v in target_state_dict.items(): + if key == self.optimizer_key_to_model_state_key(k): + if tuple(shape) != tuple(v.shape) and v.numel() == reduce(lambda x, y: x * y, shape): + reshaped_v = v.reshape(shape) + target_state_dict[k] = reshaped_v + + fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" + local_tensor_meta_data = {} + local_tensor_index = {} + for k, v in target_state_dict.items(): + # Generate metadata. + local_shape = v.shape + global_offset = tuple([0] * len(local_shape)) + dtype = str(v.dtype).split(".")[1] + local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) + local_tensor_index[k] = [LocalTensorIndex(k, global_offset), fake_file_name] + + global_local_tensor_meta_data = [] + global_local_tensor_index = [] + + use_dist = True if paddle.distributed.get_world_size() > 1 else False + + if use_dist: + paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) + paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) + else: + global_local_tensor_meta_data = [local_tensor_meta_data] + global_local_tensor_index = [local_tensor_index] + + state_dict_metadata = {} + for tensor_meta_data in global_local_tensor_meta_data: + for k, v in tensor_meta_data.items(): + if k not in state_dict_metadata: + state_dict_metadata[k] = [v] + else: + state_dict_metadata[k].append(v) + + storage_metadata = {} + for tensor_index in global_local_tensor_index: + for k, v in tensor_index.items(): + storage_metadata[v[0]] = v[1] + + meta_data = Metadata(state_dict_metadata, storage_metadata, None) + source_state_dict = {fake_file_name: target_state_dict} + + return meta_data, source_state_dict + else: + return self.gen_metadata_for_tp_sharded_tensor() + + def rename_auto_parallel_state_dict(self): + need_remove_key_pattern = ["eager_tmp", "learning_rate", "@GRAD@MERG", "gradient_merge_"] + + need_remove_key = set() + for key in self.auto_parallel_state_dict.keys(): + for pattern in need_remove_key_pattern: + if pattern in key: + need_remove_key.add(key) + break + + for key in need_remove_key: + self.auto_parallel_state_dict.pop(key) + + adamw_optimizer_status_name_suffix_mappings = { + "_fp32_master_1_moment1_0": ".w_0_moment1_0", + "_fp32_master_1_moment2_0": ".w_0_moment2_0", + "_fp32_master_1_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", + "_fp32_master_1_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", + "_fp32_master_1": ".w_0", + "_moment1_0": ".w_0_moment1_0", + "_moment2_0": ".w_0_moment2_0", + "_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", + "_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", + } + + def rename(old_name, map1, map2): + for i in range(1, len(old_name)): + str1 = old_name[:i] + str2 = old_name[i:] + if (str1 in map1) and (str2 in map2): + transformed_str1 = map1[str1] + transformed_str2 = map2[str2] + return transformed_str1 + transformed_str2 + return None + + renamed_state_dict = {} + + for key, value in self.auto_parallel_state_dict.items(): + + if key in self.parameter_to_structured_name.values(): + new_name = key + else: + new_name = rename(key, self.parameter_to_structured_name, adamw_optimizer_status_name_suffix_mappings) + + assert new_name is not None + renamed_state_dict[new_name] = value + + self.auto_parallel_state_dict = renamed_state_dict + + def load_from_hybrid_parallel_checkpoint(self): + self.rename_auto_parallel_state_dict() + metadata, source_state_dict = self.gen_metadata_and_prepare_source_state_dict() + if self.save_sharded_model: + model_params = {} + for k, v in self.auto_parallel_state_dict.items(): + if k in self.parameter_to_structured_name.values(): + model_params[k] = v + for k in model_params.keys(): + self.auto_parallel_state_dict.pop(k) + + appended_master_weight_names = [] + for k, v in model_params.items(): + master_weight = k + ".w_0" + if master_weight not in self.auto_parallel_state_dict: + appended_master_weight_names.append(master_weight) + tmp_tensor = paddle.zeros(v.shape, "float32") + dist_tmp_tensor = dist.shard_tensor(tmp_tensor, v.process_mesh, v.placements) + self.auto_parallel_state_dict[master_weight] = dist_tmp_tensor + + _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) + for k, v in model_params.items(): + master_weight = self.auto_parallel_state_dict[k + ".w_0"] + cast_master_weight = paddle.cast(master_weight._local_value(), "bfloat16") + paddle.assign(cast_master_weight, v._local_value()) + for k in appended_master_weight_names: + self.auto_parallel_state_dict.pop(k) + else: + _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) From 97f942019e96d33fa09441d586aaba5ee6963905 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Mon, 12 Aug 2024 11:42:07 +0800 Subject: [PATCH 10/30] fix --- paddlenlp/trainer/training_args.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 0563f67b8e0f..c21e5baa62f6 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -349,6 +349,8 @@ class TrainingArguments: The path to a folder with a valid checkpoint for your model. This argument is not directly used by [`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example scripts](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples) for more details. + resume_form_hybrid_parallel (`bool`, *optional*): + Wether hybrid paralle checkpoints be loaded in auto parallel mode. flatten_param_grads (`bool`, *optional*): Whether use flatten_param_grads method in optimizer, only used on NPU devices. Default is `False`. skip_profile_timer (`bool`, *optional*): @@ -357,8 +359,6 @@ class TrainingArguments: Whether to use distributed dataloader. Default is `False`. release_grads (`bool`, *optional*): Whether to release gradients during training. Default is `False`. - resume_form_hybrid_parallel (`bool`, *optional*): - Wether hybrid paralle checkpoints be loaded in auto parallel mode. """ output_dir: str = field( @@ -772,6 +772,10 @@ class TrainingArguments: default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."}, ) + resume_form_hybrid_parallel: Optional[bool] = field( + default=False, + metadata={"help": "Wether hybrid paralle checkpoints be loaded in auto parallel mode."}, + ) skip_memory_metrics: bool = field( default=True, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."} ) @@ -839,10 +843,6 @@ class TrainingArguments: release_grads: Optional[bool] = field( default=False, metadata={"help": "Whether to release gradients during training. Default is `False`."} ) - resume_form_hybrid_parallel: Optional[bool] = field( - default=False, - metadata={"help": "Wether hybrid paralle checkpoints be loaded in auto parallel mode."}, - ) def __post_init__(self): env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1)) From a69a802f02d40b3d76fa01ed265e38d01f5091ad Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Mon, 12 Aug 2024 11:47:35 +0800 Subject: [PATCH 11/30] fix --- paddlenlp/trainer/ckpt_converter.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index 51f6ea27085b..5c8d4aa19fb6 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -28,7 +28,6 @@ LocalTensorMetadata, Metadata, ) -from paddle.distributed.checkpoint.utils import flatten_state_dict MODEL_WEIGHT_SUFFIX = ".pdparams" OPTIMIZER_WEIGHT_SUFFIX = ".pdopt" @@ -40,7 +39,7 @@ class CheckpointConverter: def __init__(self, hybrid_parallel_ckpt_path, model_state, parameter_to_structured_name): self.use_dist = True if paddle.distributed.get_world_size() > 1 else False self.path = hybrid_parallel_ckpt_path - self.auto_parallel_state_dict = self.flatten_state_dict(model_state) + self.auto_parallel_state_dict = model_state self.parameter_to_structured_name = self.gather_global_object(parameter_to_structured_name) model_state_global_shape = {} for k, v in model_state.items(): @@ -69,15 +68,6 @@ def get_save_sharded_model_flag(self): save_sharded_model_flag = self.gather_global_object(save_sharded_model_flag) return save_sharded_model_flag[0] - def flatten_state_dict(self, state_dict): - flattened_state_dict = {} - flat_state_dict, mapping = flatten_state_dict(state_dict) - for k, v in flat_state_dict.items(): - last_level_key = mapping[k][-1] - assert last_level_key not in flattened_state_dict - flattened_state_dict[last_level_key] = v - return flattened_state_dict - def gather_global_object(self, cur_rank_object): all_rank_objects = [] if self.use_dist: @@ -258,7 +248,6 @@ def rename_using_model_meta(self, file_name): master_weight_name_to_model_weight_name_mapping[v.split(".")[0]] = k renamed_state_dict = {} - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) state_dict = self.cur_rank_loaded_state_dict[file_name] for k, v in state_dict.items(): master_weight_name = self.parse_master_weight_name_by(k) @@ -277,7 +266,6 @@ def rename_using_optimizer_state_order(self, file_name): self.global_file_to_state_dict_keys_mapping = self.gather_global_object(file_to_state_dict_keys_mapping) - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): model_state_file_name = self.get_model_state_file_from(file_name) assert model_state_file_name is not None @@ -291,7 +279,6 @@ def rename_using_optimizer_state_order(self, file_name): state_dict = self.cur_rank_loaded_state_dict[file_name] renamed_state_dict = {} - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) for k, v in state_dict.items(): master_weight_name = self.parse_master_weight_name_by(k) model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] @@ -410,7 +397,7 @@ def load_state_dict_and_rename(self): ) self.cur_rank_loaded_state_dict[file] = renamed_state_dict - # 2. In handling the sharding stage1 v1 and sharding stage2 scenario, the optimizer states are distributed across different ranks. + # 2. In handling the sharding stage1 v1 and stage2 scenario, the optimizer states are distributed across different ranks. # We need to obtain the name mapping by simulating the partitioning method, without concern for the presence of master_weights. elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: if not self.save_sharded_model: @@ -525,7 +512,7 @@ def load_state_dict_and_rename(self): self.cur_rank_loaded_state_dict[file] = renamed_state_dict else: - # 3. Handling the case of disabling sharding, independent of master_weights, but without considering the save_sharded_model flag. + # 3. Handling the sharding stage3 and non-sharding scenario if not self.save_sharded_model: for file, state_dict in self.cur_rank_loaded_state_dict.items(): (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) @@ -966,7 +953,7 @@ def gen_metadata_and_prepare_source_state_dict(self): else: return self.gen_metadata_for_tp_sharded_tensor() - def rename_auto_parallel_state_dict(self): + def rename_semi_auto_state_dict(self): need_remove_key_pattern = ["eager_tmp", "learning_rate", "@GRAD@MERG", "gradient_merge_"] need_remove_key = set() @@ -1016,7 +1003,7 @@ def rename(old_name, map1, map2): self.auto_parallel_state_dict = renamed_state_dict def load_from_hybrid_parallel_checkpoint(self): - self.rename_auto_parallel_state_dict() + self.rename_semi_auto_state_dict() metadata, source_state_dict = self.gen_metadata_and_prepare_source_state_dict() if self.save_sharded_model: model_params = {} From 6ea0734eea44d76bf52c81ab8d72023ec1ec6cfd Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Mon, 12 Aug 2024 11:51:09 +0800 Subject: [PATCH 12/30] fix --- paddlenlp/trainer/ckpt_converter.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index 5c8d4aa19fb6..d6121085d1f3 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -28,6 +28,7 @@ LocalTensorMetadata, Metadata, ) +from paddle.distributed.checkpoint.utils import flatten_state_dict MODEL_WEIGHT_SUFFIX = ".pdparams" OPTIMIZER_WEIGHT_SUFFIX = ".pdopt" @@ -39,7 +40,7 @@ class CheckpointConverter: def __init__(self, hybrid_parallel_ckpt_path, model_state, parameter_to_structured_name): self.use_dist = True if paddle.distributed.get_world_size() > 1 else False self.path = hybrid_parallel_ckpt_path - self.auto_parallel_state_dict = model_state + self.auto_parallel_state_dict = self.flatten_state_dict(model_state) self.parameter_to_structured_name = self.gather_global_object(parameter_to_structured_name) model_state_global_shape = {} for k, v in model_state.items(): @@ -68,6 +69,15 @@ def get_save_sharded_model_flag(self): save_sharded_model_flag = self.gather_global_object(save_sharded_model_flag) return save_sharded_model_flag[0] + def flatten_state_dict(self, state_dict): + flattened_state_dict = {} + flat_state_dict, mapping = flatten_state_dict(state_dict) + for k, v in flat_state_dict.items(): + last_level_key = mapping[k][-1] + assert last_level_key not in flattened_state_dict + flattened_state_dict[last_level_key] = v + return flattened_state_dict + def gather_global_object(self, cur_rank_object): all_rank_objects = [] if self.use_dist: @@ -953,7 +963,7 @@ def gen_metadata_and_prepare_source_state_dict(self): else: return self.gen_metadata_for_tp_sharded_tensor() - def rename_semi_auto_state_dict(self): + def rename_auto_parallel_state_dict(self): need_remove_key_pattern = ["eager_tmp", "learning_rate", "@GRAD@MERG", "gradient_merge_"] need_remove_key = set() @@ -1003,7 +1013,7 @@ def rename(old_name, map1, map2): self.auto_parallel_state_dict = renamed_state_dict def load_from_hybrid_parallel_checkpoint(self): - self.rename_semi_auto_state_dict() + self.rename_auto_parallel_state_dict() metadata, source_state_dict = self.gen_metadata_and_prepare_source_state_dict() if self.save_sharded_model: model_params = {} From 4d7bf5d38f9905f2e00c423f0913445dccf8c251 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Mon, 12 Aug 2024 12:02:26 +0800 Subject: [PATCH 13/30] fix --- paddlenlp/trainer/ckpt_converter.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index d6121085d1f3..dfe1a6259260 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -242,10 +242,11 @@ def partition_parameters(self, model_state_shapes, is_sort, shard_num): def rename_using_model_meta(self, file_name): if not hasattr(self, "model_meta"): - try: - self.model_meta = json.load(open(os.path.join(self.path, MODEL_META_FILE_NAME))) - except Exception as e: - print(e) + meta_file_path = os.path.join(self.path, MODEL_META_FILE_NAME) + assert os.path.exists(meta_file_path) + with open(meta_file_path, "r") as file: + self.model_meta = json.load(file) + distributed_rank = self.get_distribution_rank_from_file_name(file_name) dist_strategy_key = ( "tp" + "{:02d}".format(distributed_rank[0]) + "_" + "pp" + "{:02d}".format(distributed_rank[1]) From 0162d9981ab24a6575245989e4034edcad8e8df9 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Mon, 12 Aug 2024 13:35:03 +0800 Subject: [PATCH 14/30] fix --- paddlenlp/trainer/auto_trainer.py | 60 ++++++++++++++++++++++++----- paddlenlp/trainer/ckpt_converter.py | 2 +- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index e9ba8a2eabd4..b3ff932281a9 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -50,6 +50,7 @@ MODEL_NAME = "model" OPTIMIZER_NAME = "optimizer" DIST_CKPT_PATH = "dist_ckpt" +FREE_SVAE_LOAD_KEY_PATTERNS = ["learning_rate_", "gradient_merge_", "@GRAD@MERG", "eager_tmp"] class AutoTrainer(Trainer): @@ -158,20 +159,43 @@ def _split_batches_for_accumulation(self, inputs): if self.args.gradient_accumulation_steps == 1: return [inputs] - # if self.args.to_static: if self.args.to_static and self.args.pipeline_parallel_degree > 1: return [inputs] local_batches = [{} for i in range(self.args.gradient_accumulation_steps)] + assert isinstance(inputs, dict) - for key, value in inputs.items(): - ori_mesh, ori_placements = value.process_mesh, value.placements - replicate_value = dist.reshard(value, ori_mesh, [dist.Replicate(), dist.Replicate()]) + def split_dtensor_by_axis(dtensor, axis): + mesh = dtensor.process_mesh + placements = [dist.Replicate() for _ in range(len(mesh.shape))] + replicate_value = dist.reshard(dtensor, mesh, placements) local_datas = replicate_value.split(self.args.gradient_accumulation_steps, axis=0) - - for index, data in enumerate(local_datas): - local_batches[index].update({key: dist.reshard(data, ori_mesh, ori_placements)}) - + return local_datas + + for key, dtensors in inputs.items(): + if isinstance(dtensors, paddle.Tensor): + mesh, placements = dtensors.process_mesh, dtensors.placements + local_datas = split_dtensor_by_axis(dtensors, 0) + for index, data in enumerate(local_datas): + local_batches[index].update({key: dist.reshard(data, mesh, placements)}) + elif isinstance(dtensors, (list, tuple)): + if len(dtensors) == 0: + for i in range(self.args.gradient_accumulation_steps): + local_batches[i].update({key: []}) + else: + for dtensor in dtensors: + if isinstance(dtensor, paddle.Tensor): + mesh, placements = dtensor.process_mesh, dtensor.placements + local_datas = split_dtensor_by_axis(dtensor, 0) + for index, data in enumerate(local_datas): + if key in local_batches[index].keys(): + local_batches[index][key].append(dist.reshard(data, mesh, placements)) + else: + local_batches[index].update({key: [dist.reshard(data, mesh, placements)]}) + else: + raise ValueError(f"unsupported type: {type(dtensor)}") + else: + raise ValueError(f"unsupported type: {type(dtensors)}") return local_batches def _inner_training_loop( @@ -544,7 +568,15 @@ def _save_checkpoint(self, model, metrics=None): if self.args.should_save_model_state: if self.args.to_static: - state_dict = model.state_dict() + opt_state_dict = { + key: value + for key, value in model.state_dict("opt").items() + if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS) + } + state_dict = { + MODEL_NAME: model.state_dict("param"), + OPTIMIZER_NAME: opt_state_dict, + } else: optim_state_dict = self.optimizer.state_dict() optim_state_dict.pop("LR_Scheduler", None) @@ -665,7 +697,15 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): ) if self.args.to_static: - state_dict = self.model_wrapped.state_dict() + opt_state_dict = { + key: value + for key, value in self.model_wrapped.state_dict("opt").items() + if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS) + } + state_dict = { + MODEL_NAME: self.model_wrapped.state_dict("param"), + OPTIMIZER_NAME: opt_state_dict, + } else: model_state_dict = self.model_wrapped.state_dict() optim_state_dict = self.optimizer.state_dict() diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index dfe1a6259260..7d58be041d88 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -43,7 +43,7 @@ def __init__(self, hybrid_parallel_ckpt_path, model_state, parameter_to_structur self.auto_parallel_state_dict = self.flatten_state_dict(model_state) self.parameter_to_structured_name = self.gather_global_object(parameter_to_structured_name) model_state_global_shape = {} - for k, v in model_state.items(): + for k, v in self.auto_parallel_state_dict.items(): model_state_global_shape[k] = v.shape self.model_state_global_shape = self.gather_global_object(model_state_global_shape) self.cur_rank = paddle.distributed.get_rank() From 74329f442044691518f0be1057f1cb4961886a6e Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Mon, 12 Aug 2024 15:59:45 +0800 Subject: [PATCH 15/30] fix codestyle --- paddlenlp/trainer/ckpt_converter.py | 1385 +++++++++++++-------------- 1 file changed, 677 insertions(+), 708 deletions(-) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index 7d58be041d88..9acb8f13fe16 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -61,246 +61,434 @@ def __init__(self, hybrid_parallel_ckpt_path, model_state, parameter_to_structur self.initial_distributed_configuration() - def get_save_sharded_model_flag(self): - if self.cur_rank == 1: - save_sharded_model_flag = [os.path.exists(os.path.join(self.path, MODEL_META_FILE_NAME))] - else: - save_sharded_model_flag = [] - save_sharded_model_flag = self.gather_global_object(save_sharded_model_flag) - return save_sharded_model_flag[0] - - def flatten_state_dict(self, state_dict): - flattened_state_dict = {} - flat_state_dict, mapping = flatten_state_dict(state_dict) - for k, v in flat_state_dict.items(): - last_level_key = mapping[k][-1] - assert last_level_key not in flattened_state_dict - flattened_state_dict[last_level_key] = v - return flattened_state_dict + def load_from_hybrid_parallel_checkpoint(self): + """ + Automatically and inplace load the distributed checkpoint stored in hybrid parallel mode into the auto parallel state_dict. + The main logic is as follows: + 1. Callrename_semi_auto_state_dict: Rename the keys of the auto parallel state_dict according to certain rules. + (Why rename? To facilitate the subsequent correspondence between the optimizer state names of the semi-automatic and static optimizers.) + 2. Callgen_metadata_and_prepare_source_state_dict: Automatically parse the manual checkpoint file based on the state_dict information + provided by auto parallel, obtaining the Metadata and state_dict required for auto parallel to load the checkpoint. + 3. Callload_state_dict: Automatically reshard and load. + 4. Special logic adaptation: In the save_sharded_model mode, the weights are obtained through the master_weight cast in the checkpoint. + """ + self.rename_auto_parallel_state_dict() + metadata, source_state_dict = self.gen_metadata_and_prepare_source_state_dict() + if self.save_sharded_model: + model_params = {} + for state_name, state_value in self.auto_parallel_state_dict.items(): + if state_name in self.parameter_to_structured_name.values(): + model_params[state_name] = state_value + for param_name in model_params.keys(): + self.auto_parallel_state_dict.pop(param_name) - def gather_global_object(self, cur_rank_object): - all_rank_objects = [] - if self.use_dist: - paddle.distributed.all_gather_object(all_rank_objects, cur_rank_object) - else: - all_rank_objects = [all_rank_objects] + appended_master_weight_names = [] + for param_name, param_value in model_params.items(): + master_weight = param_name + ".w_0" + if master_weight not in self.auto_parallel_state_dict: + appended_master_weight_names.append(master_weight) + tmp_tensor = paddle.zeros(param_value.shape, "float32") + self.auto_parallel_state_dict[master_weight] = dist.shard_tensor( + tmp_tensor, param_value.process_mesh, param_value.placements + ) - if isinstance(cur_rank_object, list): - return [item for sublist in all_rank_objects for item in sublist] - elif isinstance(cur_rank_object, dict): - global_map = {} - for rank_map in all_rank_objects: - global_map.update(rank_map) - return global_map + _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) + for param_name, param_value in model_params.items(): + master_weight = self.auto_parallel_state_dict[param_name + ".w_0"] + cast_master_weight = paddle.cast(master_weight._local_value(), "bfloat16") + paddle.assign(cast_master_weight, param_value._local_value()) + for master_weight_name in appended_master_weight_names: + self.auto_parallel_state_dict.pop(master_weight_name) else: - raise ValueError("cur_rank_object should be either a list or a dict") + _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) - def get_local_checkpoint_file_names(self): - cur_rank_files = os.listdir(self.path) - cur_rank_model_state_file_names = [] - cur_rank_optimizer_state_file_names = [] - for file in cur_rank_files: - if file.endswith(MODEL_WEIGHT_SUFFIX): - cur_rank_model_state_file_names.append(file) - elif file.endswith(OPTIMIZER_WEIGHT_SUFFIX): - cur_rank_optimizer_state_file_names.append(file) - if SCHEDULER_NAME in cur_rank_model_state_file_names: - cur_rank_model_state_file_names.remove(SCHEDULER_NAME) - return cur_rank_model_state_file_names, cur_rank_optimizer_state_file_names + def rename_auto_parallel_state_dict(self): + """ + Rename the keys of the auto parallel state_dict according to certain rules: + 1. Rename the suffixes of the optimizer states to a unified format: adamw_optimizer_status_name_suffix_mappings + """ - def get_distribution_rank_from_file_name(self, file_name): - pp_degree = 0 - tp_degree = 0 - sharding_degree = 0 - pattern_pp = r"pp(\d+)" - pattern_tp = r"tp(\d+)" - pattern_shard = r"shard(\d+)" - match_pp = re.search(pattern_pp, file_name) - if match_pp: - pp_degree = int(match_pp.group(1)) - match_tp = re.search(pattern_tp, file_name) - if match_tp: - tp_degree = int(match_tp.group(1)) - match_shard = re.search(pattern_shard, file_name) - if match_shard: - sharding_degree = int(match_shard.group(1)) - return (tp_degree, pp_degree, sharding_degree) + adamw_optimizer_state_name_suffix_mappings = { + "_fp32_master_1_moment1_0": ".w_0_moment1_0", + "_fp32_master_1_moment2_0": ".w_0_moment2_0", + "_fp32_master_1_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", + "_fp32_master_1_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", + "_fp32_master_1": ".w_0", + "_moment1_0": ".w_0_moment1_0", + "_moment2_0": ".w_0_moment2_0", + "_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", + "_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", + } - def initial_distributed_configuration(self): - self.pp_degree = 0 - self.tp_degree = 0 - self.sharding_degree = 0 + def rename(old_name, map1, map2): + for i in range(1, len(old_name)): + str1 = old_name[:i] + str2 = old_name[i:] + if (str1 in map1) and (str2 in map2): + transformed_str1 = map1[str1] + transformed_str2 = map2[str2] + return transformed_str1 + transformed_str2 + return None - all_files = self.global_model_state_file_names + self.global_optimizer_state_file_names + renamed_state_dict = {} - for file in all_files: - (tp_degree, pp_degree, sharding_degree) = self.get_distribution_rank_from_file_name(file) - self.pp_degree = max(self.pp_degree, pp_degree) - self.tp_degree = max(self.tp_degree, tp_degree) - self.sharding_degree = max(self.sharding_degree, sharding_degree) + for key, value in self.auto_parallel_state_dict.items(): - self.pp_degree = self.pp_degree + 1 - self.tp_degree = self.tp_degree + 1 - self.sharding_degree = self.sharding_degree + 1 + if key in self.parameter_to_structured_name.values(): + new_name = key + else: + new_name = rename(key, self.parameter_to_structured_name, adamw_optimizer_state_name_suffix_mappings) - def infer_sharding_stage1_v(self): - sharding_stage1_v = [2] - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - if file.endswith(OPTIMIZER_WEIGHT_SUFFIX) and sharding_stage1_v[0] == 2: - for k, v in state_dict.items(): - # Under shardingv2, the optimizer state is first flattened and then split. - if len(v.shape) != 1: - sharding_stage1_v = [1] - break + assert new_name is not None + renamed_state_dict[new_name] = value - sharding_stage1_v = self.gather_global_object(sharding_stage1_v) - if 1 in sharding_stage1_v: - return 1 - return 2 + self.auto_parallel_state_dict = renamed_state_dict - def infer_is_sharding_stage3(self): - if self.sharding_degree == 1: - return False - if self.pp_degree > 1 or self.tp_degree > 1: - # Currently, sharding stage 3 does not support concurrent use with tensor parallelism (TP) and pipeline parallelism (PP). - return False + def gen_metadata_and_prepare_source_state_dict(self): + """ + Automatically parse the manual checkpoint file based on the state_dict information provided by auto parallel, + obtaining the Metadata and state_dict required for auto parallel to load the checkpoint: + 1. Callload_state_dict_and_rename: Parse the distributed information from the names of the checkpoint files, and evenly parse out the distributed + information for each weight/optimizer state into self.global_sharded_tensor_infos(data structure:param_name -> [{tp_rank: 1, sharding_rank: 1}, shape, dtype, file_name]). + Modify the names of the optimizer states in the form ofparameter+suffixand record them in self.cur_rank_loaded_state_dict(data structure:file_name -> renamed_state_dict). + 2. Construct the Metadata and state_dict based on the distributed information obtained in the previous step for the final load. + 3. Special logic adaptation: When sharding is enabled, the optimizer states are also split. In this step, the optimizer states need to be concatenated back according to the sharding dimension: + * Construct the Metadata for concatenating the sharded states back based on the characteristics of sharding. + * Construct a temporaryopt_state_dictand use the_load_state_dictinterface to obtain the state_dict with the sharded states concatenated back. + * Reshape the optimizer states back to the shape of the weights. + """ + self.load_state_dict_and_rename() + if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: + for state_name, shard_info in self.global_sharded_tensor_infos.items(): + shard_info.sort(key=lambda x: x[0]["sharding_rank"]) - is_sharding_stage3 = True + state_dict_metadata = {} + storage_metadata = {} + # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. + for state_name, shard_info in self.global_sharded_tensor_infos.items(): + global_offset = [0] * self.tp_degree + for item in shard_info: + tp_rank = item[0]["tp_rank"] + state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank) + local_tensor_meta_data = LocalTensorMetadata((global_offset[tp_rank],), item[1], item[2]) + local_tensor_index = LocalTensorIndex(state_name_with_tp_rank, (global_offset[tp_rank],)) + global_offset[tp_rank] += item[1][0] + if state_name_with_tp_rank not in state_dict_metadata: + state_dict_metadata[state_name_with_tp_rank] = [local_tensor_meta_data] + else: + state_dict_metadata[state_name_with_tp_rank].append(local_tensor_meta_data) + storage_metadata[local_tensor_index] = item[3] - file_to_state_shape_mapping = {} - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): - state_shape_mapping = {} - for k, v in state_dict.items(): - state_shape_mapping[k] = v.shape - if len(v.shape) != 1: - return False - file_to_state_shape_mapping[file] = state_shape_mapping - global_file_to_state_shape_mapping = self.gather_global_object(file_to_state_shape_mapping) + metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) - state_dict_std = global_file_to_state_shape_mapping[list(global_file_to_state_shape_mapping.keys())[0]] + source_state_dict_for_merge_sharding = {} + for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): + renamed_state_dict = {} + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) + for state_name, state_value in state_dict.items(): + state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank) + renamed_state_dict[state_name_with_tp_rank] = state_value - for file, state_dict in global_file_to_state_shape_mapping.items(): - if state_dict != state_dict_std: - is_sharding_stage3 = False - break - return is_sharding_stage3 + source_state_dict_for_merge_sharding[file_name] = renamed_state_dict - def parse_master_weight_name_by(self, optimizer_state_name): - return optimizer_state_name.split(".")[0] + assert self.model_meta is not None + global_model_state_shapes = [] + sharding_metas_keys = [] + for i in range(self.pp_degree): + for j in range(self.tp_degree): + sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j)) + for key in sharding_metas_keys: + param_meta = self.model_meta["sharding_metas"][key]["param_meta"] + for param_name, param_shape_and_dtype in param_meta.items(): + global_model_state_shapes.append([param_name, param_shape_and_dtype[0]]) - def get_model_state_file_from(self, optimizer_state_file_name): - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(optimizer_state_file_name) - for model_state_file in self.global_model_state_file_names: - distributed_rank = self.get_distribution_rank_from_file_name(model_state_file) - if tp_rank == distributed_rank[0] and pp_rank == distributed_rank[1]: - return model_state_file - return None + # Distribute all model parameters evenly across each card for loading - def optimizer_key_to_model_state_key(self, optimizer_key): - adamw_optimizer_key_suffix = [ - ".w_0_beta1_pow_acc_0", - ".w_0_beta2_pow_acc_0", - ".w_0_moment1_0", - ".w_0_moment2_0", - ".w_0", - ] - model_state_key = optimizer_key - for suffix in adamw_optimizer_key_suffix: - if model_state_key.endswith(suffix): - # Remove the suffix from model_state_key - model_state_key = model_state_key[: -len(suffix)] - break - return model_state_key + world_size = paddle.distributed.get_world_size() - def partition_parameters(self, model_state_shapes, is_sort, shard_num): - """ - Partitions parameters among sharding ranks. + partition_mapping = self.partition_parameters(global_model_state_shapes, True, world_size) - Return: - Dict[int, List] - """ - # Copy from python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py - mapping = {} - for rank_ in range(shard_num): - mapping[rank_] = [] - sizes = [0] * shard_num + partition_model_state_keys = [] + for cur_rank, partition_model_state in partition_mapping.items(): + partition_model_state_keys.append([item[0] for item in partition_model_state]) - parameters = model_state_shapes.copy() + param_meta = {} + for i in range(self.tp_degree): + for j in range(self.pp_degree): + key = "tp{:02d}_pp{:02d}".format(i, j) + param_meta = self.model_meta["sharding_metas"][key]["param_meta"] + for param_name, param_shape_and_dtype in param_meta.items(): + param_meta[param_name] = param_shape_and_dtype - if is_sort: - parameters.sort(key=lambda p: reduce(lambda x, y: x * y, p[1]), reverse=True) + param_flattened_shapes = {} + for param_meta, param_shape_and_dtype in param_meta.items(): + param_flattened_shapes[param_meta] = reduce(lambda x, y: x * y, param_shape_and_dtype[0]) - for param in parameters: - rank = sizes.index(min(sizes)) - mapping[rank].append(param) - numel = reduce(lambda x, y: x * y, param[1], 1) - assert numel > 0, f"param [{param[0]}] should larger than 0, but it is [{numel}]" - sizes[rank] += numel + cur_rank_need_load_model_state_keys = partition_model_state_keys[self.cur_rank] - return mapping + # Generate the optimizer states corresponding to the model weights. + optimizer_state_dict = {} + for key in cur_rank_need_load_model_state_keys: + for tp_rank in range(self.tp_degree): + tp_rank_suffix = "_tp{:02d}".format(tp_rank) + optimizer_state_dict[key + ".w_0_moment1_0" + tp_rank_suffix] = paddle.zeros( + (param_flattened_shapes[key],), "float32" + ) + optimizer_state_dict[key + ".w_0_moment2_0" + tp_rank_suffix] = paddle.zeros( + (param_flattened_shapes[key],), "float32" + ) + if self.optimizer_state_with_master_weights: + optimizer_state_dict[key + ".w_0" + tp_rank_suffix] = paddle.zeros( + (param_flattened_shapes[key],), "float32" + ) + # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. + # Later, when these are compared with the global shape, we realize that they are replicated. - def rename_using_model_meta(self, file_name): - if not hasattr(self, "model_meta"): - meta_file_path = os.path.join(self.path, MODEL_META_FILE_NAME) - assert os.path.exists(meta_file_path) - with open(meta_file_path, "r") as file: - self.model_meta = json.load(file) + optimizer_state_dict[key + ".w_0_beta1_pow_acc_0" + tp_rank_suffix] = paddle.zeros((1,), "float32") + optimizer_state_dict[key + ".w_0_beta2_pow_acc_0" + tp_rank_suffix] = paddle.zeros((1,), "float32") - distributed_rank = self.get_distribution_rank_from_file_name(file_name) - dist_strategy_key = ( - "tp" + "{:02d}".format(distributed_rank[0]) + "_" + "pp" + "{:02d}".format(distributed_rank[1]) - ) - # Map model weight names to their corresponding names of master_weights in the optimizer state. - if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): - structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"] - master_weight_name_to_model_weight_name_mapping = {} - for k, v in structure_name_mapping.items(): - master_weight_name_to_model_weight_name_mapping[v.split(".")[0]] = k + # merge sharding + _load_state_dict(optimizer_state_dict, source_state_dict_for_merge_sharding, [metadata_for_merge_sharding]) - renamed_state_dict = {} - state_dict = self.cur_rank_loaded_state_dict[file_name] - for k, v in state_dict.items(): - master_weight_name = self.parse_master_weight_name_by(k) - model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] - new_k = k.replace(master_weight_name, model_weight_name) - renamed_state_dict[new_k] = v - return renamed_state_dict - else: - return self.cur_rank_loaded_state_dict[file_name] + # Reshape + for opt_state_name, opt_state_value in optimizer_state_dict.items(): + if opt_state_value.shape[0] > 1 and "_tp" in opt_state_name: + param_name = self.optimizer_key_to_model_state_key(opt_state_name[:-5]) + param_shape = param_meta[param_name][0] + assert opt_state_value.numel() == reduce(lambda x, y: x * y, param_shape) + reshaped_opt_state_value = opt_state_value.reshape(param_shape) + optimizer_state_dict[opt_state_name] = reshaped_opt_state_value + concat_optimier_state_dict = {} - def rename_using_optimizer_state_order(self, file_name): - if not hasattr(self, "global_file_to_state_dict_keys_mapping"): - file_to_state_dict_keys_mapping = {} - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - file_to_state_dict_keys_mapping[file] = list(state_dict.keys()) + optimizer_state_key_to_tp_keys = {} + for opt_state_name in optimizer_state_dict.keys(): + # Count how each key is split into keys ending with ‘_tpXX’. + # optimizer_state_key_to_tp_keys : {key:[key_tp00,key_tp01]} + opt_state_name_removed_tp_rank = opt_state_name[:-5] + if opt_state_name_removed_tp_rank not in optimizer_state_key_to_tp_keys: + optimizer_state_key_to_tp_keys[opt_state_name_removed_tp_rank] = [opt_state_name] + else: + optimizer_state_key_to_tp_keys[opt_state_name_removed_tp_rank].append(opt_state_name) - self.global_file_to_state_dict_keys_mapping = self.gather_global_object(file_to_state_dict_keys_mapping) + for opt_state_name_removed_tp_rank, opt_state_name in optimizer_state_key_to_tp_keys.items(): + opt_state_name.sort(key=lambda x: int(x[-2:])) - if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): - model_state_file_name = self.get_model_state_file_from(file_name) - assert model_state_file_name is not None - model_state_keys = self.global_file_to_state_dict_keys_mapping[model_state_file_name] - optimizer_state_keys = self.global_file_to_state_dict_keys_mapping[file_name] + for opt_state_name_removed_tp_rank, opt_state_name in optimizer_state_key_to_tp_keys.items(): + model_state_name = self.optimizer_key_to_model_state_key(opt_state_name_removed_tp_rank) + local_shape = optimizer_state_dict[opt_state_name[0]].shape + if "_pow_acc_0" not in key: + global_shape = self.model_state_global_shape[model_state_name] + else: + global_shape = (1,) - master_weight_name_to_model_weight_name_mapping = {} - for i in range(len(model_state_keys)): - master_weight_name = self.parse_master_weight_name_by(optimizer_state_keys[i]) - master_weight_name_to_model_weight_name_mapping[master_weight_name] = model_state_keys[i] + assert len(local_shape) == len(global_shape) - state_dict = self.cur_rank_loaded_state_dict[file_name] - renamed_state_dict = {} - for k, v in state_dict.items(): - master_weight_name = self.parse_master_weight_name_by(k) - model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] - new_k = k.replace(master_weight_name, model_weight_name) - renamed_state_dict[new_k] = v + axis = -1 + for i in range(len(local_shape)): + if local_shape[i] != global_shape[i]: + axis = i + break - return renamed_state_dict + is_replicated = axis == -1 + tp_tensors = [] + for opt_state_name_with_tp_rank in opt_state_name: + tp_tensors.append(optimizer_state_dict[opt_state_name_with_tp_rank]) + + if not is_replicated: + # Derive the partition strategy based on the global_shape, then concatenate. + concat_optimier_state_dict[opt_state_name_removed_tp_rank] = paddle.concat(tp_tensors, axis=axis) + else: + concat_optimier_state_dict[opt_state_name_removed_tp_rank] = tp_tensors[0] + + fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" + local_tensor_meta_data = {} + local_tensor_index = {} + for k, v in concat_optimier_state_dict.items(): + # Generate metadata. + local_shape = v.shape + global_offset = tuple([0] * len(local_shape)) + dtype = str(v.dtype).split(".")[1] + local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) + local_tensor_index[k] = [LocalTensorIndex(k, global_offset), fake_file_name] + + global_local_tensor_meta_data = [] + global_local_tensor_index = [] + + use_dist = True if paddle.distributed.get_world_size() > 1 else False + + if use_dist: + paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) + paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) + else: + global_local_tensor_meta_data = [local_tensor_meta_data] + global_local_tensor_index = [local_tensor_index] + + state_dict_metadata = {} + for tensor_meta_data in global_local_tensor_meta_data: + for k, v in tensor_meta_data.items(): + if k not in state_dict_metadata: + state_dict_metadata[k] = [v] + else: + state_dict_metadata[k].append(v) + + storage_metadata = {} + for tensor_index in global_local_tensor_index: + for k, v in tensor_index.items(): + storage_metadata[v[0]] = v[1] + + meta_data = Metadata(state_dict_metadata, storage_metadata, None) + source_state_dict = {fake_file_name: concat_optimier_state_dict} + + return meta_data, source_state_dict + + elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: + return self.gen_metadata_for_tp_sharded_tensor() else: - return self.cur_rank_loaded_state_dict[file_name] + if self.is_sharding_stage3: + for state_name, shard_info in self.global_sharded_tensor_infos.items(): + shard_info.sort(key=lambda x: x[0]["sharding_rank"]) + state_dict_metadata = {} + storage_metadata = {} + # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. + for state_name, shard_info in self.global_sharded_tensor_infos.items(): + global_offset = 0 + for item in shard_info: + if len(item[1]) == 1: + local_tensor_meta_data = LocalTensorMetadata((global_offset,), item[1], item[2]) + local_tensor_index = LocalTensorIndex(state_name, (global_offset,)) + global_offset += item[1][0] + else: + global_offset = tuple([0] * len(item[1])) + local_tensor_meta_data = LocalTensorMetadata(global_offset, item[1], item[2]) + local_tensor_index = LocalTensorIndex(state_name, global_offset) + if state_name not in state_dict_metadata: + state_dict_metadata[state_name] = [local_tensor_meta_data] + else: + state_dict_metadata[state_name].append(local_tensor_meta_data) + storage_metadata[local_tensor_index] = item[3] + + metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) + model_state_shapes = [] + dtype = "" + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(MODEL_WEIGHT_SUFFIX): + for k, v in state_dict.items(): + model_state_shapes.append([k, v.shape]) + dtype = str(v.dtype).split(".")[1] + + dtypes = self.gather_global_object([dtype]) + for dtype_s in dtypes: + if len(dtype_s) > 0: + dtype = dtype_s + + assert len(dtype) > 0 + + global_model_state_shapes = self.gather_global_object(model_state_shapes) + + partition_result = self.partition_parameters( + global_model_state_shapes, True, paddle.distributed.get_world_size() + ) + + cur_rank_merger_model_params = partition_result[self.cur_rank] + target_state_dict = {} + for item in cur_rank_merger_model_params: + key = item[0] + shape = item[1] + flatten_shape = reduce(lambda a, b: a * b, item[1]) + target_state_dict[key] = paddle.zeros(shape, dtype) + target_state_dict[key + ".w_0_moment1_0"] = paddle.zeros((flatten_shape,), "float32") + target_state_dict[key + ".w_0_moment2_0"] = paddle.zeros((flatten_shape,), "float32") + if self.optimizer_state_with_master_weights: + target_state_dict[key + ".w_0"] = paddle.zeros((flatten_shape,), "float32") + # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. + # Later, when these are compared with the global shape, we realize that they are replicated. + + target_state_dict[key + ".w_0_beta1_pow_acc_0"] = paddle.zeros((1,), "float32") + target_state_dict[key + ".w_0_beta2_pow_acc_0"] = paddle.zeros((1,), "float32") + + _load_state_dict(target_state_dict, self.cur_rank_loaded_state_dict, [metadata_for_merge_sharding]) + + # Reshape + for item in cur_rank_merger_model_params: + key = item[0] + shape = item[1] + for k, v in target_state_dict.items(): + if key == self.optimizer_key_to_model_state_key(k): + if tuple(shape) != tuple(v.shape) and v.numel() == reduce(lambda x, y: x * y, shape): + reshaped_v = v.reshape(shape) + target_state_dict[k] = reshaped_v + + fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" + local_tensor_meta_data = {} + local_tensor_index = {} + for k, v in target_state_dict.items(): + # Generate metadata. + local_shape = v.shape + global_offset = tuple([0] * len(local_shape)) + dtype = str(v.dtype).split(".")[1] + local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) + local_tensor_index[k] = [LocalTensorIndex(k, global_offset), fake_file_name] + + global_local_tensor_meta_data = [] + global_local_tensor_index = [] + + use_dist = True if paddle.distributed.get_world_size() > 1 else False + + if use_dist: + paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) + paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) + else: + global_local_tensor_meta_data = [local_tensor_meta_data] + global_local_tensor_index = [local_tensor_index] + + state_dict_metadata = {} + for tensor_meta_data in global_local_tensor_meta_data: + for k, v in tensor_meta_data.items(): + if k not in state_dict_metadata: + state_dict_metadata[k] = [v] + else: + state_dict_metadata[k].append(v) + + storage_metadata = {} + for tensor_index in global_local_tensor_index: + for k, v in tensor_index.items(): + storage_metadata[v[0]] = v[1] + + meta_data = Metadata(state_dict_metadata, storage_metadata, None) + source_state_dict = {fake_file_name: target_state_dict} + + return meta_data, source_state_dict + else: + return self.gen_metadata_for_tp_sharded_tensor() def load_state_dict_and_rename(self): + """ + Parse the distributed information from the names of the checkpoint files and evenly parse out the distributed information for each weight/optimizer state + into self.global_sharded_tensor_infos (data structure: param_name -> [{tp_rank: 1, sharding_rank: 1}, shape, dtype, file_name]). Modify the names of the + optimizer states in the form of parameter+suffix and record them in self.cur_rank_loaded_state_dict (data structure: file_name -> renamed_state_dict). + 1. Load balancing: Each rank parses a portion of the checkpoint files. + 2. Flatten master_weights in opt_state into opt_state. + 3. Rename the keys in opt_state according to the rule: adamw_optimizer_param_suffix_name_mapping. + 4. Optimizer state renaming and distributed information extraction: + * If it is sharding_stage1/2_v2 version: + * Renaming: rename_using_model_meta: In this case, a model_meta file is required. According to this file, + obtain the name mapping of weights and optimizer parameters, so that the optimizer states of manual and static partitions can correspond. + * Distributed information extraction: Record the distributed information of parameters: name -> [{tp_rank, sharding_rank}, shape, dtype, file_name]. + * If it is sharding_stage1/2_v1 version: + * Renaming: + * If a model_meta file exists: + * rename_using_model_meta + * If a model_meta file does not exist: + * According to the characteristics of v1 partitioning, infer the mapping relationship between optimizer states and weights (partition_result): master_weight_name_to_model_weight_name_mapping. + * Distributed information extraction: Record the distributed information of parameters: name -> [{tp_rank}, shape, dtype, file_name] (parameters will not be sharded). + * If it is sharding_stage3: + * Renaming: + * If a model_meta file exists: + * rename_using_model_meta + * If a model_meta file does not exist: + * Establish the mapping between weights and optimizer names according to the order of optimizer states and weights: rename_using_optimizer_state_order. + * Distributed information extraction: Record the distributed information of parameters: name -> [{tp_rank, sharding_rank}, shape, dtype, file_name]. + """ rank_access_files = {} if self.save_sharded_model: rank_access_files[self.cur_rank] = self.cur_rank_optimizer_state_file_names @@ -327,13 +515,12 @@ def load_state_dict_and_rename(self): self.optimizer_state_with_master_weights = True master_weights = state_dict.pop("master_weights") file_to_master_weights_keys[file] = list(master_weights.keys()) - for k, v in master_weights.items(): + for master_weight_name, master_weight_value in master_weights.items(): # In sharding stage3, ‘@slice’ will be added in front of the key for master_weight, which is removed here. - k = k.replace("slice@", "") - state_dict[k] = v + state_dict[master_weight_name.replace("slice@", "")] = master_weight_value # Standardize the state names of the AdamW optimizer. - adamw_optimizer_param_suffix_name_mapping = { + adamw_opt_state_suffix_name_mapping = { ".w_0_fp32_master_0_moment1_0": ".w_0_moment1_0", ".w_0_fp32_master_0_moment2_0": ".w_0_moment2_0", ".w_0_fp32_master_0_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", @@ -341,13 +528,15 @@ def load_state_dict_and_rename(self): } unified_name_state_dict = {} - for k, v in state_dict.items(): - new_k = k - for suffix in adamw_optimizer_param_suffix_name_mapping: - if k.endswith(suffix): - new_k = k.replace(suffix, adamw_optimizer_param_suffix_name_mapping[suffix]) + for opt_state_name, opt_state_value in state_dict.items(): + new_opt_state_name = opt_state_name + for suffix in adamw_opt_state_suffix_name_mapping: + if opt_state_name.endswith(suffix): + new_opt_state_name = opt_state_name.replace( + suffix, adamw_opt_state_suffix_name_mapping[suffix] + ) break - unified_name_state_dict[new_k] = v + unified_name_state_dict[new_opt_state_name] = opt_state_value self.cur_rank_loaded_state_dict[file] = unified_name_state_dict @@ -386,27 +575,7 @@ def load_state_dict_and_rename(self): # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, # and then append the tp_degree. renamed_state_dict = self.rename_using_model_meta(file) - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) - for new_k, v in renamed_state_dict.items(): - if new_k not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[new_k] = [ - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ] - else: - cur_rank_sharded_tensor_infos[new_k].append( - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ) - + self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) self.cur_rank_loaded_state_dict[file] = renamed_state_dict # 2. In handling the sharding stage1 v1 and stage2 scenario, the optimizer states are distributed across different ranks. # We need to obtain the name mapping by simulating the partitioning method, without concern for the presence of master_weights. @@ -415,8 +584,8 @@ def load_state_dict_and_rename(self): file_to_state_dict_shapes_mapping = {} for file, state_dict in self.cur_rank_loaded_state_dict.items(): shapes = [] - for k, v in state_dict.items(): - shapes.append([k, v.shape]) + for state_name, state_value in state_dict.items(): + shapes.append([state_name, state_value.shape]) file_to_state_dict_shapes_mapping[file] = shapes global_file_to_state_dict_shapes_mapping = self.gather_global_object(file_to_state_dict_shapes_mapping) @@ -436,13 +605,13 @@ def load_state_dict_and_rename(self): partition_result_0 = self.partition_parameters(model_state_shapes, False, self.sharding_degree) partition_result_1 = self.partition_parameters(model_state_shapes, True, self.sharding_degree) - for k, v in partition_result_0.items(): - v = sorted(v, key=model_state_shapes.index) - partition_result_0[k] = v + for rank, portion in partition_result_0.items(): + portion = sorted(portion, key=model_state_shapes.index) + partition_result_0[rank] = portion - for k, v in partition_result_1.items(): - v = sorted(v, key=model_state_shapes.index) - partition_result_1[k] = v + for rank, portion in partition_result_1.items(): + portion = sorted(portion, key=model_state_shapes.index) + partition_result_1[rank] = portion sharding_sort_parameters = False @@ -478,152 +647,92 @@ def load_state_dict_and_rename(self): ] = partitioned_shard[j][0] renamed_state_dict = {} - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) - # In this branch, sharding does not split the optimizer states; it merely relocates them to different cards. # Therefore, the sharding information can now be directly removed. - for k, v in state_dict.items(): - master_weight_name = self.parse_master_weight_name_by(k) + for opt_state_name, opt_state_value in state_dict.items(): + master_weight_name = self.parse_master_weight_name_by(opt_state_name) model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] - new_k = k.replace(master_weight_name, model_weight_name) - renamed_state_dict[new_k] = v - if new_k not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[new_k] = [ - [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] - ] - else: - cur_rank_sharded_tensor_infos[new_k].append( - [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] - ) + new_opt_state_name = opt_state_name.replace(master_weight_name, model_weight_name) + renamed_state_dict[new_opt_state_name] = opt_state_value + + self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) self.cur_rank_loaded_state_dict[file] = renamed_state_dict else: - for k, v in state_dict.items(): - if k not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[k] = [ - [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] - ] - else: - cur_rank_sharded_tensor_infos[k].append( - [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] - ) + self.get_sharded_tensor_infos(file, state_dict, cur_rank_sharded_tensor_infos) else: for file, state_dict in self.cur_rank_loaded_state_dict.items(): renamed_state_dict = self.rename_using_model_meta(file) - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) - for new_k, v in renamed_state_dict.items(): - if new_k not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[new_k] = [ - [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] - ] - else: - cur_rank_sharded_tensor_infos[new_k].append( - [{"tp_rank": tp_rank}, v.shape, str(v.dtype).split(".")[1], file] - ) + self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) self.cur_rank_loaded_state_dict[file] = renamed_state_dict else: # 3. Handling the sharding stage3 and non-sharding scenario if not self.save_sharded_model: for file, state_dict in self.cur_rank_loaded_state_dict.items(): - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): - renamed_state_dict = self.rename_using_optimizer_state_order(file) - - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) - for new_k, v in renamed_state_dict.items(): - if new_k not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[new_k] = [ - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ] - else: - cur_rank_sharded_tensor_infos[new_k].append( - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ) - + self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) self.cur_rank_loaded_state_dict[file] = renamed_state_dict else: - for k, v in state_dict.items(): - if k not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[k] = [ - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ] - else: - cur_rank_sharded_tensor_infos[k].append( - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ) + self.get_sharded_tensor_infos(file, state_dict, cur_rank_sharded_tensor_infos) else: for file, state_dict in self.cur_rank_loaded_state_dict.items(): # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, # and then append the tp_degree. renamed_state_dict = self.rename_using_model_meta(file) - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) - for new_k, v in renamed_state_dict.items(): - if new_k not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[new_k] = [ - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ] - else: - cur_rank_sharded_tensor_infos[new_k].append( - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - v.shape, - str(v.dtype).split(".")[1], - file, - ] - ) - + self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) self.cur_rank_loaded_state_dict[file] = renamed_state_dict # gather global sharded tensor infos sharded_tensor_infos = self.gather_global_object({self.cur_rank: cur_rank_sharded_tensor_infos}) self.global_sharded_tensor_infos = {} for rank, sharded_tensor_info in sharded_tensor_infos.items(): - for k, v in sharded_tensor_info.items(): - if k not in self.global_sharded_tensor_infos: - self.global_sharded_tensor_infos[k] = v + for state_name, shard_info in sharded_tensor_info.items(): + if state_name not in self.global_sharded_tensor_infos: + self.global_sharded_tensor_infos[state_name] = shard_info else: - self.global_sharded_tensor_infos[k] += v + self.global_sharded_tensor_infos[state_name] += shard_info + + def get_sharded_tensor_infos(self, file, state_dict, cur_rank_sharded_tensor_infos): + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + for state_name, state_value in state_dict.items(): + if state_name not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[state_name] = [ + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + state_value.shape, + str(state_value.dtype).split(".")[1], + file, + ] + ] + else: + cur_rank_sharded_tensor_infos[state_name].append( + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + state_value.shape, + str(state_value.dtype).split(".")[1], + file, + ] + ) def gen_metadata_for_tp_sharded_tensor(self): - for k, v in self.global_sharded_tensor_infos.items(): - v.sort(key=lambda x: x[0]["tp_rank"]) + """ + Based on the distributed information of each weight/optimizer state (global_sharded_tensor_infos), construct Metadata + information: LocalTensorMetadata,LocalTensorIndex + """ + for state_name, shard_info in self.global_sharded_tensor_infos.items(): + shard_info.sort(key=lambda x: x[0]["tp_rank"]) state_dict_metadata = {} storage_metadata = {} # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. - for k, v in self.global_sharded_tensor_infos.items(): + for state_name, shard_info in self.global_sharded_tensor_infos.items(): global_offset = 0 - local_shape = v[0][1] - model_state_name = self.optimizer_key_to_model_state_key(k) - if "_pow_acc_0" not in k: + local_shape = shard_info[0][1] + model_state_name = self.optimizer_key_to_model_state_key(state_name) + if "_pow_acc_0" not in state_name: global_shape = self.model_state_global_shape[model_state_name] else: global_shape = (1,) @@ -639,16 +748,16 @@ def gen_metadata_for_tp_sharded_tensor(self): global_offset = [0] * len(local_shape) if is_replicated: - v = [v[0]] + shard_info = [shard_info[0]] - for item in v: + for item in shard_info: local_tensor_meta_data = LocalTensorMetadata(tuple(global_offset), item[1], item[2]) - local_tensor_index = LocalTensorIndex(k, tuple(global_offset)) + local_tensor_index = LocalTensorIndex(state_name, tuple(global_offset)) global_offset[axis] += item[1][axis] - if k not in state_dict_metadata: - state_dict_metadata[k] = [local_tensor_meta_data] + if state_name not in state_dict_metadata: + state_dict_metadata[state_name] = [local_tensor_meta_data] else: - state_dict_metadata[k].append(local_tensor_meta_data) + state_dict_metadata[state_name].append(local_tensor_meta_data) storage_metadata[local_tensor_index] = item[3] metadata = Metadata(state_dict_metadata, storage_metadata, None) @@ -656,389 +765,249 @@ def gen_metadata_for_tp_sharded_tensor(self): return metadata, source_state_dict - def gen_metadata_and_prepare_source_state_dict(self): - self.load_state_dict_and_rename() - if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: - for k, v in self.global_sharded_tensor_infos.items(): - v.sort(key=lambda x: x[0]["sharding_rank"]) - - state_dict_metadata = {} - storage_metadata = {} - # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. - for k, v in self.global_sharded_tensor_infos.items(): - global_offset = [0] * self.tp_degree - for item in v: - tp_rank = item[0]["tp_rank"] - k_with_tp_rank = k + "_tp" + "{:02d}".format(tp_rank) - local_tensor_meta_data = LocalTensorMetadata((global_offset[tp_rank],), item[1], item[2]) - local_tensor_index = LocalTensorIndex(k_with_tp_rank, (global_offset[tp_rank],)) - global_offset[tp_rank] += item[1][0] - if k_with_tp_rank not in state_dict_metadata: - state_dict_metadata[k_with_tp_rank] = [local_tensor_meta_data] - else: - state_dict_metadata[k_with_tp_rank].append(local_tensor_meta_data) - storage_metadata[local_tensor_index] = item[3] + def rename_using_model_meta(self, file_name): + """ + Rename the keys in opt_state_dict based on the following rule: model_meta records a mapping of parameter names to optimizer names. + Here, we unify the optimizer state names to parameter names directly. For example: + * model_meta: linear0 -> param0 + * opt_state: param0.w0 + * Renamed opt_state: linear0.w0 + NOTE:The reason for renaming is that there is a difference in the naming of optimizer parameters between dynamic and static partitions, + making it difficult to match optimizer parameters directly by name. Therefore, we unify them to the weight names. + """ + if not hasattr(self, "model_meta"): + meta_file_path = os.path.join(self.path, MODEL_META_FILE_NAME) + assert os.path.exists(meta_file_path) + with open(meta_file_path, "r") as file: + self.model_meta = json.load(file) - metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) + dist_strategy_key = "tp" + "{:02d}".format(tp_rank) + "_" + "pp" + "{:02d}".format(pp_rank) + # Map model weight names to their corresponding names of master_weights in the optimizer state. + if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): + structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"] + master_weight_name_to_model_weight_name_mapping = {} + for model_weight_name, master_weight_name in structure_name_mapping.items(): + master_weight_name_to_model_weight_name_mapping[master_weight_name.split(".")[0]] = model_weight_name - source_state_dict_for_merge_sharding = {} - for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): - renamed_state_dict = {} - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) - for k, v in state_dict.items(): - if self.global_sharded_tensor_infos[k][0][0]["tp_rank"] != -1: - k_with_tp_rank = k + "_tp" + "{:02d}".format(tp_rank) - renamed_state_dict[k_with_tp_rank] = v - else: - renamed_state_dict[k] = v - - source_state_dict_for_merge_sharding[file_name] = renamed_state_dict - - assert self.model_meta is not None - global_model_state_shapes = [] - sharding_metas_keys = [] - for i in range(self.pp_degree): - for j in range(self.tp_degree): - sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j)) - for key in sharding_metas_keys: - param_meta = self.model_meta["sharding_metas"][key]["param_meta"] - for k, v in param_meta.items(): - global_model_state_shapes.append([k, v[0]]) - - # Distribute all model parameters evenly across each card for loading - - world_size = paddle.distributed.get_world_size() - - partition_mapping = self.partition_parameters(global_model_state_shapes, True, world_size) - - partition_model_state_keys = [] - for cur_rank, partition_model_state in partition_mapping.items(): - partition_model_state_keys.append([item[0] for item in partition_model_state]) - - param_meta = {} - for i in range(self.tp_degree): - for j in range(self.pp_degree): - key = "tp{:02d}_pp{:02d}".format(i, j) - pm = self.model_meta["sharding_metas"][key]["param_meta"] - for k, v in pm.items(): - param_meta[k] = v - - param_flattened_shapes = {} - for k, v in param_meta.items(): - param_flattened_shapes[k] = reduce(lambda x, y: x * y, v[0]) - - cur_rank_need_load_model_state_keys = partition_model_state_keys[self.cur_rank] - - # Generate the optimizer states corresponding to the model weights. - optimizer_state_dict = {} - for key in cur_rank_need_load_model_state_keys: - for tp_rank in range(self.tp_degree): - tp_rank_suffix = "_tp{:02d}".format(tp_rank) - optimizer_state_dict[key + ".w_0_moment1_0" + tp_rank_suffix] = paddle.zeros( - (param_flattened_shapes[key],), "float32" - ) - optimizer_state_dict[key + ".w_0_moment2_0" + tp_rank_suffix] = paddle.zeros( - (param_flattened_shapes[key],), "float32" - ) - if self.optimizer_state_with_master_weights: - optimizer_state_dict[key + ".w_0" + tp_rank_suffix] = paddle.zeros( - (param_flattened_shapes[key],), "float32" - ) - # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. - # Later, when these are compared with the global shape, we realize that they are replicated. - - optimizer_state_dict[key + ".w_0_beta1_pow_acc_0" + tp_rank_suffix] = paddle.zeros((1,), "float32") - optimizer_state_dict[key + ".w_0_beta2_pow_acc_0" + tp_rank_suffix] = paddle.zeros((1,), "float32") - - # merge sharding - _load_state_dict(optimizer_state_dict, source_state_dict_for_merge_sharding, [metadata_for_merge_sharding]) - - # Reshape - for k, v in optimizer_state_dict.items(): - if v.shape[0] > 1 and "_tp" in k: - param_name = self.optimizer_key_to_model_state_key(k[:-5]) - param_shape = param_meta[param_name][0] - assert v.numel() == reduce(lambda x, y: x * y, param_shape) - reshaped_v = v.reshape(param_shape) - optimizer_state_dict[k] = reshaped_v - concat_optimier_state_dict = {} - - optimizer_state_key_to_tp_keys = {} - for key in optimizer_state_dict.keys(): - # Count how each key is split into keys ending with ‘_tpXX’. - # optimizer_state_key_to_tp_keys : {key:[key_tp00,key_tp01]} - key_removed_tp_rank = key[:-5] - if key_removed_tp_rank not in optimizer_state_key_to_tp_keys: - optimizer_state_key_to_tp_keys[key_removed_tp_rank] = [key] - else: - optimizer_state_key_to_tp_keys[key_removed_tp_rank].append(key) - - for key, value in optimizer_state_key_to_tp_keys.items(): - value.sort(key=lambda x: int(x[-2:])) - - for key, tp_keys in optimizer_state_key_to_tp_keys.items(): - model_state_name = self.optimizer_key_to_model_state_key(key) - local_shape = optimizer_state_dict[tp_keys[0]].shape - if "_pow_acc_0" not in key: - global_shape = self.model_state_global_shape[model_state_name] - else: - global_shape = (1,) - - assert len(local_shape) == len(global_shape) - - axis = -1 - for i in range(len(local_shape)): - if local_shape[i] != global_shape[i]: - axis = i - break - - is_replicated = axis == -1 - tp_tensors = [] - for tp_key in tp_keys: - tp_tensors.append(optimizer_state_dict[tp_key]) - - if not is_replicated: - # Derive the partition strategy based on the global_shape, then concatenate. - concat_optimier_state_dict[key] = paddle.concat(tp_tensors, axis=axis) - else: - concat_optimier_state_dict[key] = tp_tensors[0] - - fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" - local_tensor_meta_data = {} - local_tensor_index = {} - for k, v in concat_optimier_state_dict.items(): - # Generate metadata. - local_shape = v.shape - global_offset = tuple([0] * len(local_shape)) - dtype = str(v.dtype).split(".")[1] - local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) - local_tensor_index[k] = [LocalTensorIndex(k, global_offset), fake_file_name] - - global_local_tensor_meta_data = [] - global_local_tensor_index = [] - - use_dist = True if paddle.distributed.get_world_size() > 1 else False - - if use_dist: - paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) - paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) - else: - global_local_tensor_meta_data = [local_tensor_meta_data] - global_local_tensor_index = [local_tensor_index] - - state_dict_metadata = {} - for tensor_meta_data in global_local_tensor_meta_data: - for k, v in tensor_meta_data.items(): - if k not in state_dict_metadata: - state_dict_metadata[k] = [v] - else: - state_dict_metadata[k].append(v) - - storage_metadata = {} - for tensor_index in global_local_tensor_index: - for k, v in tensor_index.items(): - storage_metadata[v[0]] = v[1] - - meta_data = Metadata(state_dict_metadata, storage_metadata, None) - source_state_dict = {fake_file_name: concat_optimier_state_dict} - - return meta_data, source_state_dict - - elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: - return self.gen_metadata_for_tp_sharded_tensor() + renamed_state_dict = {} + state_dict = self.cur_rank_loaded_state_dict[file_name] + for opt_state_name, opt_state_value in state_dict.items(): + master_weight_name = self.parse_master_weight_name_by(opt_state_name) + model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] + renamed_state_dict[opt_state_name.replace(master_weight_name, model_weight_name)] = opt_state_value + return renamed_state_dict else: - if self.is_sharding_stage3: - for k, v in self.global_sharded_tensor_infos.items(): - v.sort(key=lambda x: x[0]["sharding_rank"]) - state_dict_metadata = {} - storage_metadata = {} - # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. - for k, v in self.global_sharded_tensor_infos.items(): - global_offset = 0 - for item in v: - if len(item[1]) == 1: - local_tensor_meta_data = LocalTensorMetadata((global_offset,), item[1], item[2]) - local_tensor_index = LocalTensorIndex(k, (global_offset,)) - global_offset += item[1][0] - else: - global_offset = tuple([0] * len(item[1])) - local_tensor_meta_data = LocalTensorMetadata(global_offset, item[1], item[2]) - local_tensor_index = LocalTensorIndex(k, global_offset) - if k not in state_dict_metadata: - state_dict_metadata[k] = [local_tensor_meta_data] - else: - state_dict_metadata[k].append(local_tensor_meta_data) - storage_metadata[local_tensor_index] = item[3] + return self.cur_rank_loaded_state_dict[file_name] - metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) - model_state_shapes = [] - dtype = "" - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - if file.endswith(MODEL_WEIGHT_SUFFIX): - for k, v in state_dict.items(): - model_state_shapes.append([k, v.shape]) - dtype = str(v.dtype).split(".")[1] + def partition_parameters(self, model_state_shapes, is_sort, shard_num): + """ + In sharding_stage3 and sharding_stage1_v1, parameters and optimizer states will be assigned to different ranks. This function defines the allocation rules. + For details, refer to: python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py. + """ + mapping = {} + for rank_ in range(shard_num): + mapping[rank_] = [] + sizes = [0] * shard_num - dtypes = self.gather_global_object([dtype]) - for dtype_s in dtypes: - if len(dtype_s) > 0: - dtype = dtype_s + parameters = model_state_shapes.copy() - assert len(dtype) > 0 + if is_sort: + parameters.sort(key=lambda p: reduce(lambda x, y: x * y, p[1]), reverse=True) - global_model_state_shapes = self.gather_global_object(model_state_shapes) + for param in parameters: + rank = sizes.index(min(sizes)) + mapping[rank].append(param) + numel = reduce(lambda x, y: x * y, param[1], 1) + assert numel > 0, f"param [{param[0]}] should larger than 0, but it is [{numel}]" + sizes[rank] += numel - partition_result = self.partition_parameters( - global_model_state_shapes, True, paddle.distributed.get_world_size() - ) + return mapping - cur_rank_merger_model_params = partition_result[self.cur_rank] - target_state_dict = {} - for item in cur_rank_merger_model_params: - key = item[0] - shape = item[1] - flatten_shape = reduce(lambda a, b: a * b, item[1]) - target_state_dict[key] = paddle.zeros(shape, dtype) - target_state_dict[key + ".w_0_moment1_0"] = paddle.zeros((flatten_shape,), "float32") - target_state_dict[key + ".w_0_moment2_0"] = paddle.zeros((flatten_shape,), "float32") - if self.optimizer_state_with_master_weights: - target_state_dict[key + ".w_0"] = paddle.zeros((flatten_shape,), "float32") - # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. - # Later, when these are compared with the global shape, we realize that they are replicated. + def rename_using_optimizer_state_order(self, file_name): + """ + Rename the keys in opt_state_dict based on the following rule: The order of weights recorded in the weight file is consistent with the order of optimizer states recorded in the optimizer file. + By using this order, we can obtain the correspondence between the names of weights and optimizer states and rename the optimizer accordingly. For example: + * model_state: linear0, linear1 + * opt_state: param0.w0, param1.w0 + * Renamed opt_state: linear0.w0, linear1.w0 + NOTE:The reason for renaming is that there is a difference in the naming of optimizer parameters between dynamic and static partitions, making it difficult to match optimizer parameters directly by name. + Therefore, we unify them to the weight names. + """ + if not hasattr(self, "global_file_to_state_dict_keys_mapping"): + file_to_state_dict_keys_mapping = {} + for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): + file_to_state_dict_keys_mapping[file_name] = list(state_dict.keys()) - target_state_dict[key + ".w_0_beta1_pow_acc_0"] = paddle.zeros((1,), "float32") - target_state_dict[key + ".w_0_beta2_pow_acc_0"] = paddle.zeros((1,), "float32") + self.global_file_to_state_dict_keys_mapping = self.gather_global_object(file_to_state_dict_keys_mapping) - _load_state_dict(target_state_dict, self.cur_rank_loaded_state_dict, [metadata_for_merge_sharding]) + if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): + model_state_file_name = self.get_model_state_file_from(file_name) + assert model_state_file_name is not None + model_state_keys = self.global_file_to_state_dict_keys_mapping[model_state_file_name] + optimizer_state_keys = self.global_file_to_state_dict_keys_mapping[file_name] - # Reshape - for item in cur_rank_merger_model_params: - key = item[0] - shape = item[1] - for k, v in target_state_dict.items(): - if key == self.optimizer_key_to_model_state_key(k): - if tuple(shape) != tuple(v.shape) and v.numel() == reduce(lambda x, y: x * y, shape): - reshaped_v = v.reshape(shape) - target_state_dict[k] = reshaped_v + master_weight_name_to_model_weight_name_mapping = {} + for i in range(len(model_state_keys)): + master_weight_name = self.parse_master_weight_name_by(optimizer_state_keys[i]) + master_weight_name_to_model_weight_name_mapping[master_weight_name] = model_state_keys[i] - fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" - local_tensor_meta_data = {} - local_tensor_index = {} - for k, v in target_state_dict.items(): - # Generate metadata. - local_shape = v.shape - global_offset = tuple([0] * len(local_shape)) - dtype = str(v.dtype).split(".")[1] - local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) - local_tensor_index[k] = [LocalTensorIndex(k, global_offset), fake_file_name] + state_dict = self.cur_rank_loaded_state_dict[file_name] + renamed_state_dict = {} + for opt_state_name, opt_state_value in state_dict.items(): + master_weight_name = self.parse_master_weight_name_by(opt_state_name) + model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] + renamed_state_dict[opt_state_name.replace(master_weight_name, model_weight_name)] = opt_state_value + return renamed_state_dict + else: + return self.cur_rank_loaded_state_dict[file_name] - global_local_tensor_meta_data = [] - global_local_tensor_index = [] + def get_save_sharded_model_flag(self): + save_sharded_model_flag = self.gather_global_object( + [os.path.exists(os.path.join(self.path, MODEL_META_FILE_NAME))] + ) + return True in save_sharded_model_flag - use_dist = True if paddle.distributed.get_world_size() > 1 else False + def flatten_state_dict(self, state_dict): + flattened_state_dict = {} + flat_state_dict, mapping = flatten_state_dict(state_dict) + for k, v in flat_state_dict.items(): + last_level_key = mapping[k][-1] + assert last_level_key not in flattened_state_dict + flattened_state_dict[last_level_key] = v + return flattened_state_dict - if use_dist: - paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) - paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) - else: - global_local_tensor_meta_data = [local_tensor_meta_data] - global_local_tensor_index = [local_tensor_index] + def gather_global_object(self, cur_rank_object): + all_rank_objects = [] + if self.use_dist: + paddle.distributed.all_gather_object(all_rank_objects, cur_rank_object) + else: + all_rank_objects = [all_rank_objects] - state_dict_metadata = {} - for tensor_meta_data in global_local_tensor_meta_data: - for k, v in tensor_meta_data.items(): - if k not in state_dict_metadata: - state_dict_metadata[k] = [v] - else: - state_dict_metadata[k].append(v) + if isinstance(cur_rank_object, list): + return [item for sublist in all_rank_objects for item in sublist] + elif isinstance(cur_rank_object, dict): + global_map = {} + for rank_map in all_rank_objects: + global_map.update(rank_map) + return global_map + else: + raise ValueError("cur_rank_object should be either a list or a dict") - storage_metadata = {} - for tensor_index in global_local_tensor_index: - for k, v in tensor_index.items(): - storage_metadata[v[0]] = v[1] + def get_local_checkpoint_file_names(self): + cur_rank_files = os.listdir(self.path) + cur_rank_model_state_file_names = [] + cur_rank_optimizer_state_file_names = [] + for file_name in cur_rank_files: + if file_name.endswith(MODEL_WEIGHT_SUFFIX): + cur_rank_model_state_file_names.append(file_name) + elif file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): + cur_rank_optimizer_state_file_names.append(file_name) + if SCHEDULER_NAME in cur_rank_model_state_file_names: + cur_rank_model_state_file_names.remove(SCHEDULER_NAME) + return cur_rank_model_state_file_names, cur_rank_optimizer_state_file_names - meta_data = Metadata(state_dict_metadata, storage_metadata, None) - source_state_dict = {fake_file_name: target_state_dict} + def get_distribution_rank_from_file_name(self, file_name): + pp_degree = 0 + tp_degree = 0 + sharding_degree = 0 + pattern_pp = r"pp(\d+)" + pattern_tp = r"tp(\d+)" + pattern_shard = r"shard(\d+)" + match_pp = re.search(pattern_pp, file_name) + if match_pp: + pp_degree = int(match_pp.group(1)) + match_tp = re.search(pattern_tp, file_name) + if match_tp: + tp_degree = int(match_tp.group(1)) + match_shard = re.search(pattern_shard, file_name) + if match_shard: + sharding_degree = int(match_shard.group(1)) + return (tp_degree, pp_degree, sharding_degree) - return meta_data, source_state_dict - else: - return self.gen_metadata_for_tp_sharded_tensor() + def initial_distributed_configuration(self): + self.pp_degree = 0 + self.tp_degree = 0 + self.sharding_degree = 0 - def rename_auto_parallel_state_dict(self): - need_remove_key_pattern = ["eager_tmp", "learning_rate", "@GRAD@MERG", "gradient_merge_"] + all_files = self.global_model_state_file_names + self.global_optimizer_state_file_names - need_remove_key = set() - for key in self.auto_parallel_state_dict.keys(): - for pattern in need_remove_key_pattern: - if pattern in key: - need_remove_key.add(key) - break + for file in all_files: + (tp_degree, pp_degree, sharding_degree) = self.get_distribution_rank_from_file_name(file) + self.pp_degree = max(self.pp_degree, pp_degree) + self.tp_degree = max(self.tp_degree, tp_degree) + self.sharding_degree = max(self.sharding_degree, sharding_degree) - for key in need_remove_key: - self.auto_parallel_state_dict.pop(key) + self.pp_degree = self.pp_degree + 1 + self.tp_degree = self.tp_degree + 1 + self.sharding_degree = self.sharding_degree + 1 - adamw_optimizer_status_name_suffix_mappings = { - "_fp32_master_1_moment1_0": ".w_0_moment1_0", - "_fp32_master_1_moment2_0": ".w_0_moment2_0", - "_fp32_master_1_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", - "_fp32_master_1_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", - "_fp32_master_1": ".w_0", - "_moment1_0": ".w_0_moment1_0", - "_moment2_0": ".w_0_moment2_0", - "_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", - "_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", - } + def infer_sharding_stage1_v(self): + sharding_stage1_v = [2] + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX) and sharding_stage1_v[0] == 2: + for k, v in state_dict.items(): + # Under shardingv2, the optimizer state is first flattened and then split. + if len(v.shape) != 1: + sharding_stage1_v = [1] + break - def rename(old_name, map1, map2): - for i in range(1, len(old_name)): - str1 = old_name[:i] - str2 = old_name[i:] - if (str1 in map1) and (str2 in map2): - transformed_str1 = map1[str1] - transformed_str2 = map2[str2] - return transformed_str1 + transformed_str2 - return None + sharding_stage1_v = self.gather_global_object(sharding_stage1_v) + if 1 in sharding_stage1_v: + return 1 + return 2 - renamed_state_dict = {} + def infer_is_sharding_stage3(self): + if self.sharding_degree == 1: + return False + if self.pp_degree > 1 or self.tp_degree > 1: + # Currently, sharding stage 3 does not support concurrent use with tensor parallelism (TP) and pipeline parallelism (PP). + return False - for key, value in self.auto_parallel_state_dict.items(): + is_sharding_stage3 = True - if key in self.parameter_to_structured_name.values(): - new_name = key - else: - new_name = rename(key, self.parameter_to_structured_name, adamw_optimizer_status_name_suffix_mappings) + file_to_state_shape_mapping = {} + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + state_shape_mapping = {} + for k, v in state_dict.items(): + state_shape_mapping[k] = v.shape + if len(v.shape) != 1: + return False + file_to_state_shape_mapping[file] = state_shape_mapping + global_file_to_state_shape_mapping = self.gather_global_object(file_to_state_shape_mapping) - assert new_name is not None - renamed_state_dict[new_name] = value + state_dict_std = global_file_to_state_shape_mapping[list(global_file_to_state_shape_mapping.keys())[0]] - self.auto_parallel_state_dict = renamed_state_dict + for file, state_dict in global_file_to_state_shape_mapping.items(): + if state_dict != state_dict_std: + is_sharding_stage3 = False + break + return is_sharding_stage3 - def load_from_hybrid_parallel_checkpoint(self): - self.rename_auto_parallel_state_dict() - metadata, source_state_dict = self.gen_metadata_and_prepare_source_state_dict() - if self.save_sharded_model: - model_params = {} - for k, v in self.auto_parallel_state_dict.items(): - if k in self.parameter_to_structured_name.values(): - model_params[k] = v - for k in model_params.keys(): - self.auto_parallel_state_dict.pop(k) + def parse_master_weight_name_by(self, optimizer_state_name): + return optimizer_state_name.split(".")[0] - appended_master_weight_names = [] - for k, v in model_params.items(): - master_weight = k + ".w_0" - if master_weight not in self.auto_parallel_state_dict: - appended_master_weight_names.append(master_weight) - tmp_tensor = paddle.zeros(v.shape, "float32") - dist_tmp_tensor = dist.shard_tensor(tmp_tensor, v.process_mesh, v.placements) - self.auto_parallel_state_dict[master_weight] = dist_tmp_tensor + def get_model_state_file_from(self, optimizer_state_file_name): + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(optimizer_state_file_name) + for model_state_file in self.global_model_state_file_names: + distributed_rank = self.get_distribution_rank_from_file_name(model_state_file) + if tp_rank == distributed_rank[0] and pp_rank == distributed_rank[1]: + return model_state_file + return None - _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) - for k, v in model_params.items(): - master_weight = self.auto_parallel_state_dict[k + ".w_0"] - cast_master_weight = paddle.cast(master_weight._local_value(), "bfloat16") - paddle.assign(cast_master_weight, v._local_value()) - for k in appended_master_weight_names: - self.auto_parallel_state_dict.pop(k) - else: - _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) + def optimizer_key_to_model_state_key(self, optimizer_key): + adamw_optimizer_key_suffix = [ + ".w_0_beta1_pow_acc_0", + ".w_0_beta2_pow_acc_0", + ".w_0_moment1_0", + ".w_0_moment2_0", + ".w_0", + ] + model_state_key = optimizer_key + for suffix in adamw_optimizer_key_suffix: + if model_state_key.endswith(suffix): + # Remove the suffix from model_state_key + model_state_key = model_state_key[: -len(suffix)] + break + return model_state_key From 243d2804a7da8c6b4ef5ba8b0534bd6503044436 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Mon, 12 Aug 2024 17:01:40 +0800 Subject: [PATCH 16/30] fix codestyle --- paddlenlp/trainer/ckpt_converter.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index 9acb8f13fe16..ad453f038646 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -212,17 +212,17 @@ def gen_metadata_and_prepare_source_state_dict(self): for cur_rank, partition_model_state in partition_mapping.items(): partition_model_state_keys.append([item[0] for item in partition_model_state]) - param_meta = {} + all_param_meta = {} for i in range(self.tp_degree): for j in range(self.pp_degree): key = "tp{:02d}_pp{:02d}".format(i, j) param_meta = self.model_meta["sharding_metas"][key]["param_meta"] for param_name, param_shape_and_dtype in param_meta.items(): - param_meta[param_name] = param_shape_and_dtype + all_param_meta[param_name] = param_shape_and_dtype param_flattened_shapes = {} - for param_meta, param_shape_and_dtype in param_meta.items(): - param_flattened_shapes[param_meta] = reduce(lambda x, y: x * y, param_shape_and_dtype[0]) + for param_name, param_shape_and_dtype in all_param_meta.items(): + param_flattened_shapes[param_name] = reduce(lambda x, y: x * y, param_shape_and_dtype[0]) cur_rank_need_load_model_state_keys = partition_model_state_keys[self.cur_rank] @@ -254,7 +254,7 @@ def gen_metadata_and_prepare_source_state_dict(self): for opt_state_name, opt_state_value in optimizer_state_dict.items(): if opt_state_value.shape[0] > 1 and "_tp" in opt_state_name: param_name = self.optimizer_key_to_model_state_key(opt_state_name[:-5]) - param_shape = param_meta[param_name][0] + param_shape = all_param_meta[param_name][0] assert opt_state_value.numel() == reduce(lambda x, y: x * y, param_shape) reshaped_opt_state_value = opt_state_value.reshape(param_shape) optimizer_state_dict[opt_state_name] = reshaped_opt_state_value @@ -281,7 +281,8 @@ def gen_metadata_and_prepare_source_state_dict(self): else: global_shape = (1,) - assert len(local_shape) == len(global_shape) + if len(local_shape) != 1: + assert len(local_shape) == len(global_shape) axis = -1 for i in range(len(local_shape)): @@ -498,7 +499,6 @@ def load_state_dict_and_rename(self): ) need_read_files = get_local_load_files(self.gather_global_object(rank_access_files)) - self.cur_rank_loaded_state_dict = {} for file in need_read_files: @@ -540,6 +540,10 @@ def load_state_dict_and_rename(self): self.cur_rank_loaded_state_dict[file] = unified_name_state_dict + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + for k, v in state_dict.items(): + print(k, v.shape) + # After the rank has finished loading the files it needs, it can infer sharding_stage1_v and is_sharding_stage3. self.sharding_stage1_v = self.infer_sharding_stage1_v() self.is_sharding_stage3 = self.infer_is_sharding_stage3() From 99ea6b8e9ddede6910db34f3855b5636e1cbfa83 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Mon, 12 Aug 2024 20:21:27 +0800 Subject: [PATCH 17/30] fix codestyle --- paddlenlp/trainer/auto_trainer.py | 19 ++++++++++--------- paddlenlp/trainer/ckpt_converter.py | 6 +----- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index b3ff932281a9..c7a171ec0b30 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -697,14 +697,15 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): ) if self.args.to_static: - opt_state_dict = { + model_state_dict = { key: value - for key, value in self.model_wrapped.state_dict("opt").items() + for key, value in self.model_wrapped.state_dict("param").items() if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS) } - state_dict = { - MODEL_NAME: self.model_wrapped.state_dict("param"), - OPTIMIZER_NAME: opt_state_dict, + optim_state_dict = { + key: value + for key, value in self.model_wrapped.state_dict("opt").items() + if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS) } else: model_state_dict = self.model_wrapped.state_dict() @@ -717,10 +718,10 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): optim_state_dict = self.optimizer.state_dict() optim_state_dict.pop("LR_Scheduler", None) - state_dict = { - MODEL_NAME: model_state_dict, - OPTIMIZER_NAME: optim_state_dict, - } + state_dict = { + MODEL_NAME: model_state_dict, + OPTIMIZER_NAME: optim_state_dict, + } if self.args.resume_form_hybrid_parallel: CheckpointConverter( diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index ad453f038646..b1ffdf02cb0b 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -65,7 +65,7 @@ def load_from_hybrid_parallel_checkpoint(self): """ Automatically and inplace load the distributed checkpoint stored in hybrid parallel mode into the auto parallel state_dict. The main logic is as follows: - 1. Callrename_semi_auto_state_dict: Rename the keys of the auto parallel state_dict according to certain rules. + 1. Call rename_semi_auto_state_dict: Rename the keys of the auto parallel state_dict according to certain rules. (Why rename? To facilitate the subsequent correspondence between the optimizer state names of the semi-automatic and static optimizers.) 2. Callgen_metadata_and_prepare_source_state_dict: Automatically parse the manual checkpoint file based on the state_dict information provided by auto parallel, obtaining the Metadata and state_dict required for auto parallel to load the checkpoint. @@ -540,10 +540,6 @@ def load_state_dict_and_rename(self): self.cur_rank_loaded_state_dict[file] = unified_name_state_dict - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - for k, v in state_dict.items(): - print(k, v.shape) - # After the rank has finished loading the files it needs, it can infer sharding_stage1_v and is_sharding_stage3. self.sharding_stage1_v = self.infer_sharding_stage1_v() self.is_sharding_stage3 = self.infer_is_sharding_stage3() From 12de58e2e734613dadce5c23e8f19dd76cdb6ebf Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Tue, 13 Aug 2024 11:06:50 +0800 Subject: [PATCH 18/30] fix --- paddlenlp/trainer/auto_trainer.py | 10 +++++- paddlenlp/trainer/ckpt_converter.py | 55 +++++++++++++++-------------- 2 files changed, 37 insertions(+), 28 deletions(-) diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index c7a171ec0b30..576d50402363 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -723,9 +723,17 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): OPTIMIZER_NAME: optim_state_dict, } + parameter_to_structured_name = {} + if self.args.to_static: + parameter_to_structured_name = self.model_wrapped._parameter_to_structured_name + else: + for state_dict_name, sub_state_dict in state_dict.items(): + for state_name, state_value in sub_state_dict.items(): + parameter_to_structured_name[state_value.name] = state_name + if self.args.resume_form_hybrid_parallel: CheckpointConverter( - resume_from_checkpoint, state_dict, self.model_wrapped._parameter_to_structured_name + resume_from_checkpoint, state_dict, parameter_to_structured_name ).load_from_hybrid_parallel_checkpoint() else: ckpt_path = os.path.join(resume_from_checkpoint, DIST_CKPT_PATH) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index b1ffdf02cb0b..84cb95833f2b 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -15,6 +15,7 @@ import json import os import re +from collections import OrderedDict from functools import reduce import paddle @@ -36,6 +37,27 @@ MODEL_META_FILE_NAME = "model_meta.json" +OPTIMIZER_STATE_NAME_SUFFIX_MAPPING = { + "_fp32_master_1_moment1_0": ".w_0_moment1_0", + "_fp32_master_1_moment2_0": ".w_0_moment2_0", + "_fp32_master_1_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", + "_fp32_master_1_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", + "_fp32_master_1": ".w_0", + "_moment1_0": ".w_0_moment1_0", + "_moment2_0": ".w_0_moment2_0", + "_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", + "_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", + ".w_0_fp32_master_0_moment1_0": ".w_0_moment1_0", + ".w_0_fp32_master_0_moment2_0": ".w_0_moment2_0", + ".w_0_fp32_master_0_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", + ".w_0_fp32_master_0_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", +} + +OPTIMIZER_STATE_NAME_SUFFIX_MAPPING = OrderedDict( + sorted(OPTIMIZER_STATE_NAME_SUFFIX_MAPPING.items(), key=lambda x: len(x[0]), reverse=True) +) + + class CheckpointConverter: def __init__(self, hybrid_parallel_ckpt_path, model_state, parameter_to_structured_name): self.use_dist = True if paddle.distributed.get_world_size() > 1 else False @@ -67,9 +89,9 @@ def load_from_hybrid_parallel_checkpoint(self): The main logic is as follows: 1. Call rename_semi_auto_state_dict: Rename the keys of the auto parallel state_dict according to certain rules. (Why rename? To facilitate the subsequent correspondence between the optimizer state names of the semi-automatic and static optimizers.) - 2. Callgen_metadata_and_prepare_source_state_dict: Automatically parse the manual checkpoint file based on the state_dict information + 2. Call gen_metadata_and_prepare_source_state_dict: Automatically parse the manual checkpoint file based on the state_dict information provided by auto parallel, obtaining the Metadata and state_dict required for auto parallel to load the checkpoint. - 3. Callload_state_dict: Automatically reshard and load. + 3. Call load_state_dict: Automatically reshard and load. 4. Special logic adaptation: In the save_sharded_model mode, the weights are obtained through the master_weight cast in the checkpoint. """ self.rename_auto_parallel_state_dict() @@ -108,18 +130,6 @@ def rename_auto_parallel_state_dict(self): 1. Rename the suffixes of the optimizer states to a unified format: adamw_optimizer_status_name_suffix_mappings """ - adamw_optimizer_state_name_suffix_mappings = { - "_fp32_master_1_moment1_0": ".w_0_moment1_0", - "_fp32_master_1_moment2_0": ".w_0_moment2_0", - "_fp32_master_1_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", - "_fp32_master_1_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", - "_fp32_master_1": ".w_0", - "_moment1_0": ".w_0_moment1_0", - "_moment2_0": ".w_0_moment2_0", - "_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", - "_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", - } - def rename(old_name, map1, map2): for i in range(1, len(old_name)): str1 = old_name[:i] @@ -137,7 +147,7 @@ def rename(old_name, map1, map2): if key in self.parameter_to_structured_name.values(): new_name = key else: - new_name = rename(key, self.parameter_to_structured_name, adamw_optimizer_state_name_suffix_mappings) + new_name = rename(key, self.parameter_to_structured_name, OPTIMIZER_STATE_NAME_SUFFIX_MAPPING) assert new_name is not None renamed_state_dict[new_name] = value @@ -148,7 +158,7 @@ def gen_metadata_and_prepare_source_state_dict(self): """ Automatically parse the manual checkpoint file based on the state_dict information provided by auto parallel, obtaining the Metadata and state_dict required for auto parallel to load the checkpoint: - 1. Callload_state_dict_and_rename: Parse the distributed information from the names of the checkpoint files, and evenly parse out the distributed + 1. Call load_state_dict_and_rename: Parse the distributed information from the names of the checkpoint files, and evenly parse out the distributed information for each weight/optimizer state into self.global_sharded_tensor_infos(data structure:param_name -> [{tp_rank: 1, sharding_rank: 1}, shape, dtype, file_name]). Modify the names of the optimizer states in the form ofparameter+suffixand record them in self.cur_rank_loaded_state_dict(data structure:file_name -> renamed_state_dict). 2. Construct the Metadata and state_dict based on the distributed information obtained in the previous step for the final load. @@ -339,7 +349,6 @@ def gen_metadata_and_prepare_source_state_dict(self): meta_data = Metadata(state_dict_metadata, storage_metadata, None) source_state_dict = {fake_file_name: concat_optimier_state_dict} - return meta_data, source_state_dict elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: @@ -519,21 +528,13 @@ def load_state_dict_and_rename(self): # In sharding stage3, ‘@slice’ will be added in front of the key for master_weight, which is removed here. state_dict[master_weight_name.replace("slice@", "")] = master_weight_value - # Standardize the state names of the AdamW optimizer. - adamw_opt_state_suffix_name_mapping = { - ".w_0_fp32_master_0_moment1_0": ".w_0_moment1_0", - ".w_0_fp32_master_0_moment2_0": ".w_0_moment2_0", - ".w_0_fp32_master_0_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", - ".w_0_fp32_master_0_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", - } - unified_name_state_dict = {} for opt_state_name, opt_state_value in state_dict.items(): new_opt_state_name = opt_state_name - for suffix in adamw_opt_state_suffix_name_mapping: + for suffix in OPTIMIZER_STATE_NAME_SUFFIX_MAPPING: if opt_state_name.endswith(suffix): new_opt_state_name = opt_state_name.replace( - suffix, adamw_opt_state_suffix_name_mapping[suffix] + suffix, OPTIMIZER_STATE_NAME_SUFFIX_MAPPING[suffix] ) break unified_name_state_dict[new_opt_state_name] = opt_state_value From db34712e86070fb21046396bec6d112fc0c4d3f9 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Tue, 13 Aug 2024 11:46:40 +0800 Subject: [PATCH 19/30] fix hang --- paddlenlp/trainer/ckpt_converter.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index 84cb95833f2b..c7c2e60f93cf 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -19,7 +19,6 @@ from functools import reduce import paddle -import paddle.distributed as dist from paddle.distributed.checkpoint.load_state_dict import ( _load_state_dict, get_local_load_files, @@ -109,10 +108,13 @@ def load_from_hybrid_parallel_checkpoint(self): master_weight = param_name + ".w_0" if master_weight not in self.auto_parallel_state_dict: appended_master_weight_names.append(master_weight) - tmp_tensor = paddle.zeros(param_value.shape, "float32") - self.auto_parallel_state_dict[master_weight] = dist.shard_tensor( - tmp_tensor, param_value.process_mesh, param_value.placements - ) + tmp_tensor = paddle.zeros(param_value._local_value().shape, "float32") + with paddle.base.dygraph.guard(): + self.auto_parallel_state_dict[ + master_weight + ] = paddle.distributed.auto_parallel.api.dtensor_from_local( + tmp_tensor, param_value.process_mesh, param_value.placements + ) _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) for param_name, param_value in model_params.items(): From d74ae0dad051e643d6e77442e5e4b25d735866c2 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Tue, 13 Aug 2024 11:47:43 +0800 Subject: [PATCH 20/30] fix hang --- paddlenlp/trainer/ckpt_converter.py | 36 +++++++++++++++-------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index c7c2e60f93cf..81042f80e01c 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -36,24 +36,26 @@ MODEL_META_FILE_NAME = "model_meta.json" -OPTIMIZER_STATE_NAME_SUFFIX_MAPPING = { - "_fp32_master_1_moment1_0": ".w_0_moment1_0", - "_fp32_master_1_moment2_0": ".w_0_moment2_0", - "_fp32_master_1_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", - "_fp32_master_1_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", - "_fp32_master_1": ".w_0", - "_moment1_0": ".w_0_moment1_0", - "_moment2_0": ".w_0_moment2_0", - "_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", - "_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", - ".w_0_fp32_master_0_moment1_0": ".w_0_moment1_0", - ".w_0_fp32_master_0_moment2_0": ".w_0_moment2_0", - ".w_0_fp32_master_0_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", - ".w_0_fp32_master_0_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", -} - OPTIMIZER_STATE_NAME_SUFFIX_MAPPING = OrderedDict( - sorted(OPTIMIZER_STATE_NAME_SUFFIX_MAPPING.items(), key=lambda x: len(x[0]), reverse=True) + sorted( + { + "_fp32_master_1_moment1_0": ".w_0_moment1_0", + "_fp32_master_1_moment2_0": ".w_0_moment2_0", + "_fp32_master_1_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", + "_fp32_master_1_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", + "_fp32_master_1": ".w_0", + "_moment1_0": ".w_0_moment1_0", + "_moment2_0": ".w_0_moment2_0", + "_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", + "_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", + ".w_0_fp32_master_0_moment1_0": ".w_0_moment1_0", + ".w_0_fp32_master_0_moment2_0": ".w_0_moment2_0", + ".w_0_fp32_master_0_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", + ".w_0_fp32_master_0_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", + }.items(), + key=lambda x: len(x[0]), + reverse=True, + ) ) From 877e7c41c5cb15356a5fc9fc0950a406ff1b089a Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Tue, 13 Aug 2024 15:11:55 +0800 Subject: [PATCH 21/30] add loger --- paddlenlp/trainer/ckpt_converter.py | 45 +++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index 81042f80e01c..5391883712a4 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -29,13 +29,13 @@ Metadata, ) from paddle.distributed.checkpoint.utils import flatten_state_dict +from paddle.distributed.fleet.utils.log_util import logger MODEL_WEIGHT_SUFFIX = ".pdparams" OPTIMIZER_WEIGHT_SUFFIX = ".pdopt" SCHEDULER_NAME = "scheduler.pdparams" MODEL_META_FILE_NAME = "model_meta.json" - OPTIMIZER_STATE_NAME_SUFFIX_MAPPING = OrderedDict( sorted( { @@ -83,6 +83,9 @@ def __init__(self, hybrid_parallel_ckpt_path, model_state, parameter_to_structur self.global_optimizer_state_file_names = self.gather_global_object(self.cur_rank_optimizer_state_file_names) self.initial_distributed_configuration() + logger.debug( + f"The current checkpoint’s distributed strategy is tp{self.tp_degree}, pp{self.pp_degree}, sharding{self.sharding_degree}" + ) def load_from_hybrid_parallel_checkpoint(self): """ @@ -97,6 +100,8 @@ def load_from_hybrid_parallel_checkpoint(self): """ self.rename_auto_parallel_state_dict() metadata, source_state_dict = self.gen_metadata_and_prepare_source_state_dict() + logger.info("Generated the checkpoint’s metadata.") + logger.debug(f"The checkpoint's metadata is {metadata}.") if self.save_sharded_model: model_params = {} for state_name, state_value in self.auto_parallel_state_dict.items(): @@ -105,6 +110,7 @@ def load_from_hybrid_parallel_checkpoint(self): for param_name in model_params.keys(): self.auto_parallel_state_dict.pop(param_name) + logger.info("Requesting GPU memory space to load master_weights.") appended_master_weight_names = [] for param_name, param_value in model_params.items(): master_weight = param_name + ".w_0" @@ -118,7 +124,10 @@ def load_from_hybrid_parallel_checkpoint(self): tmp_tensor, param_value.process_mesh, param_value.placements ) + logger.info("Calling _load_state_dict to load the required weights.") _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) + logger.info("Calling _load_state_dict completed, restored the required weights.") + for param_name, param_value in model_params.items(): master_weight = self.auto_parallel_state_dict[param_name + ".w_0"] cast_master_weight = paddle.cast(master_weight._local_value(), "bfloat16") @@ -126,7 +135,9 @@ def load_from_hybrid_parallel_checkpoint(self): for master_weight_name in appended_master_weight_names: self.auto_parallel_state_dict.pop(master_weight_name) else: + logger.info("Calling _load_state_dict to load the required weights.") _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) + logger.info("Calling _load_state_dict completed, restored the required weights.") def rename_auto_parallel_state_dict(self): """ @@ -172,6 +183,7 @@ def gen_metadata_and_prepare_source_state_dict(self): * Reshape the optimizer states back to the shape of the weights. """ self.load_state_dict_and_rename() + logger.info("Complete the loading and renaming of state_dict.") if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: for state_name, shard_info in self.global_sharded_tensor_infos.items(): shard_info.sort(key=lambda x: x[0]["sharding_rank"]) @@ -195,6 +207,8 @@ def gen_metadata_and_prepare_source_state_dict(self): metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) + logger.debug(f"The metadata for merge sharding is: {metadata_for_merge_sharding}") + source_state_dict_for_merge_sharding = {} for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): renamed_state_dict = {} @@ -219,7 +233,6 @@ def gen_metadata_and_prepare_source_state_dict(self): # Distribute all model parameters evenly across each card for loading world_size = paddle.distributed.get_world_size() - partition_mapping = self.partition_parameters(global_model_state_shapes, True, world_size) partition_model_state_keys = [] @@ -239,8 +252,8 @@ def gen_metadata_and_prepare_source_state_dict(self): param_flattened_shapes[param_name] = reduce(lambda x, y: x * y, param_shape_and_dtype[0]) cur_rank_need_load_model_state_keys = partition_model_state_keys[self.cur_rank] - # Generate the optimizer states corresponding to the model weights. + logger.info("Requesting GPU memory space to concatenate tensors split by sharding1 v2.") optimizer_state_dict = {} for key in cur_rank_need_load_model_state_keys: for tp_rank in range(self.tp_degree): @@ -261,8 +274,16 @@ def gen_metadata_and_prepare_source_state_dict(self): optimizer_state_dict[key + ".w_0_beta1_pow_acc_0" + tp_rank_suffix] = paddle.zeros((1,), "float32") optimizer_state_dict[key + ".w_0_beta2_pow_acc_0" + tp_rank_suffix] = paddle.zeros((1,), "float32") + malloc_size = 0 + for opt_state_name, opt_state_value in optimizer_state_dict.items(): + malloc_size += opt_state_value.numel() * opt_state_value.element_size() + malloc_size = malloc_size.numpy() / 2**20 + logger.debug(f"{malloc_size} MB of GPU memory were allocated.") + # merge sharding + logger.info("First call _load_state_dict to stitch back the tensors split by sharding1 v2.") _load_state_dict(optimizer_state_dict, source_state_dict_for_merge_sharding, [metadata_for_merge_sharding]) + logger.info("Completed the call _load_state_dict, concating back the tensors split by sharding.") # Reshape for opt_state_name, opt_state_value in optimizer_state_dict.items(): @@ -512,6 +533,7 @@ def load_state_dict_and_rename(self): ) need_read_files = get_local_load_files(self.gather_global_object(rank_access_files)) + logger.info(f"The file(s) to be loaded for the current rank are: {need_read_files}") self.cur_rank_loaded_state_dict = {} for file in need_read_files: @@ -545,6 +567,16 @@ def load_state_dict_and_rename(self): self.cur_rank_loaded_state_dict[file] = unified_name_state_dict + memory_size = 0 + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + for k, v in state_dict.items(): + memory_size += v.numel() * v.element_size() + + memory_size = memory_size.numpy() / 2**20 + logger.debug( + f"The current rank has finished loading the checkpoint file and has allocated {memory_size} MB of GPU memory." + ) + # After the rank has finished loading the files it needs, it can infer sharding_stage1_v and is_sharding_stage3. self.sharding_stage1_v = self.infer_sharding_stage1_v() self.is_sharding_stage3 = self.infer_is_sharding_stage3() @@ -553,6 +585,7 @@ def load_state_dict_and_rename(self): # The threshold for determining whether to slice is segment_size, with a default value of 2**20. # However, sharding stage3 allows users to specify their own unsliced layers, which seems to be incompatible here. if self.is_sharding_stage3: + logger.info("The currently loaded checkpoint file comes from sharding stage 3.") segment_size = 2**20 for file, state_dict in self.cur_rank_loaded_state_dict.items(): if file.endswith(MODEL_WEIGHT_SUFFIX): @@ -573,8 +606,10 @@ def load_state_dict_and_rename(self): # rename and record sharded_tensor_info cur_rank_sharded_tensor_infos = {} + logger.info(f"save_sharded_model is {self.save_sharded_model}.") # 1. Handling the sharding stage1 v2 scenario, where the save_sharded_model flag must be enabled, independent of master_weights. if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: + logger.info("The currently loaded checkpoint file comes from sharding stage1 v2.") assert self.save_sharded_model for file, state_dict in self.cur_rank_loaded_state_dict.items(): # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, @@ -585,6 +620,7 @@ def load_state_dict_and_rename(self): # 2. In handling the sharding stage1 v1 and stage2 scenario, the optimizer states are distributed across different ranks. # We need to obtain the name mapping by simulating the partitioning method, without concern for the presence of master_weights. elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: + logger.info("The currently loaded checkpoint file comes from sharding stage1/2 v1.") if not self.save_sharded_model: file_to_state_dict_shapes_mapping = {} for file, state_dict in self.cur_rank_loaded_state_dict.items(): @@ -673,6 +709,7 @@ def load_state_dict_and_rename(self): self.cur_rank_loaded_state_dict[file] = renamed_state_dict else: # 3. Handling the sharding stage3 and non-sharding scenario + logger.info("The current checkpoint comes from either sharding stage 3 or non-sharding.") if not self.save_sharded_model: for file, state_dict in self.cur_rank_loaded_state_dict.items(): if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): @@ -700,6 +737,8 @@ def load_state_dict_and_rename(self): else: self.global_sharded_tensor_infos[state_name] += shard_info + logger.info(f"global_sharded_tensor_infos: {self.global_sharded_tensor_infos}") + def get_sharded_tensor_infos(self, file, state_dict, cur_rank_sharded_tensor_infos): (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) for state_name, state_value in state_dict.items(): From e59a8b7c42e9df563633ee35a5b4a5a33d555773 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Fri, 16 Aug 2024 16:02:18 +0800 Subject: [PATCH 22/30] fix bug --- paddlenlp/trainer/auto_trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 576d50402363..69c3f358f247 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -727,9 +727,8 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): if self.args.to_static: parameter_to_structured_name = self.model_wrapped._parameter_to_structured_name else: - for state_dict_name, sub_state_dict in state_dict.items(): - for state_name, state_value in sub_state_dict.items(): - parameter_to_structured_name[state_value.name] = state_name + for state_name, state_value in self.model_wrapped.state_dict().items(): + parameter_to_structured_name[state_value.name] = state_name if self.args.resume_form_hybrid_parallel: CheckpointConverter( From 906066568030aaebc1603accce9317c98d7ee724 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Mon, 19 Aug 2024 19:18:02 +0800 Subject: [PATCH 23/30] fix rename --- paddlenlp/trainer/ckpt_converter.py | 369 +++++++++++++++------------- 1 file changed, 202 insertions(+), 167 deletions(-) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index 5391883712a4..b196b3661836 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -15,13 +15,12 @@ import json import os import re -from collections import OrderedDict from functools import reduce import paddle from paddle.distributed.checkpoint.load_state_dict import ( _load_state_dict, - get_local_load_files, + get_rank_to_read_files, ) from paddle.distributed.checkpoint.metadata import ( LocalTensorIndex, @@ -35,35 +34,14 @@ OPTIMIZER_WEIGHT_SUFFIX = ".pdopt" SCHEDULER_NAME = "scheduler.pdparams" MODEL_META_FILE_NAME = "model_meta.json" - -OPTIMIZER_STATE_NAME_SUFFIX_MAPPING = OrderedDict( - sorted( - { - "_fp32_master_1_moment1_0": ".w_0_moment1_0", - "_fp32_master_1_moment2_0": ".w_0_moment2_0", - "_fp32_master_1_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", - "_fp32_master_1_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", - "_fp32_master_1": ".w_0", - "_moment1_0": ".w_0_moment1_0", - "_moment2_0": ".w_0_moment2_0", - "_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", - "_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", - ".w_0_fp32_master_0_moment1_0": ".w_0_moment1_0", - ".w_0_fp32_master_0_moment2_0": ".w_0_moment2_0", - ".w_0_fp32_master_0_beta1_pow_acc_0": ".w_0_beta1_pow_acc_0", - ".w_0_fp32_master_0_beta2_pow_acc_0": ".w_0_beta2_pow_acc_0", - }.items(), - key=lambda x: len(x[0]), - reverse=True, - ) -) +OPTIMIZER_STATE_NAME_SUFFIX = [".moment1", ".moment2", ".beta1_pow_acc", ".beta2_pow_acc", ".master_weight"] class CheckpointConverter: - def __init__(self, hybrid_parallel_ckpt_path, model_state, parameter_to_structured_name): + def __init__(self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structured_name, patch_dict=None): self.use_dist = True if paddle.distributed.get_world_size() > 1 else False self.path = hybrid_parallel_ckpt_path - self.auto_parallel_state_dict = self.flatten_state_dict(model_state) + self.auto_parallel_state_dict = self.flatten_state_dict(state_dict) self.parameter_to_structured_name = self.gather_global_object(parameter_to_structured_name) model_state_global_shape = {} for k, v in self.auto_parallel_state_dict.items(): @@ -87,6 +65,19 @@ def __init__(self, hybrid_parallel_ckpt_path, model_state, parameter_to_structur f"The current checkpoint’s distributed strategy is tp{self.tp_degree}, pp{self.pp_degree}, sharding{self.sharding_degree}" ) + self.patch_dict = patch_dict + for k, v in self.parameter_to_structured_name.items(): + if v in self.patch_dict: + self.parameter_to_structured_name[k] = self.patch_dict[v] + + del_keys = [] + for k, v in self.auto_parallel_state_dict.items(): + if k in self.patch_dict: + del_keys.append(k) + self.auto_parallel_state_dict[self.patch_dict[k]] = v + for k in del_keys: + self.auto_parallel_state_dict.pop(k) + def load_from_hybrid_parallel_checkpoint(self): """ Automatically and inplace load the distributed checkpoint stored in hybrid parallel mode into the auto parallel state_dict. @@ -99,10 +90,11 @@ def load_from_hybrid_parallel_checkpoint(self): 4. Special logic adaptation: In the save_sharded_model mode, the weights are obtained through the master_weight cast in the checkpoint. """ self.rename_auto_parallel_state_dict() + metadata, source_state_dict = self.gen_metadata_and_prepare_source_state_dict() logger.info("Generated the checkpoint’s metadata.") logger.debug(f"The checkpoint's metadata is {metadata}.") - if self.save_sharded_model: + if self.save_sharded_model and False: model_params = {} for state_name, state_value in self.auto_parallel_state_dict.items(): if state_name in self.parameter_to_structured_name.values(): @@ -113,62 +105,58 @@ def load_from_hybrid_parallel_checkpoint(self): logger.info("Requesting GPU memory space to load master_weights.") appended_master_weight_names = [] for param_name, param_value in model_params.items(): - master_weight = param_name + ".w_0" + master_weight = param_name + ".master_weight" if master_weight not in self.auto_parallel_state_dict: appended_master_weight_names.append(master_weight) - tmp_tensor = paddle.zeros(param_value._local_value().shape, "float32") + if param_value.is_dist(): + param_shape = param_value._local_value().shape + else: + param_shape = param_value.shape + + tmp_tensor = paddle.zeros(param_shape, dtype="float32") with paddle.base.dygraph.guard(): - self.auto_parallel_state_dict[ - master_weight - ] = paddle.distributed.auto_parallel.api.dtensor_from_local( - tmp_tensor, param_value.process_mesh, param_value.placements - ) + if param_value.is_dist(): + self.auto_parallel_state_dict[ + master_weight + ] = paddle.distributed.auto_parallel.api.dtensor_from_local( + tmp_tensor, param_value.process_mesh, param_value.placements + ) + else: + self.auto_parallel_state_dict[master_weight] = tmp_tensor logger.info("Calling _load_state_dict to load the required weights.") + state_dict_in_cpu = [] + for k, v in self.auto_parallel_state_dict.items(): + if v.place.is_cpu_place(): + print("v is cpu place", flush=1) + state_dict_in_cpu.append(k) + self.auto_parallel_state_dict[k] = v.cuda() _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) + for k, v in self.auto_parallel_state_dict.items(): + if k in state_dict_in_cpu: + v = v.cpu() logger.info("Calling _load_state_dict completed, restored the required weights.") for param_name, param_value in model_params.items(): - master_weight = self.auto_parallel_state_dict[param_name + ".w_0"] + master_weight = self.auto_parallel_state_dict[param_name + ".master_weight"] cast_master_weight = paddle.cast(master_weight._local_value(), "bfloat16") paddle.assign(cast_master_weight, param_value._local_value()) for master_weight_name in appended_master_weight_names: self.auto_parallel_state_dict.pop(master_weight_name) else: logger.info("Calling _load_state_dict to load the required weights.") + state_dict_in_cpu = [] + for k, v in self.auto_parallel_state_dict.items(): + if v.place.is_cpu_place(): + print("v is cpu place", flush=1) + state_dict_in_cpu.append(k) + self.auto_parallel_state_dict[k] = v.cuda() _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) + for k, v in self.auto_parallel_state_dict.items(): + if k in state_dict_in_cpu: + v = v.cpu() logger.info("Calling _load_state_dict completed, restored the required weights.") - def rename_auto_parallel_state_dict(self): - """ - Rename the keys of the auto parallel state_dict according to certain rules: - 1. Rename the suffixes of the optimizer states to a unified format: adamw_optimizer_status_name_suffix_mappings - """ - - def rename(old_name, map1, map2): - for i in range(1, len(old_name)): - str1 = old_name[:i] - str2 = old_name[i:] - if (str1 in map1) and (str2 in map2): - transformed_str1 = map1[str1] - transformed_str2 = map2[str2] - return transformed_str1 + transformed_str2 - return None - - renamed_state_dict = {} - - for key, value in self.auto_parallel_state_dict.items(): - - if key in self.parameter_to_structured_name.values(): - new_name = key - else: - new_name = rename(key, self.parameter_to_structured_name, OPTIMIZER_STATE_NAME_SUFFIX_MAPPING) - - assert new_name is not None - renamed_state_dict[new_name] = value - - self.auto_parallel_state_dict = renamed_state_dict - def gen_metadata_and_prepare_source_state_dict(self): """ Automatically parse the manual checkpoint file based on the state_dict information provided by auto parallel, @@ -258,21 +246,21 @@ def gen_metadata_and_prepare_source_state_dict(self): for key in cur_rank_need_load_model_state_keys: for tp_rank in range(self.tp_degree): tp_rank_suffix = "_tp{:02d}".format(tp_rank) - optimizer_state_dict[key + ".w_0_moment1_0" + tp_rank_suffix] = paddle.zeros( + optimizer_state_dict[key + ".moment1" + tp_rank_suffix] = paddle.zeros( (param_flattened_shapes[key],), "float32" ) - optimizer_state_dict[key + ".w_0_moment2_0" + tp_rank_suffix] = paddle.zeros( + optimizer_state_dict[key + ".moment2" + tp_rank_suffix] = paddle.zeros( (param_flattened_shapes[key],), "float32" ) if self.optimizer_state_with_master_weights: - optimizer_state_dict[key + ".w_0" + tp_rank_suffix] = paddle.zeros( + optimizer_state_dict[key + ".master_weight" + tp_rank_suffix] = paddle.zeros( (param_flattened_shapes[key],), "float32" ) # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. # Later, when these are compared with the global shape, we realize that they are replicated. - optimizer_state_dict[key + ".w_0_beta1_pow_acc_0" + tp_rank_suffix] = paddle.zeros((1,), "float32") - optimizer_state_dict[key + ".w_0_beta2_pow_acc_0" + tp_rank_suffix] = paddle.zeros((1,), "float32") + optimizer_state_dict[key + ".beta1_pow_acc" + tp_rank_suffix] = paddle.zeros((1,), "float32") + optimizer_state_dict[key + ".beta2_pow_acc" + tp_rank_suffix] = paddle.zeros((1,), "float32") malloc_size = 0 for opt_state_name, opt_state_value in optimizer_state_dict.items(): @@ -311,7 +299,10 @@ def gen_metadata_and_prepare_source_state_dict(self): for opt_state_name_removed_tp_rank, opt_state_name in optimizer_state_key_to_tp_keys.items(): model_state_name = self.optimizer_key_to_model_state_key(opt_state_name_removed_tp_rank) local_shape = optimizer_state_dict[opt_state_name[0]].shape - if "_pow_acc_0" not in key: + if ( + ".beta1_pow_acc" not in opt_state_name_removed_tp_rank + and ".beta2_pow_acc" not in opt_state_name_removed_tp_rank + ): global_shape = self.model_state_global_shape[model_state_name] else: global_shape = (1,) @@ -431,15 +422,15 @@ def gen_metadata_and_prepare_source_state_dict(self): shape = item[1] flatten_shape = reduce(lambda a, b: a * b, item[1]) target_state_dict[key] = paddle.zeros(shape, dtype) - target_state_dict[key + ".w_0_moment1_0"] = paddle.zeros((flatten_shape,), "float32") - target_state_dict[key + ".w_0_moment2_0"] = paddle.zeros((flatten_shape,), "float32") + target_state_dict[key + ".moment1"] = paddle.zeros((flatten_shape,), "float32") + target_state_dict[key + ".moment2"] = paddle.zeros((flatten_shape,), "float32") if self.optimizer_state_with_master_weights: - target_state_dict[key + ".w_0"] = paddle.zeros((flatten_shape,), "float32") + target_state_dict[key + ".master_weight"] = paddle.zeros((flatten_shape,), "float32") # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. # Later, when these are compared with the global shape, we realize that they are replicated. - target_state_dict[key + ".w_0_beta1_pow_acc_0"] = paddle.zeros((1,), "float32") - target_state_dict[key + ".w_0_beta2_pow_acc_0"] = paddle.zeros((1,), "float32") + target_state_dict[key + ".beta1_pow_acc"] = paddle.zeros((1,), "float32") + target_state_dict[key + ".beta2_pow_acc"] = paddle.zeros((1,), "float32") _load_state_dict(target_state_dict, self.cur_rank_loaded_state_dict, [metadata_for_merge_sharding]) @@ -526,13 +517,16 @@ def load_state_dict_and_rename(self): """ rank_access_files = {} if self.save_sharded_model: - rank_access_files[self.cur_rank] = self.cur_rank_optimizer_state_file_names + rank_access_files[self.cur_rank] = ( + self.cur_rank_model_state_file_names + self.cur_rank_optimizer_state_file_names + ) else: rank_access_files[self.cur_rank] = ( self.cur_rank_model_state_file_names + self.cur_rank_optimizer_state_file_names ) - need_read_files = get_local_load_files(self.gather_global_object(rank_access_files)) + global_rank_access_files = self.gather_global_object(rank_access_files) + need_read_files = get_rank_to_read_files(global_rank_access_files, global_rank_access_files) logger.info(f"The file(s) to be loaded for the current rank are: {need_read_files}") self.cur_rank_loaded_state_dict = {} @@ -552,20 +546,9 @@ def load_state_dict_and_rename(self): file_to_master_weights_keys[file] = list(master_weights.keys()) for master_weight_name, master_weight_value in master_weights.items(): # In sharding stage3, ‘@slice’ will be added in front of the key for master_weight, which is removed here. - state_dict[master_weight_name.replace("slice@", "")] = master_weight_value - - unified_name_state_dict = {} - for opt_state_name, opt_state_value in state_dict.items(): - new_opt_state_name = opt_state_name - for suffix in OPTIMIZER_STATE_NAME_SUFFIX_MAPPING: - if opt_state_name.endswith(suffix): - new_opt_state_name = opt_state_name.replace( - suffix, OPTIMIZER_STATE_NAME_SUFFIX_MAPPING[suffix] - ) - break - unified_name_state_dict[new_opt_state_name] = opt_state_value + state_dict[master_weight_name.replace("slice@", "") + ".master_weight"] = master_weight_value - self.cur_rank_loaded_state_dict[file] = unified_name_state_dict + self.cur_rank_loaded_state_dict[file] = state_dict memory_size = 0 for file, state_dict in self.cur_rank_loaded_state_dict.items(): @@ -677,24 +660,36 @@ def load_state_dict_and_rename(self): else: partition_result = partition_result_0 - master_weight_name_to_model_weight_name_mapping = {} + name_mapping = {} for i in range(len(sharding_optimizer_state_shards)): state_shard = sharding_optimizer_state_shards[i][0] partitioned_shard = partition_result[i] - for j in range(len(partitioned_shard)): - master_weight_name = self.parse_master_weight_name_by(state_shard[j][0]) - master_weight_name_to_model_weight_name_mapping[ - master_weight_name - ] = partitioned_shard[j][0] + suffix_bucket = {} + for suffix in OPTIMIZER_STATE_NAME_SUFFIX: + suffix_bucket[suffix] = [] + for j in range(len(state_shard)): + optimizer_state_name = state_shard[j][0] + if "moment1" in optimizer_state_name: + suffix_bucket[".moment1"].append(optimizer_state_name) + elif "moment2" in optimizer_state_name: + suffix_bucket[".moment2"].append(optimizer_state_name) + elif "beta1_pow_acc" in optimizer_state_name: + suffix_bucket[".beta1_pow_acc"].append(optimizer_state_name) + elif "beta2_pow_acc" in optimizer_state_name: + suffix_bucket[".beta2_pow_acc"].append(optimizer_state_name) + else: + suffix_bucket[".master_weight"].append(optimizer_state_name) + + for suffix, old_names in suffix_bucket.items(): + assert len(old_names) == len(partitioned_shard) + for k in range(len(old_names)): + name_mapping[old_names[k]] = partitioned_shard[k][0] + suffix renamed_state_dict = {} # In this branch, sharding does not split the optimizer states; it merely relocates them to different cards. # Therefore, the sharding information can now be directly removed. for opt_state_name, opt_state_value in state_dict.items(): - master_weight_name = self.parse_master_weight_name_by(opt_state_name) - model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] - new_opt_state_name = opt_state_name.replace(master_weight_name, model_weight_name) - renamed_state_dict[new_opt_state_name] = opt_state_value + renamed_state_dict[name_mapping[opt_state_name]] = opt_state_value self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) @@ -709,15 +704,25 @@ def load_state_dict_and_rename(self): self.cur_rank_loaded_state_dict[file] = renamed_state_dict else: # 3. Handling the sharding stage3 and non-sharding scenario + if not hasattr(self, "global_file_to_state_dict_keys_mapping"): + file_to_state_dict_keys_mapping = {} + for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): + file_to_state_dict_keys_mapping[file_name] = list(state_dict.keys()) + + self.global_file_to_state_dict_keys_mapping = self.gather_global_object(file_to_state_dict_keys_mapping) + logger.info("The current checkpoint comes from either sharding stage 3 or non-sharding.") if not self.save_sharded_model: - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): - renamed_state_dict = self.rename_using_optimizer_state_order(file) + for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): + if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): + model_state_file_name = self.get_model_state_file_from(file_name) + assert model_state_file_name is not None + model_state_keys = self.global_file_to_state_dict_keys_mapping[model_state_file_name] + renamed_state_dict = self.rename_using_optimizer_state_order(model_state_keys, state_dict) self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) - self.cur_rank_loaded_state_dict[file] = renamed_state_dict + self.cur_rank_loaded_state_dict[file_name] = renamed_state_dict else: - self.get_sharded_tensor_infos(file, state_dict, cur_rank_sharded_tensor_infos) + self.get_sharded_tensor_infos(file_name, state_dict, cur_rank_sharded_tensor_infos) else: for file, state_dict in self.cur_rank_loaded_state_dict.items(): @@ -726,9 +731,9 @@ def load_state_dict_and_rename(self): renamed_state_dict = self.rename_using_model_meta(file) self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) self.cur_rank_loaded_state_dict[file] = renamed_state_dict + # gather global sharded tensor infos sharded_tensor_infos = self.gather_global_object({self.cur_rank: cur_rank_sharded_tensor_infos}) - self.global_sharded_tensor_infos = {} for rank, sharded_tensor_info in sharded_tensor_infos.items(): for state_name, shard_info in sharded_tensor_info.items(): @@ -771,12 +776,16 @@ def gen_metadata_for_tp_sharded_tensor(self): state_dict_metadata = {} storage_metadata = {} + # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. for state_name, shard_info in self.global_sharded_tensor_infos.items(): + global_offset = 0 local_shape = shard_info[0][1] + model_state_name = self.optimizer_key_to_model_state_key(state_name) - if "_pow_acc_0" not in state_name: + + if ".beta1_pow_acc" not in state_name and ".beta2_pow_acc" not in state_name: global_shape = self.model_state_global_shape[model_state_name] else: global_shape = (1,) @@ -830,20 +839,90 @@ def rename_using_model_meta(self, file_name): # Map model weight names to their corresponding names of master_weights in the optimizer state. if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"] - master_weight_name_to_model_weight_name_mapping = {} - for model_weight_name, master_weight_name in structure_name_mapping.items(): - master_weight_name_to_model_weight_name_mapping[master_weight_name.split(".")[0]] = model_weight_name - - renamed_state_dict = {} + parameter_to_structured_name = {} + for k, v in structure_name_mapping.items(): + parameter_to_structured_name[v] = k state_dict = self.cur_rank_loaded_state_dict[file_name] - for opt_state_name, opt_state_value in state_dict.items(): - master_weight_name = self.parse_master_weight_name_by(opt_state_name) - model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] - renamed_state_dict[opt_state_name.replace(master_weight_name, model_weight_name)] = opt_state_value - return renamed_state_dict + return self.rename_using_parameter_to_structured_name_mapping(state_dict, parameter_to_structured_name) else: return self.cur_rank_loaded_state_dict[file_name] + def rename_auto_parallel_state_dict(self): + """ + Rename the keys of the auto parallel state_dict according to certain rules: + 1. Rename the suffixes of the optimizer states to a unified format: adamw_optimizer_status_name_suffix_mappings + """ + self.auto_parallel_state_dict = self.rename_using_parameter_to_structured_name_mapping( + self.auto_parallel_state_dict, self.parameter_to_structured_name + ) + + def rename_using_parameter_to_structured_name_mapping(self, state_dict, parameter_to_structured_name): + renamed_state_dict = {} + + print("parameter_to_structured_name: ", parameter_to_structured_name, flush=1) + + def rename(old_name, parameter_to_structured_name): + for i in range(1, len(old_name) + 1): + param_name = old_name[:i] # param_name + suffix = old_name[i:] # suffix + if param_name in parameter_to_structured_name: + print("param_name: ", param_name, flush=1) + structure_name = parameter_to_structured_name[param_name] + print("structure_name: ", structure_name, flush=1) + if "moment1" in suffix: + return structure_name + ".moment1" + elif "moment2" in suffix: + return structure_name + ".moment2" + elif "beta1_pow_acc" in suffix: + return structure_name + ".beta1_pow_acc" + elif "beta2_pow_acc" in suffix: + return structure_name + ".beta2_pow_acc" + else: + return structure_name + ".master_weight" + return None + + for key, value in state_dict.items(): + print("key: ", key, flush=1) + if key in parameter_to_structured_name.values(): + print("key is a param") + new_name = key + else: + print("key is a opt", flush=1) + new_name = rename(key, parameter_to_structured_name) + assert new_name is not None + renamed_state_dict[new_name] = value + + return renamed_state_dict + + def rename_using_optimizer_state_order(self, model_state_keys, optimizer_state_dict): + + name_mapping = {} + suffix_bucket = {} + assert len(optimizer_state_dict) % len(model_state_keys) == 0 + for suffix in OPTIMIZER_STATE_NAME_SUFFIX: + suffix_bucket[suffix] = [] + for satte_name, satte_value in optimizer_state_dict.items(): + if "moment1" in satte_name: + suffix_bucket[".moment1"].append(satte_name) + elif "moment2" in satte_name: + suffix_bucket[".moment2"].append(satte_name) + elif "beta1_pow_acc" in satte_name: + suffix_bucket[".beta1_pow_acc"].append(satte_name) + elif "beta2_pow_acc" in satte_name: + suffix_bucket[".beta2_pow_acc"].append(satte_name) + else: + suffix_bucket[".master_weight"].append(satte_name) + + for suffix, old_names in suffix_bucket.items(): + assert len(old_names) == len(model_state_keys) + for i in range(len(old_names)): + name_mapping[old_names[i]] = model_state_keys[i] + suffix + + renamed_state_dict = {} + for k, v in optimizer_state_dict.items(): + renamed_state_dict[name_mapping[k]] = v + return renamed_state_dict + def partition_parameters(self, model_state_shapes, is_sort, shard_num): """ In sharding_stage3 and sharding_stage1_v1, parameters and optimizer states will be assigned to different ranks. This function defines the allocation rules. @@ -868,44 +947,6 @@ def partition_parameters(self, model_state_shapes, is_sort, shard_num): return mapping - def rename_using_optimizer_state_order(self, file_name): - """ - Rename the keys in opt_state_dict based on the following rule: The order of weights recorded in the weight file is consistent with the order of optimizer states recorded in the optimizer file. - By using this order, we can obtain the correspondence between the names of weights and optimizer states and rename the optimizer accordingly. For example: - * model_state: linear0, linear1 - * opt_state: param0.w0, param1.w0 - * Renamed opt_state: linear0.w0, linear1.w0 - NOTE:The reason for renaming is that there is a difference in the naming of optimizer parameters between dynamic and static partitions, making it difficult to match optimizer parameters directly by name. - Therefore, we unify them to the weight names. - """ - if not hasattr(self, "global_file_to_state_dict_keys_mapping"): - file_to_state_dict_keys_mapping = {} - for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): - file_to_state_dict_keys_mapping[file_name] = list(state_dict.keys()) - - self.global_file_to_state_dict_keys_mapping = self.gather_global_object(file_to_state_dict_keys_mapping) - - if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): - model_state_file_name = self.get_model_state_file_from(file_name) - assert model_state_file_name is not None - model_state_keys = self.global_file_to_state_dict_keys_mapping[model_state_file_name] - optimizer_state_keys = self.global_file_to_state_dict_keys_mapping[file_name] - - master_weight_name_to_model_weight_name_mapping = {} - for i in range(len(model_state_keys)): - master_weight_name = self.parse_master_weight_name_by(optimizer_state_keys[i]) - master_weight_name_to_model_weight_name_mapping[master_weight_name] = model_state_keys[i] - - state_dict = self.cur_rank_loaded_state_dict[file_name] - renamed_state_dict = {} - for opt_state_name, opt_state_value in state_dict.items(): - master_weight_name = self.parse_master_weight_name_by(opt_state_name) - model_weight_name = master_weight_name_to_model_weight_name_mapping[master_weight_name] - renamed_state_dict[opt_state_name.replace(master_weight_name, model_weight_name)] = opt_state_value - return renamed_state_dict - else: - return self.cur_rank_loaded_state_dict[file_name] - def get_save_sharded_model_flag(self): save_sharded_model_flag = self.gather_global_object( [os.path.exists(os.path.join(self.path, MODEL_META_FILE_NAME))] @@ -929,8 +970,12 @@ def gather_global_object(self, cur_rank_object): all_rank_objects = [all_rank_objects] if isinstance(cur_rank_object, list): + for obj in all_rank_objects: + assert isinstance(obj, list) return [item for sublist in all_rank_objects for item in sublist] elif isinstance(cur_rank_object, dict): + for obj in all_rank_objects: + assert isinstance(obj, dict) global_map = {} for rank_map in all_rank_objects: global_map.update(rank_map) @@ -1029,9 +1074,6 @@ def infer_is_sharding_stage3(self): break return is_sharding_stage3 - def parse_master_weight_name_by(self, optimizer_state_name): - return optimizer_state_name.split(".")[0] - def get_model_state_file_from(self, optimizer_state_file_name): (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(optimizer_state_file_name) for model_state_file in self.global_model_state_file_names: @@ -1041,15 +1083,8 @@ def get_model_state_file_from(self, optimizer_state_file_name): return None def optimizer_key_to_model_state_key(self, optimizer_key): - adamw_optimizer_key_suffix = [ - ".w_0_beta1_pow_acc_0", - ".w_0_beta2_pow_acc_0", - ".w_0_moment1_0", - ".w_0_moment2_0", - ".w_0", - ] model_state_key = optimizer_key - for suffix in adamw_optimizer_key_suffix: + for suffix in OPTIMIZER_STATE_NAME_SUFFIX: if model_state_key.endswith(suffix): # Remove the suffix from model_state_key model_state_key = model_state_key[: -len(suffix)] From 05bc0906b6a7f0b14773dfe726eed1aae2ec6b16 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Mon, 19 Aug 2024 19:21:46 +0800 Subject: [PATCH 24/30] fix rename --- paddlenlp/trainer/ckpt_converter.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index b196b3661836..f169fc5b0bfa 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -128,7 +128,6 @@ def load_from_hybrid_parallel_checkpoint(self): state_dict_in_cpu = [] for k, v in self.auto_parallel_state_dict.items(): if v.place.is_cpu_place(): - print("v is cpu place", flush=1) state_dict_in_cpu.append(k) self.auto_parallel_state_dict[k] = v.cuda() _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) @@ -148,7 +147,6 @@ def load_from_hybrid_parallel_checkpoint(self): state_dict_in_cpu = [] for k, v in self.auto_parallel_state_dict.items(): if v.place.is_cpu_place(): - print("v is cpu place", flush=1) state_dict_in_cpu.append(k) self.auto_parallel_state_dict[k] = v.cuda() _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) @@ -859,16 +857,12 @@ def rename_auto_parallel_state_dict(self): def rename_using_parameter_to_structured_name_mapping(self, state_dict, parameter_to_structured_name): renamed_state_dict = {} - print("parameter_to_structured_name: ", parameter_to_structured_name, flush=1) - def rename(old_name, parameter_to_structured_name): for i in range(1, len(old_name) + 1): param_name = old_name[:i] # param_name suffix = old_name[i:] # suffix if param_name in parameter_to_structured_name: - print("param_name: ", param_name, flush=1) structure_name = parameter_to_structured_name[param_name] - print("structure_name: ", structure_name, flush=1) if "moment1" in suffix: return structure_name + ".moment1" elif "moment2" in suffix: @@ -882,12 +876,9 @@ def rename(old_name, parameter_to_structured_name): return None for key, value in state_dict.items(): - print("key: ", key, flush=1) if key in parameter_to_structured_name.values(): - print("key is a param") new_name = key else: - print("key is a opt", flush=1) new_name = rename(key, parameter_to_structured_name) assert new_name is not None renamed_state_dict[new_name] = value From 03240687da8376d8099f8576a526dd92353fc0cb Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Tue, 20 Aug 2024 15:19:25 +0800 Subject: [PATCH 25/30] fix rename --- paddlenlp/trainer/ckpt_converter.py | 173 +++++++++++++++++----------- 1 file changed, 108 insertions(+), 65 deletions(-) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index f169fc5b0bfa..e9ea2bed8552 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -35,6 +35,7 @@ SCHEDULER_NAME = "scheduler.pdparams" MODEL_META_FILE_NAME = "model_meta.json" OPTIMIZER_STATE_NAME_SUFFIX = [".moment1", ".moment2", ".beta1_pow_acc", ".beta2_pow_acc", ".master_weight"] +MODEL_STATE_FILE_MIN_SIZE = 512 class CheckpointConverter: @@ -49,8 +50,6 @@ def __init__(self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structure self.model_state_global_shape = self.gather_global_object(model_state_global_shape) self.cur_rank = paddle.distributed.get_rank() - self.save_sharded_model = self.get_save_sharded_model_flag() - ( self.cur_rank_model_state_file_names, self.cur_rank_optimizer_state_file_names, @@ -60,23 +59,34 @@ def __init__(self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structure self.global_optimizer_state_file_names = self.gather_global_object(self.cur_rank_optimizer_state_file_names) + self.is_model_meta_exists = self.get_is_model_meta_exists_flag() + self.is_model_state_stored = self.get_is_model_state_stored_flag() + self.initial_distributed_configuration() - logger.debug( - f"The current checkpoint’s distributed strategy is tp{self.tp_degree}, pp{self.pp_degree}, sharding{self.sharding_degree}" - ) - self.patch_dict = patch_dict - for k, v in self.parameter_to_structured_name.items(): - if v in self.patch_dict: - self.parameter_to_structured_name[k] = self.patch_dict[v] + if patch_dict is not None: + self.patch_dict = patch_dict + for k, v in self.parameter_to_structured_name.items(): + if v in self.patch_dict: + self.parameter_to_structured_name[k] = self.patch_dict[v] - del_keys = [] - for k, v in self.auto_parallel_state_dict.items(): - if k in self.patch_dict: - del_keys.append(k) - self.auto_parallel_state_dict[self.patch_dict[k]] = v - for k in del_keys: - self.auto_parallel_state_dict.pop(k) + del_keys = [] + for k, v in self.auto_parallel_state_dict.items(): + if k in self.patch_dict: + del_keys.append(k) + + for k in del_keys: + self.auto_parallel_state_dict[self.patch_dict[k]] = self.auto_parallel_state_dict[k] + self.auto_parallel_state_dict.pop(k) + + flags = [ + ["tp degree", self.tp_degree], + ["pp degree", self.pp_degree], + ["sharding degree", self.sharding_degree], + ["is model_meta exists", self.is_model_meta_exists], + ["is model_state stored", self.is_model_state_stored], + ] + self.print_checkpoint_file_info(flags) def load_from_hybrid_parallel_checkpoint(self): """ @@ -94,7 +104,8 @@ def load_from_hybrid_parallel_checkpoint(self): metadata, source_state_dict = self.gen_metadata_and_prepare_source_state_dict() logger.info("Generated the checkpoint’s metadata.") logger.debug(f"The checkpoint's metadata is {metadata}.") - if self.save_sharded_model and False: + if not self.is_model_state_stored: + assert self.optimizer_state_with_master_weights model_params = {} for state_name, state_value in self.auto_parallel_state_dict.items(): if state_name in self.parameter_to_structured_name.values(): @@ -125,35 +136,25 @@ def load_from_hybrid_parallel_checkpoint(self): self.auto_parallel_state_dict[master_weight] = tmp_tensor logger.info("Calling _load_state_dict to load the required weights.") - state_dict_in_cpu = [] - for k, v in self.auto_parallel_state_dict.items(): - if v.place.is_cpu_place(): - state_dict_in_cpu.append(k) - self.auto_parallel_state_dict[k] = v.cuda() _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) - for k, v in self.auto_parallel_state_dict.items(): - if k in state_dict_in_cpu: - v = v.cpu() logger.info("Calling _load_state_dict completed, restored the required weights.") for param_name, param_value in model_params.items(): - master_weight = self.auto_parallel_state_dict[param_name + ".master_weight"] - cast_master_weight = paddle.cast(master_weight._local_value(), "bfloat16") - paddle.assign(cast_master_weight, param_value._local_value()) + if param_value.is_dist(): + master_weight = self.auto_parallel_state_dict[param_name + ".master_weight"] + cast_master_weight = paddle.cast(master_weight._local_value(), param_value.dtype) + paddle.assign(cast_master_weight, param_value._local_value()) + else: + master_weight = self.auto_parallel_state_dict[param_name + ".master_weight"] + cast_master_weight = paddle.cast(master_weight, param_value.dtype) + paddle.assign(cast_master_weight, param_value) for master_weight_name in appended_master_weight_names: self.auto_parallel_state_dict.pop(master_weight_name) else: logger.info("Calling _load_state_dict to load the required weights.") - state_dict_in_cpu = [] - for k, v in self.auto_parallel_state_dict.items(): - if v.place.is_cpu_place(): - state_dict_in_cpu.append(k) - self.auto_parallel_state_dict[k] = v.cuda() _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) - for k, v in self.auto_parallel_state_dict.items(): - if k in state_dict_in_cpu: - v = v.cpu() logger.info("Calling _load_state_dict completed, restored the required weights.") + logger.info("Successfully loaded hybrid_parallel checkpoint!") def gen_metadata_and_prepare_source_state_dict(self): """ @@ -514,14 +515,12 @@ def load_state_dict_and_rename(self): * Distributed information extraction: Record the distributed information of parameters: name -> [{tp_rank, sharding_rank}, shape, dtype, file_name]. """ rank_access_files = {} - if self.save_sharded_model: + if self.is_model_state_stored: rank_access_files[self.cur_rank] = ( self.cur_rank_model_state_file_names + self.cur_rank_optimizer_state_file_names ) else: - rank_access_files[self.cur_rank] = ( - self.cur_rank_model_state_file_names + self.cur_rank_optimizer_state_file_names - ) + rank_access_files[self.cur_rank] = self.cur_rank_optimizer_state_file_names global_rank_access_files = self.gather_global_object(rank_access_files) need_read_files = get_rank_to_read_files(global_rank_access_files, global_rank_access_files) @@ -531,8 +530,6 @@ def load_state_dict_and_rename(self): for file in need_read_files: self.cur_rank_loaded_state_dict[file] = paddle.load(os.path.join(self.path, file)) - file_to_master_weights_keys = {} - self.optimizer_state_with_master_weights = False for file, state_dict in self.cur_rank_loaded_state_dict.items(): @@ -541,7 +538,6 @@ def load_state_dict_and_rename(self): if "master_weights" in state_dict: self.optimizer_state_with_master_weights = True master_weights = state_dict.pop("master_weights") - file_to_master_weights_keys[file] = list(master_weights.keys()) for master_weight_name, master_weight_value in master_weights.items(): # In sharding stage3, ‘@slice’ will be added in front of the key for master_weight, which is removed here. state_dict[master_weight_name.replace("slice@", "") + ".master_weight"] = master_weight_value @@ -562,6 +558,14 @@ def load_state_dict_and_rename(self): self.sharding_stage1_v = self.infer_sharding_stage1_v() self.is_sharding_stage3 = self.infer_is_sharding_stage3() + flags = [ + ["is sharding stage1/2", (not self.is_sharding_stage3) and self.sharding_degree > 1], + ["sharding stage1/2 version", self.sharding_stage1_v], + ["is sharding stage3", self.is_sharding_stage3], + ["master_weight", self.optimizer_state_with_master_weights], + ] + self.print_checkpoint_file_info(flags) + # In sharding stage3, the parameters need to be reordered based on whether they are sliced. # The threshold for determining whether to slice is segment_size, with a default value of 2**20. # However, sharding stage3 allows users to specify their own unsliced layers, which seems to be incompatible here. @@ -570,28 +574,25 @@ def load_state_dict_and_rename(self): segment_size = 2**20 for file, state_dict in self.cur_rank_loaded_state_dict.items(): if file.endswith(MODEL_WEIGHT_SUFFIX): - sliced_pramaeters = [] - unseliced_pramaeters = [] + sliced_prameters = [] + unsliced_parameters = [] sorted_state_dict = {} for k, v in state_dict.items(): if v.numel() > segment_size: - sliced_pramaeters.append(k) + sliced_prameters.append(k) else: - unseliced_pramaeters.append(k) - for k in sliced_pramaeters + unseliced_pramaeters: + unsliced_parameters.append(k) + for k in sliced_prameters + unsliced_parameters: sorted_state_dict[k] = state_dict.pop(k) self.cur_rank_loaded_state_dict[file] = sorted_state_dict - self.global_file_to_master_weights_keys = self.gather_global_object(file_to_master_weights_keys) - # rename and record sharded_tensor_info cur_rank_sharded_tensor_infos = {} - logger.info(f"save_sharded_model is {self.save_sharded_model}.") # 1. Handling the sharding stage1 v2 scenario, where the save_sharded_model flag must be enabled, independent of master_weights. if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: logger.info("The currently loaded checkpoint file comes from sharding stage1 v2.") - assert self.save_sharded_model + assert self.is_model_meta_exists for file, state_dict in self.cur_rank_loaded_state_dict.items(): # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, # and then append the tp_degree. @@ -602,7 +603,7 @@ def load_state_dict_and_rename(self): # We need to obtain the name mapping by simulating the partitioning method, without concern for the presence of master_weights. elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: logger.info("The currently loaded checkpoint file comes from sharding stage1/2 v1.") - if not self.save_sharded_model: + if not self.is_model_meta_exists: file_to_state_dict_shapes_mapping = {} for file, state_dict in self.cur_rank_loaded_state_dict.items(): shapes = [] @@ -678,6 +679,19 @@ def load_state_dict_and_rename(self): else: suffix_bucket[".master_weight"].append(optimizer_state_name) + # In this scenario, the order of master_weights might differ from the order of the regular optimizer states and needs to be reordered. + if len(suffix_bucket[".master_weight"]) != 0: + master_weight_keys = [] + for master_weight_key in suffix_bucket[".master_weight"]: + for index in range(len(state_shard)): + if master_weight_key[: -len(".master_weight")] in state_shard[index][0]: + # Find the first match + master_weight_keys.append([master_weight_key, index]) + break + + master_weight_keys = sorted(master_weight_keys, key=lambda x: x[1]) + suffix_bucket[".master_weight"] = [x[0] for x in master_weight_keys] + for suffix, old_names in suffix_bucket.items(): assert len(old_names) == len(partitioned_shard) for k in range(len(old_names)): @@ -702,26 +716,24 @@ def load_state_dict_and_rename(self): self.cur_rank_loaded_state_dict[file] = renamed_state_dict else: # 3. Handling the sharding stage3 and non-sharding scenario - if not hasattr(self, "global_file_to_state_dict_keys_mapping"): - file_to_state_dict_keys_mapping = {} + + file_to_state_dict_keys_mapping = {} for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): file_to_state_dict_keys_mapping[file_name] = list(state_dict.keys()) - - self.global_file_to_state_dict_keys_mapping = self.gather_global_object(file_to_state_dict_keys_mapping) + global_file_to_state_dict_keys_mapping = self.gather_global_object(file_to_state_dict_keys_mapping) logger.info("The current checkpoint comes from either sharding stage 3 or non-sharding.") - if not self.save_sharded_model: + if not self.is_model_meta_exists: for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): model_state_file_name = self.get_model_state_file_from(file_name) assert model_state_file_name is not None - model_state_keys = self.global_file_to_state_dict_keys_mapping[model_state_file_name] + model_state_keys = global_file_to_state_dict_keys_mapping[model_state_file_name] renamed_state_dict = self.rename_using_optimizer_state_order(model_state_keys, state_dict) self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) self.cur_rank_loaded_state_dict[file_name] = renamed_state_dict else: self.get_sharded_tensor_infos(file_name, state_dict, cur_rank_sharded_tensor_infos) - else: for file, state_dict in self.cur_rank_loaded_state_dict.items(): # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, @@ -739,8 +751,7 @@ def load_state_dict_and_rename(self): self.global_sharded_tensor_infos[state_name] = shard_info else: self.global_sharded_tensor_infos[state_name] += shard_info - - logger.info(f"global_sharded_tensor_infos: {self.global_sharded_tensor_infos}") + logger.debug(f"global_sharded_tensor_infos: {self.global_sharded_tensor_infos}") def get_sharded_tensor_infos(self, file, state_dict, cur_rank_sharded_tensor_infos): (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) @@ -782,12 +793,10 @@ def gen_metadata_for_tp_sharded_tensor(self): local_shape = shard_info[0][1] model_state_name = self.optimizer_key_to_model_state_key(state_name) - if ".beta1_pow_acc" not in state_name and ".beta2_pow_acc" not in state_name: global_shape = self.model_state_global_shape[model_state_name] else: global_shape = (1,) - assert len(local_shape) == len(global_shape) axis = -1 for i in range(len(local_shape)): @@ -938,12 +947,28 @@ def partition_parameters(self, model_state_shapes, is_sort, shard_num): return mapping - def get_save_sharded_model_flag(self): + def get_is_model_meta_exists_flag(self): save_sharded_model_flag = self.gather_global_object( [os.path.exists(os.path.join(self.path, MODEL_META_FILE_NAME))] ) return True in save_sharded_model_flag + def get_is_model_state_stored_flag(self): + if len(self.global_model_state_file_names) == 0: + return False + model_state_file_name = self.global_model_state_file_names[0] + file_readable = model_state_file_name in self.cur_rank_model_state_file_names + file_readables = self.gather_global_object([file_readable]) + coordinator_rank = file_readables.index(True) + is_model_state_stored = False + if self.cur_rank == coordinator_rank: + model_state_file_size = os.path.getsize(os.path.join(self.path, model_state_file_name)) + if model_state_file_size > MODEL_STATE_FILE_MIN_SIZE: + is_model_state_stored = True + + is_model_state_stored_flags = self.gather_global_object([is_model_state_stored]) + return True in is_model_state_stored_flags + def flatten_state_dict(self, state_dict): flattened_state_dict = {} flat_state_dict, mapping = flatten_state_dict(state_dict) @@ -1081,3 +1106,21 @@ def optimizer_key_to_model_state_key(self, optimizer_key): model_state_key = model_state_key[: -len(suffix)] break return model_state_key + + def print_checkpoint_file_info(self, flags): + processed_flags = [ + [str(item) if not isinstance(item, bool) else "True" if item else "False" for item in row] for row in flags + ] + + logger.info("Checkpoint file info:") + headers = ["Flag", "Value"] + col_widths = [max(len(str(item)) for item in column) for column in zip(headers, *flags)] + format_str = "| " + " | ".join(f"{{:<{width}}}" for width in col_widths) + " |" + separator_line = "+-" + "-+-".join("-" * width for width in col_widths) + "-+" + + logger.info(separator_line) + logger.info(format_str.format(*headers)) + logger.info(separator_line) + for row in processed_flags: + logger.info(format_str.format(*row)) + logger.info(separator_line) From 07a299bbb478ad87eb47669f5467a6ea4985e967 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Tue, 20 Aug 2024 15:40:24 +0800 Subject: [PATCH 26/30] fix rename --- paddlenlp/trainer/ckpt_converter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index e9ea2bed8552..c7d12139ca10 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -139,14 +139,15 @@ def load_from_hybrid_parallel_checkpoint(self): _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) logger.info("Calling _load_state_dict completed, restored the required weights.") + # In this scenario, the data type of the model state is bfloat16. for param_name, param_value in model_params.items(): if param_value.is_dist(): master_weight = self.auto_parallel_state_dict[param_name + ".master_weight"] - cast_master_weight = paddle.cast(master_weight._local_value(), param_value.dtype) + cast_master_weight = paddle.cast(master_weight._local_value(), "bfloat16") paddle.assign(cast_master_weight, param_value._local_value()) else: master_weight = self.auto_parallel_state_dict[param_name + ".master_weight"] - cast_master_weight = paddle.cast(master_weight, param_value.dtype) + cast_master_weight = paddle.cast(master_weight, "bfloat16") paddle.assign(cast_master_weight, param_value) for master_weight_name in appended_master_weight_names: self.auto_parallel_state_dict.pop(master_weight_name) From 62f16b240b1c46d2406e9d90189eb842e5cb3f16 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Tue, 20 Aug 2024 15:43:27 +0800 Subject: [PATCH 27/30] fix rename --- paddlenlp/trainer/ckpt_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index c7d12139ca10..b1d1ae275507 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -155,7 +155,7 @@ def load_from_hybrid_parallel_checkpoint(self): logger.info("Calling _load_state_dict to load the required weights.") _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) logger.info("Calling _load_state_dict completed, restored the required weights.") - logger.info("Successfully loaded hybrid_parallel checkpoint!") + logger.info("Successfully loaded hybrid_parallel checkpoint!") def gen_metadata_and_prepare_source_state_dict(self): """ From 07a31f08a3a3a83a006b300184e51cf57d8ed47f Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Tue, 20 Aug 2024 15:55:47 +0800 Subject: [PATCH 28/30] fix rename --- paddlenlp/trainer/ckpt_converter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py index b1d1ae275507..88b5b7c474df 100644 --- a/paddlenlp/trainer/ckpt_converter.py +++ b/paddlenlp/trainer/ckpt_converter.py @@ -143,11 +143,11 @@ def load_from_hybrid_parallel_checkpoint(self): for param_name, param_value in model_params.items(): if param_value.is_dist(): master_weight = self.auto_parallel_state_dict[param_name + ".master_weight"] - cast_master_weight = paddle.cast(master_weight._local_value(), "bfloat16") + cast_master_weight = paddle.cast(master_weight._local_value(), param_value.dtype) paddle.assign(cast_master_weight, param_value._local_value()) else: master_weight = self.auto_parallel_state_dict[param_name + ".master_weight"] - cast_master_weight = paddle.cast(master_weight, "bfloat16") + cast_master_weight = paddle.cast(master_weight, param_value.dtype) paddle.assign(cast_master_weight, param_value) for master_weight_name in appended_master_weight_names: self.auto_parallel_state_dict.pop(master_weight_name) From af73d74ff0a298566478f5d5f75efe87e3e57ef5 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Wed, 21 Aug 2024 15:43:35 +0800 Subject: [PATCH 29/30] fix --- paddlenlp/trainer/auto_trainer.py | 4 +- paddlenlp/trainer/ckpt_converter.py | 1127 --------------------------- paddlenlp/trainer/training_args.py | 4 +- 3 files changed, 4 insertions(+), 1131 deletions(-) delete mode 100644 paddlenlp/trainer/ckpt_converter.py diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 69c3f358f247..19513050674e 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -28,7 +28,6 @@ from ..utils.log import logger from .argparser import strtobool -from .ckpt_converter import CheckpointConverter from .trainer import SCALER_NAME, SCHEDULER_NAME, TRAINER_STATE_NAME, TRAINING_ARGS_NAME from .trainer_callback import TrainerState from .trainer_utils import ( # set_hyrbid_parallel_seed, @@ -40,6 +39,7 @@ has_length, speed_metrics, ) +from .utils.ckpt_converter import CheckpointConverter from .utils.helper import distributed_file, distributed_isfile # nested_truncate, try: @@ -730,7 +730,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): for state_name, state_value in self.model_wrapped.state_dict().items(): parameter_to_structured_name[state_value.name] = state_name - if self.args.resume_form_hybrid_parallel: + if self.args.auto_parallel_resume_form_hybrid_parallel: CheckpointConverter( resume_from_checkpoint, state_dict, parameter_to_structured_name ).load_from_hybrid_parallel_checkpoint() diff --git a/paddlenlp/trainer/ckpt_converter.py b/paddlenlp/trainer/ckpt_converter.py deleted file mode 100644 index 88b5b7c474df..000000000000 --- a/paddlenlp/trainer/ckpt_converter.py +++ /dev/null @@ -1,1127 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -import re -from functools import reduce - -import paddle -from paddle.distributed.checkpoint.load_state_dict import ( - _load_state_dict, - get_rank_to_read_files, -) -from paddle.distributed.checkpoint.metadata import ( - LocalTensorIndex, - LocalTensorMetadata, - Metadata, -) -from paddle.distributed.checkpoint.utils import flatten_state_dict -from paddle.distributed.fleet.utils.log_util import logger - -MODEL_WEIGHT_SUFFIX = ".pdparams" -OPTIMIZER_WEIGHT_SUFFIX = ".pdopt" -SCHEDULER_NAME = "scheduler.pdparams" -MODEL_META_FILE_NAME = "model_meta.json" -OPTIMIZER_STATE_NAME_SUFFIX = [".moment1", ".moment2", ".beta1_pow_acc", ".beta2_pow_acc", ".master_weight"] -MODEL_STATE_FILE_MIN_SIZE = 512 - - -class CheckpointConverter: - def __init__(self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structured_name, patch_dict=None): - self.use_dist = True if paddle.distributed.get_world_size() > 1 else False - self.path = hybrid_parallel_ckpt_path - self.auto_parallel_state_dict = self.flatten_state_dict(state_dict) - self.parameter_to_structured_name = self.gather_global_object(parameter_to_structured_name) - model_state_global_shape = {} - for k, v in self.auto_parallel_state_dict.items(): - model_state_global_shape[k] = v.shape - self.model_state_global_shape = self.gather_global_object(model_state_global_shape) - self.cur_rank = paddle.distributed.get_rank() - - ( - self.cur_rank_model_state_file_names, - self.cur_rank_optimizer_state_file_names, - ) = self.get_local_checkpoint_file_names() - - self.global_model_state_file_names = self.gather_global_object(self.cur_rank_model_state_file_names) - - self.global_optimizer_state_file_names = self.gather_global_object(self.cur_rank_optimizer_state_file_names) - - self.is_model_meta_exists = self.get_is_model_meta_exists_flag() - self.is_model_state_stored = self.get_is_model_state_stored_flag() - - self.initial_distributed_configuration() - - if patch_dict is not None: - self.patch_dict = patch_dict - for k, v in self.parameter_to_structured_name.items(): - if v in self.patch_dict: - self.parameter_to_structured_name[k] = self.patch_dict[v] - - del_keys = [] - for k, v in self.auto_parallel_state_dict.items(): - if k in self.patch_dict: - del_keys.append(k) - - for k in del_keys: - self.auto_parallel_state_dict[self.patch_dict[k]] = self.auto_parallel_state_dict[k] - self.auto_parallel_state_dict.pop(k) - - flags = [ - ["tp degree", self.tp_degree], - ["pp degree", self.pp_degree], - ["sharding degree", self.sharding_degree], - ["is model_meta exists", self.is_model_meta_exists], - ["is model_state stored", self.is_model_state_stored], - ] - self.print_checkpoint_file_info(flags) - - def load_from_hybrid_parallel_checkpoint(self): - """ - Automatically and inplace load the distributed checkpoint stored in hybrid parallel mode into the auto parallel state_dict. - The main logic is as follows: - 1. Call rename_semi_auto_state_dict: Rename the keys of the auto parallel state_dict according to certain rules. - (Why rename? To facilitate the subsequent correspondence between the optimizer state names of the semi-automatic and static optimizers.) - 2. Call gen_metadata_and_prepare_source_state_dict: Automatically parse the manual checkpoint file based on the state_dict information - provided by auto parallel, obtaining the Metadata and state_dict required for auto parallel to load the checkpoint. - 3. Call load_state_dict: Automatically reshard and load. - 4. Special logic adaptation: In the save_sharded_model mode, the weights are obtained through the master_weight cast in the checkpoint. - """ - self.rename_auto_parallel_state_dict() - - metadata, source_state_dict = self.gen_metadata_and_prepare_source_state_dict() - logger.info("Generated the checkpoint’s metadata.") - logger.debug(f"The checkpoint's metadata is {metadata}.") - if not self.is_model_state_stored: - assert self.optimizer_state_with_master_weights - model_params = {} - for state_name, state_value in self.auto_parallel_state_dict.items(): - if state_name in self.parameter_to_structured_name.values(): - model_params[state_name] = state_value - for param_name in model_params.keys(): - self.auto_parallel_state_dict.pop(param_name) - - logger.info("Requesting GPU memory space to load master_weights.") - appended_master_weight_names = [] - for param_name, param_value in model_params.items(): - master_weight = param_name + ".master_weight" - if master_weight not in self.auto_parallel_state_dict: - appended_master_weight_names.append(master_weight) - if param_value.is_dist(): - param_shape = param_value._local_value().shape - else: - param_shape = param_value.shape - - tmp_tensor = paddle.zeros(param_shape, dtype="float32") - with paddle.base.dygraph.guard(): - if param_value.is_dist(): - self.auto_parallel_state_dict[ - master_weight - ] = paddle.distributed.auto_parallel.api.dtensor_from_local( - tmp_tensor, param_value.process_mesh, param_value.placements - ) - else: - self.auto_parallel_state_dict[master_weight] = tmp_tensor - - logger.info("Calling _load_state_dict to load the required weights.") - _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) - logger.info("Calling _load_state_dict completed, restored the required weights.") - - # In this scenario, the data type of the model state is bfloat16. - for param_name, param_value in model_params.items(): - if param_value.is_dist(): - master_weight = self.auto_parallel_state_dict[param_name + ".master_weight"] - cast_master_weight = paddle.cast(master_weight._local_value(), param_value.dtype) - paddle.assign(cast_master_weight, param_value._local_value()) - else: - master_weight = self.auto_parallel_state_dict[param_name + ".master_weight"] - cast_master_weight = paddle.cast(master_weight, param_value.dtype) - paddle.assign(cast_master_weight, param_value) - for master_weight_name in appended_master_weight_names: - self.auto_parallel_state_dict.pop(master_weight_name) - else: - logger.info("Calling _load_state_dict to load the required weights.") - _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) - logger.info("Calling _load_state_dict completed, restored the required weights.") - logger.info("Successfully loaded hybrid_parallel checkpoint!") - - def gen_metadata_and_prepare_source_state_dict(self): - """ - Automatically parse the manual checkpoint file based on the state_dict information provided by auto parallel, - obtaining the Metadata and state_dict required for auto parallel to load the checkpoint: - 1. Call load_state_dict_and_rename: Parse the distributed information from the names of the checkpoint files, and evenly parse out the distributed - information for each weight/optimizer state into self.global_sharded_tensor_infos(data structure:param_name -> [{tp_rank: 1, sharding_rank: 1}, shape, dtype, file_name]). - Modify the names of the optimizer states in the form ofparameter+suffixand record them in self.cur_rank_loaded_state_dict(data structure:file_name -> renamed_state_dict). - 2. Construct the Metadata and state_dict based on the distributed information obtained in the previous step for the final load. - 3. Special logic adaptation: When sharding is enabled, the optimizer states are also split. In this step, the optimizer states need to be concatenated back according to the sharding dimension: - * Construct the Metadata for concatenating the sharded states back based on the characteristics of sharding. - * Construct a temporaryopt_state_dictand use the_load_state_dictinterface to obtain the state_dict with the sharded states concatenated back. - * Reshape the optimizer states back to the shape of the weights. - """ - self.load_state_dict_and_rename() - logger.info("Complete the loading and renaming of state_dict.") - if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: - for state_name, shard_info in self.global_sharded_tensor_infos.items(): - shard_info.sort(key=lambda x: x[0]["sharding_rank"]) - - state_dict_metadata = {} - storage_metadata = {} - # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. - for state_name, shard_info in self.global_sharded_tensor_infos.items(): - global_offset = [0] * self.tp_degree - for item in shard_info: - tp_rank = item[0]["tp_rank"] - state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank) - local_tensor_meta_data = LocalTensorMetadata((global_offset[tp_rank],), item[1], item[2]) - local_tensor_index = LocalTensorIndex(state_name_with_tp_rank, (global_offset[tp_rank],)) - global_offset[tp_rank] += item[1][0] - if state_name_with_tp_rank not in state_dict_metadata: - state_dict_metadata[state_name_with_tp_rank] = [local_tensor_meta_data] - else: - state_dict_metadata[state_name_with_tp_rank].append(local_tensor_meta_data) - storage_metadata[local_tensor_index] = item[3] - - metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) - - logger.debug(f"The metadata for merge sharding is: {metadata_for_merge_sharding}") - - source_state_dict_for_merge_sharding = {} - for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): - renamed_state_dict = {} - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) - for state_name, state_value in state_dict.items(): - state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank) - renamed_state_dict[state_name_with_tp_rank] = state_value - - source_state_dict_for_merge_sharding[file_name] = renamed_state_dict - - assert self.model_meta is not None - global_model_state_shapes = [] - sharding_metas_keys = [] - for i in range(self.pp_degree): - for j in range(self.tp_degree): - sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j)) - for key in sharding_metas_keys: - param_meta = self.model_meta["sharding_metas"][key]["param_meta"] - for param_name, param_shape_and_dtype in param_meta.items(): - global_model_state_shapes.append([param_name, param_shape_and_dtype[0]]) - - # Distribute all model parameters evenly across each card for loading - - world_size = paddle.distributed.get_world_size() - partition_mapping = self.partition_parameters(global_model_state_shapes, True, world_size) - - partition_model_state_keys = [] - for cur_rank, partition_model_state in partition_mapping.items(): - partition_model_state_keys.append([item[0] for item in partition_model_state]) - - all_param_meta = {} - for i in range(self.tp_degree): - for j in range(self.pp_degree): - key = "tp{:02d}_pp{:02d}".format(i, j) - param_meta = self.model_meta["sharding_metas"][key]["param_meta"] - for param_name, param_shape_and_dtype in param_meta.items(): - all_param_meta[param_name] = param_shape_and_dtype - - param_flattened_shapes = {} - for param_name, param_shape_and_dtype in all_param_meta.items(): - param_flattened_shapes[param_name] = reduce(lambda x, y: x * y, param_shape_and_dtype[0]) - - cur_rank_need_load_model_state_keys = partition_model_state_keys[self.cur_rank] - # Generate the optimizer states corresponding to the model weights. - logger.info("Requesting GPU memory space to concatenate tensors split by sharding1 v2.") - optimizer_state_dict = {} - for key in cur_rank_need_load_model_state_keys: - for tp_rank in range(self.tp_degree): - tp_rank_suffix = "_tp{:02d}".format(tp_rank) - optimizer_state_dict[key + ".moment1" + tp_rank_suffix] = paddle.zeros( - (param_flattened_shapes[key],), "float32" - ) - optimizer_state_dict[key + ".moment2" + tp_rank_suffix] = paddle.zeros( - (param_flattened_shapes[key],), "float32" - ) - if self.optimizer_state_with_master_weights: - optimizer_state_dict[key + ".master_weight" + tp_rank_suffix] = paddle.zeros( - (param_flattened_shapes[key],), "float32" - ) - # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. - # Later, when these are compared with the global shape, we realize that they are replicated. - - optimizer_state_dict[key + ".beta1_pow_acc" + tp_rank_suffix] = paddle.zeros((1,), "float32") - optimizer_state_dict[key + ".beta2_pow_acc" + tp_rank_suffix] = paddle.zeros((1,), "float32") - - malloc_size = 0 - for opt_state_name, opt_state_value in optimizer_state_dict.items(): - malloc_size += opt_state_value.numel() * opt_state_value.element_size() - malloc_size = malloc_size.numpy() / 2**20 - logger.debug(f"{malloc_size} MB of GPU memory were allocated.") - - # merge sharding - logger.info("First call _load_state_dict to stitch back the tensors split by sharding1 v2.") - _load_state_dict(optimizer_state_dict, source_state_dict_for_merge_sharding, [metadata_for_merge_sharding]) - logger.info("Completed the call _load_state_dict, concating back the tensors split by sharding.") - - # Reshape - for opt_state_name, opt_state_value in optimizer_state_dict.items(): - if opt_state_value.shape[0] > 1 and "_tp" in opt_state_name: - param_name = self.optimizer_key_to_model_state_key(opt_state_name[:-5]) - param_shape = all_param_meta[param_name][0] - assert opt_state_value.numel() == reduce(lambda x, y: x * y, param_shape) - reshaped_opt_state_value = opt_state_value.reshape(param_shape) - optimizer_state_dict[opt_state_name] = reshaped_opt_state_value - concat_optimier_state_dict = {} - - optimizer_state_key_to_tp_keys = {} - for opt_state_name in optimizer_state_dict.keys(): - # Count how each key is split into keys ending with ‘_tpXX’. - # optimizer_state_key_to_tp_keys : {key:[key_tp00,key_tp01]} - opt_state_name_removed_tp_rank = opt_state_name[:-5] - if opt_state_name_removed_tp_rank not in optimizer_state_key_to_tp_keys: - optimizer_state_key_to_tp_keys[opt_state_name_removed_tp_rank] = [opt_state_name] - else: - optimizer_state_key_to_tp_keys[opt_state_name_removed_tp_rank].append(opt_state_name) - - for opt_state_name_removed_tp_rank, opt_state_name in optimizer_state_key_to_tp_keys.items(): - opt_state_name.sort(key=lambda x: int(x[-2:])) - - for opt_state_name_removed_tp_rank, opt_state_name in optimizer_state_key_to_tp_keys.items(): - model_state_name = self.optimizer_key_to_model_state_key(opt_state_name_removed_tp_rank) - local_shape = optimizer_state_dict[opt_state_name[0]].shape - if ( - ".beta1_pow_acc" not in opt_state_name_removed_tp_rank - and ".beta2_pow_acc" not in opt_state_name_removed_tp_rank - ): - global_shape = self.model_state_global_shape[model_state_name] - else: - global_shape = (1,) - - if len(local_shape) != 1: - assert len(local_shape) == len(global_shape) - - axis = -1 - for i in range(len(local_shape)): - if local_shape[i] != global_shape[i]: - axis = i - break - - is_replicated = axis == -1 - tp_tensors = [] - for opt_state_name_with_tp_rank in opt_state_name: - tp_tensors.append(optimizer_state_dict[opt_state_name_with_tp_rank]) - - if not is_replicated: - # Derive the partition strategy based on the global_shape, then concatenate. - concat_optimier_state_dict[opt_state_name_removed_tp_rank] = paddle.concat(tp_tensors, axis=axis) - else: - concat_optimier_state_dict[opt_state_name_removed_tp_rank] = tp_tensors[0] - - fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" - local_tensor_meta_data = {} - local_tensor_index = {} - for k, v in concat_optimier_state_dict.items(): - # Generate metadata. - local_shape = v.shape - global_offset = tuple([0] * len(local_shape)) - dtype = str(v.dtype).split(".")[1] - local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) - local_tensor_index[k] = [LocalTensorIndex(k, global_offset), fake_file_name] - - global_local_tensor_meta_data = [] - global_local_tensor_index = [] - - use_dist = True if paddle.distributed.get_world_size() > 1 else False - - if use_dist: - paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) - paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) - else: - global_local_tensor_meta_data = [local_tensor_meta_data] - global_local_tensor_index = [local_tensor_index] - - state_dict_metadata = {} - for tensor_meta_data in global_local_tensor_meta_data: - for k, v in tensor_meta_data.items(): - if k not in state_dict_metadata: - state_dict_metadata[k] = [v] - else: - state_dict_metadata[k].append(v) - - storage_metadata = {} - for tensor_index in global_local_tensor_index: - for k, v in tensor_index.items(): - storage_metadata[v[0]] = v[1] - - meta_data = Metadata(state_dict_metadata, storage_metadata, None) - source_state_dict = {fake_file_name: concat_optimier_state_dict} - return meta_data, source_state_dict - - elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: - return self.gen_metadata_for_tp_sharded_tensor() - else: - if self.is_sharding_stage3: - for state_name, shard_info in self.global_sharded_tensor_infos.items(): - shard_info.sort(key=lambda x: x[0]["sharding_rank"]) - state_dict_metadata = {} - storage_metadata = {} - # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. - for state_name, shard_info in self.global_sharded_tensor_infos.items(): - global_offset = 0 - for item in shard_info: - if len(item[1]) == 1: - local_tensor_meta_data = LocalTensorMetadata((global_offset,), item[1], item[2]) - local_tensor_index = LocalTensorIndex(state_name, (global_offset,)) - global_offset += item[1][0] - else: - global_offset = tuple([0] * len(item[1])) - local_tensor_meta_data = LocalTensorMetadata(global_offset, item[1], item[2]) - local_tensor_index = LocalTensorIndex(state_name, global_offset) - if state_name not in state_dict_metadata: - state_dict_metadata[state_name] = [local_tensor_meta_data] - else: - state_dict_metadata[state_name].append(local_tensor_meta_data) - storage_metadata[local_tensor_index] = item[3] - - metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) - model_state_shapes = [] - dtype = "" - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - if file.endswith(MODEL_WEIGHT_SUFFIX): - for k, v in state_dict.items(): - model_state_shapes.append([k, v.shape]) - dtype = str(v.dtype).split(".")[1] - - dtypes = self.gather_global_object([dtype]) - for dtype_s in dtypes: - if len(dtype_s) > 0: - dtype = dtype_s - - assert len(dtype) > 0 - - global_model_state_shapes = self.gather_global_object(model_state_shapes) - - partition_result = self.partition_parameters( - global_model_state_shapes, True, paddle.distributed.get_world_size() - ) - - cur_rank_merger_model_params = partition_result[self.cur_rank] - target_state_dict = {} - for item in cur_rank_merger_model_params: - key = item[0] - shape = item[1] - flatten_shape = reduce(lambda a, b: a * b, item[1]) - target_state_dict[key] = paddle.zeros(shape, dtype) - target_state_dict[key + ".moment1"] = paddle.zeros((flatten_shape,), "float32") - target_state_dict[key + ".moment2"] = paddle.zeros((flatten_shape,), "float32") - if self.optimizer_state_with_master_weights: - target_state_dict[key + ".master_weight"] = paddle.zeros((flatten_shape,), "float32") - # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. - # Later, when these are compared with the global shape, we realize that they are replicated. - - target_state_dict[key + ".beta1_pow_acc"] = paddle.zeros((1,), "float32") - target_state_dict[key + ".beta2_pow_acc"] = paddle.zeros((1,), "float32") - - _load_state_dict(target_state_dict, self.cur_rank_loaded_state_dict, [metadata_for_merge_sharding]) - - # Reshape - for item in cur_rank_merger_model_params: - key = item[0] - shape = item[1] - for k, v in target_state_dict.items(): - if key == self.optimizer_key_to_model_state_key(k): - if tuple(shape) != tuple(v.shape) and v.numel() == reduce(lambda x, y: x * y, shape): - reshaped_v = v.reshape(shape) - target_state_dict[k] = reshaped_v - - fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" - local_tensor_meta_data = {} - local_tensor_index = {} - for k, v in target_state_dict.items(): - # Generate metadata. - local_shape = v.shape - global_offset = tuple([0] * len(local_shape)) - dtype = str(v.dtype).split(".")[1] - local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) - local_tensor_index[k] = [LocalTensorIndex(k, global_offset), fake_file_name] - - global_local_tensor_meta_data = [] - global_local_tensor_index = [] - - use_dist = True if paddle.distributed.get_world_size() > 1 else False - - if use_dist: - paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) - paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) - else: - global_local_tensor_meta_data = [local_tensor_meta_data] - global_local_tensor_index = [local_tensor_index] - - state_dict_metadata = {} - for tensor_meta_data in global_local_tensor_meta_data: - for k, v in tensor_meta_data.items(): - if k not in state_dict_metadata: - state_dict_metadata[k] = [v] - else: - state_dict_metadata[k].append(v) - - storage_metadata = {} - for tensor_index in global_local_tensor_index: - for k, v in tensor_index.items(): - storage_metadata[v[0]] = v[1] - - meta_data = Metadata(state_dict_metadata, storage_metadata, None) - source_state_dict = {fake_file_name: target_state_dict} - - return meta_data, source_state_dict - else: - return self.gen_metadata_for_tp_sharded_tensor() - - def load_state_dict_and_rename(self): - """ - Parse the distributed information from the names of the checkpoint files and evenly parse out the distributed information for each weight/optimizer state - into self.global_sharded_tensor_infos (data structure: param_name -> [{tp_rank: 1, sharding_rank: 1}, shape, dtype, file_name]). Modify the names of the - optimizer states in the form of parameter+suffix and record them in self.cur_rank_loaded_state_dict (data structure: file_name -> renamed_state_dict). - 1. Load balancing: Each rank parses a portion of the checkpoint files. - 2. Flatten master_weights in opt_state into opt_state. - 3. Rename the keys in opt_state according to the rule: adamw_optimizer_param_suffix_name_mapping. - 4. Optimizer state renaming and distributed information extraction: - * If it is sharding_stage1/2_v2 version: - * Renaming: rename_using_model_meta: In this case, a model_meta file is required. According to this file, - obtain the name mapping of weights and optimizer parameters, so that the optimizer states of manual and static partitions can correspond. - * Distributed information extraction: Record the distributed information of parameters: name -> [{tp_rank, sharding_rank}, shape, dtype, file_name]. - * If it is sharding_stage1/2_v1 version: - * Renaming: - * If a model_meta file exists: - * rename_using_model_meta - * If a model_meta file does not exist: - * According to the characteristics of v1 partitioning, infer the mapping relationship between optimizer states and weights (partition_result): master_weight_name_to_model_weight_name_mapping. - * Distributed information extraction: Record the distributed information of parameters: name -> [{tp_rank}, shape, dtype, file_name] (parameters will not be sharded). - * If it is sharding_stage3: - * Renaming: - * If a model_meta file exists: - * rename_using_model_meta - * If a model_meta file does not exist: - * Establish the mapping between weights and optimizer names according to the order of optimizer states and weights: rename_using_optimizer_state_order. - * Distributed information extraction: Record the distributed information of parameters: name -> [{tp_rank, sharding_rank}, shape, dtype, file_name]. - """ - rank_access_files = {} - if self.is_model_state_stored: - rank_access_files[self.cur_rank] = ( - self.cur_rank_model_state_file_names + self.cur_rank_optimizer_state_file_names - ) - else: - rank_access_files[self.cur_rank] = self.cur_rank_optimizer_state_file_names - - global_rank_access_files = self.gather_global_object(rank_access_files) - need_read_files = get_rank_to_read_files(global_rank_access_files, global_rank_access_files) - logger.info(f"The file(s) to be loaded for the current rank are: {need_read_files}") - self.cur_rank_loaded_state_dict = {} - - for file in need_read_files: - self.cur_rank_loaded_state_dict[file] = paddle.load(os.path.join(self.path, file)) - - self.optimizer_state_with_master_weights = False - - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): - state_dict.pop("LR_Scheduler") - if "master_weights" in state_dict: - self.optimizer_state_with_master_weights = True - master_weights = state_dict.pop("master_weights") - for master_weight_name, master_weight_value in master_weights.items(): - # In sharding stage3, ‘@slice’ will be added in front of the key for master_weight, which is removed here. - state_dict[master_weight_name.replace("slice@", "") + ".master_weight"] = master_weight_value - - self.cur_rank_loaded_state_dict[file] = state_dict - - memory_size = 0 - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - for k, v in state_dict.items(): - memory_size += v.numel() * v.element_size() - - memory_size = memory_size.numpy() / 2**20 - logger.debug( - f"The current rank has finished loading the checkpoint file and has allocated {memory_size} MB of GPU memory." - ) - - # After the rank has finished loading the files it needs, it can infer sharding_stage1_v and is_sharding_stage3. - self.sharding_stage1_v = self.infer_sharding_stage1_v() - self.is_sharding_stage3 = self.infer_is_sharding_stage3() - - flags = [ - ["is sharding stage1/2", (not self.is_sharding_stage3) and self.sharding_degree > 1], - ["sharding stage1/2 version", self.sharding_stage1_v], - ["is sharding stage3", self.is_sharding_stage3], - ["master_weight", self.optimizer_state_with_master_weights], - ] - self.print_checkpoint_file_info(flags) - - # In sharding stage3, the parameters need to be reordered based on whether they are sliced. - # The threshold for determining whether to slice is segment_size, with a default value of 2**20. - # However, sharding stage3 allows users to specify their own unsliced layers, which seems to be incompatible here. - if self.is_sharding_stage3: - logger.info("The currently loaded checkpoint file comes from sharding stage 3.") - segment_size = 2**20 - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - if file.endswith(MODEL_WEIGHT_SUFFIX): - sliced_prameters = [] - unsliced_parameters = [] - sorted_state_dict = {} - for k, v in state_dict.items(): - if v.numel() > segment_size: - sliced_prameters.append(k) - else: - unsliced_parameters.append(k) - for k in sliced_prameters + unsliced_parameters: - sorted_state_dict[k] = state_dict.pop(k) - self.cur_rank_loaded_state_dict[file] = sorted_state_dict - - # rename and record sharded_tensor_info - cur_rank_sharded_tensor_infos = {} - - # 1. Handling the sharding stage1 v2 scenario, where the save_sharded_model flag must be enabled, independent of master_weights. - if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: - logger.info("The currently loaded checkpoint file comes from sharding stage1 v2.") - assert self.is_model_meta_exists - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, - # and then append the tp_degree. - renamed_state_dict = self.rename_using_model_meta(file) - self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) - self.cur_rank_loaded_state_dict[file] = renamed_state_dict - # 2. In handling the sharding stage1 v1 and stage2 scenario, the optimizer states are distributed across different ranks. - # We need to obtain the name mapping by simulating the partitioning method, without concern for the presence of master_weights. - elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: - logger.info("The currently loaded checkpoint file comes from sharding stage1/2 v1.") - if not self.is_model_meta_exists: - file_to_state_dict_shapes_mapping = {} - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - shapes = [] - for state_name, state_value in state_dict.items(): - shapes.append([state_name, state_value.shape]) - file_to_state_dict_shapes_mapping[file] = shapes - - global_file_to_state_dict_shapes_mapping = self.gather_global_object(file_to_state_dict_shapes_mapping) - - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) - sharding_optimizer_state_shards = [] - if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): - for k, v in global_file_to_state_dict_shapes_mapping.items(): - (tp_rank_, pp_rank_, sharding_rank_) = self.get_distribution_rank_from_file_name(k) - if tp_rank == tp_rank_ and pp_rank == pp_rank_ and k.endswith(OPTIMIZER_WEIGHT_SUFFIX): - sharding_optimizer_state_shards.append([v, sharding_rank_]) - model_state_file_name = self.get_model_state_file_from(file) - model_state_shapes = global_file_to_state_dict_shapes_mapping[model_state_file_name] - sharding_optimizer_state_shards.sort(key=lambda x: x[1]) - - partition_result_0 = self.partition_parameters(model_state_shapes, False, self.sharding_degree) - partition_result_1 = self.partition_parameters(model_state_shapes, True, self.sharding_degree) - - for rank, portion in partition_result_0.items(): - portion = sorted(portion, key=model_state_shapes.index) - partition_result_0[rank] = portion - - for rank, portion in partition_result_1.items(): - portion = sorted(portion, key=model_state_shapes.index) - partition_result_1[rank] = portion - - sharding_sort_parameters = False - - for i in range(len(sharding_optimizer_state_shards)): - if not sharding_sort_parameters: - state_shard = sharding_optimizer_state_shards[i][0] - partitioned_shard = partition_result_0[i] - for j in range(len(partitioned_shard)): - if partitioned_shard[j][1] != state_shard[j][1]: - sharding_sort_parameters = True - break - - if sharding_sort_parameters: - for i in range(len(sharding_optimizer_state_shards)): - state_shard = sharding_optimizer_state_shards[i][0] - partitioned_shard = partition_result_1[i] - for j in range(len(partitioned_shard)): - assert partitioned_shard[j][1] == state_shard[j][1] - - if sharding_sort_parameters: - partition_result = partition_result_1 - else: - partition_result = partition_result_0 - - name_mapping = {} - for i in range(len(sharding_optimizer_state_shards)): - state_shard = sharding_optimizer_state_shards[i][0] - partitioned_shard = partition_result[i] - suffix_bucket = {} - for suffix in OPTIMIZER_STATE_NAME_SUFFIX: - suffix_bucket[suffix] = [] - for j in range(len(state_shard)): - optimizer_state_name = state_shard[j][0] - if "moment1" in optimizer_state_name: - suffix_bucket[".moment1"].append(optimizer_state_name) - elif "moment2" in optimizer_state_name: - suffix_bucket[".moment2"].append(optimizer_state_name) - elif "beta1_pow_acc" in optimizer_state_name: - suffix_bucket[".beta1_pow_acc"].append(optimizer_state_name) - elif "beta2_pow_acc" in optimizer_state_name: - suffix_bucket[".beta2_pow_acc"].append(optimizer_state_name) - else: - suffix_bucket[".master_weight"].append(optimizer_state_name) - - # In this scenario, the order of master_weights might differ from the order of the regular optimizer states and needs to be reordered. - if len(suffix_bucket[".master_weight"]) != 0: - master_weight_keys = [] - for master_weight_key in suffix_bucket[".master_weight"]: - for index in range(len(state_shard)): - if master_weight_key[: -len(".master_weight")] in state_shard[index][0]: - # Find the first match - master_weight_keys.append([master_weight_key, index]) - break - - master_weight_keys = sorted(master_weight_keys, key=lambda x: x[1]) - suffix_bucket[".master_weight"] = [x[0] for x in master_weight_keys] - - for suffix, old_names in suffix_bucket.items(): - assert len(old_names) == len(partitioned_shard) - for k in range(len(old_names)): - name_mapping[old_names[k]] = partitioned_shard[k][0] + suffix - - renamed_state_dict = {} - # In this branch, sharding does not split the optimizer states; it merely relocates them to different cards. - # Therefore, the sharding information can now be directly removed. - for opt_state_name, opt_state_value in state_dict.items(): - renamed_state_dict[name_mapping[opt_state_name]] = opt_state_value - - self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) - - self.cur_rank_loaded_state_dict[file] = renamed_state_dict - else: - self.get_sharded_tensor_infos(file, state_dict, cur_rank_sharded_tensor_infos) - else: - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - renamed_state_dict = self.rename_using_model_meta(file) - self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) - - self.cur_rank_loaded_state_dict[file] = renamed_state_dict - else: - # 3. Handling the sharding stage3 and non-sharding scenario - - file_to_state_dict_keys_mapping = {} - for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): - file_to_state_dict_keys_mapping[file_name] = list(state_dict.keys()) - global_file_to_state_dict_keys_mapping = self.gather_global_object(file_to_state_dict_keys_mapping) - - logger.info("The current checkpoint comes from either sharding stage 3 or non-sharding.") - if not self.is_model_meta_exists: - for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): - if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): - model_state_file_name = self.get_model_state_file_from(file_name) - assert model_state_file_name is not None - model_state_keys = global_file_to_state_dict_keys_mapping[model_state_file_name] - renamed_state_dict = self.rename_using_optimizer_state_order(model_state_keys, state_dict) - self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) - self.cur_rank_loaded_state_dict[file_name] = renamed_state_dict - else: - self.get_sharded_tensor_infos(file_name, state_dict, cur_rank_sharded_tensor_infos) - else: - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, - # and then append the tp_degree. - renamed_state_dict = self.rename_using_model_meta(file) - self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) - self.cur_rank_loaded_state_dict[file] = renamed_state_dict - - # gather global sharded tensor infos - sharded_tensor_infos = self.gather_global_object({self.cur_rank: cur_rank_sharded_tensor_infos}) - self.global_sharded_tensor_infos = {} - for rank, sharded_tensor_info in sharded_tensor_infos.items(): - for state_name, shard_info in sharded_tensor_info.items(): - if state_name not in self.global_sharded_tensor_infos: - self.global_sharded_tensor_infos[state_name] = shard_info - else: - self.global_sharded_tensor_infos[state_name] += shard_info - logger.debug(f"global_sharded_tensor_infos: {self.global_sharded_tensor_infos}") - - def get_sharded_tensor_infos(self, file, state_dict, cur_rank_sharded_tensor_infos): - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) - for state_name, state_value in state_dict.items(): - if state_name not in cur_rank_sharded_tensor_infos: - cur_rank_sharded_tensor_infos[state_name] = [ - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - state_value.shape, - str(state_value.dtype).split(".")[1], - file, - ] - ] - else: - cur_rank_sharded_tensor_infos[state_name].append( - [ - {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, - state_value.shape, - str(state_value.dtype).split(".")[1], - file, - ] - ) - - def gen_metadata_for_tp_sharded_tensor(self): - """ - Based on the distributed information of each weight/optimizer state (global_sharded_tensor_infos), construct Metadata - information: LocalTensorMetadata,LocalTensorIndex - """ - for state_name, shard_info in self.global_sharded_tensor_infos.items(): - shard_info.sort(key=lambda x: x[0]["tp_rank"]) - - state_dict_metadata = {} - storage_metadata = {} - - # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. - for state_name, shard_info in self.global_sharded_tensor_infos.items(): - - global_offset = 0 - local_shape = shard_info[0][1] - - model_state_name = self.optimizer_key_to_model_state_key(state_name) - if ".beta1_pow_acc" not in state_name and ".beta2_pow_acc" not in state_name: - global_shape = self.model_state_global_shape[model_state_name] - else: - global_shape = (1,) - assert len(local_shape) == len(global_shape) - axis = -1 - for i in range(len(local_shape)): - if local_shape[i] != global_shape[i]: - axis = i - break - - is_replicated = axis == -1 - global_offset = [0] * len(local_shape) - - if is_replicated: - shard_info = [shard_info[0]] - - for item in shard_info: - local_tensor_meta_data = LocalTensorMetadata(tuple(global_offset), item[1], item[2]) - local_tensor_index = LocalTensorIndex(state_name, tuple(global_offset)) - global_offset[axis] += item[1][axis] - if state_name not in state_dict_metadata: - state_dict_metadata[state_name] = [local_tensor_meta_data] - else: - state_dict_metadata[state_name].append(local_tensor_meta_data) - storage_metadata[local_tensor_index] = item[3] - - metadata = Metadata(state_dict_metadata, storage_metadata, None) - source_state_dict = self.cur_rank_loaded_state_dict - - return metadata, source_state_dict - - def rename_using_model_meta(self, file_name): - """ - Rename the keys in opt_state_dict based on the following rule: model_meta records a mapping of parameter names to optimizer names. - Here, we unify the optimizer state names to parameter names directly. For example: - * model_meta: linear0 -> param0 - * opt_state: param0.w0 - * Renamed opt_state: linear0.w0 - NOTE:The reason for renaming is that there is a difference in the naming of optimizer parameters between dynamic and static partitions, - making it difficult to match optimizer parameters directly by name. Therefore, we unify them to the weight names. - """ - if not hasattr(self, "model_meta"): - meta_file_path = os.path.join(self.path, MODEL_META_FILE_NAME) - assert os.path.exists(meta_file_path) - with open(meta_file_path, "r") as file: - self.model_meta = json.load(file) - - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) - dist_strategy_key = "tp" + "{:02d}".format(tp_rank) + "_" + "pp" + "{:02d}".format(pp_rank) - # Map model weight names to their corresponding names of master_weights in the optimizer state. - if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): - structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"] - parameter_to_structured_name = {} - for k, v in structure_name_mapping.items(): - parameter_to_structured_name[v] = k - state_dict = self.cur_rank_loaded_state_dict[file_name] - return self.rename_using_parameter_to_structured_name_mapping(state_dict, parameter_to_structured_name) - else: - return self.cur_rank_loaded_state_dict[file_name] - - def rename_auto_parallel_state_dict(self): - """ - Rename the keys of the auto parallel state_dict according to certain rules: - 1. Rename the suffixes of the optimizer states to a unified format: adamw_optimizer_status_name_suffix_mappings - """ - self.auto_parallel_state_dict = self.rename_using_parameter_to_structured_name_mapping( - self.auto_parallel_state_dict, self.parameter_to_structured_name - ) - - def rename_using_parameter_to_structured_name_mapping(self, state_dict, parameter_to_structured_name): - renamed_state_dict = {} - - def rename(old_name, parameter_to_structured_name): - for i in range(1, len(old_name) + 1): - param_name = old_name[:i] # param_name - suffix = old_name[i:] # suffix - if param_name in parameter_to_structured_name: - structure_name = parameter_to_structured_name[param_name] - if "moment1" in suffix: - return structure_name + ".moment1" - elif "moment2" in suffix: - return structure_name + ".moment2" - elif "beta1_pow_acc" in suffix: - return structure_name + ".beta1_pow_acc" - elif "beta2_pow_acc" in suffix: - return structure_name + ".beta2_pow_acc" - else: - return structure_name + ".master_weight" - return None - - for key, value in state_dict.items(): - if key in parameter_to_structured_name.values(): - new_name = key - else: - new_name = rename(key, parameter_to_structured_name) - assert new_name is not None - renamed_state_dict[new_name] = value - - return renamed_state_dict - - def rename_using_optimizer_state_order(self, model_state_keys, optimizer_state_dict): - - name_mapping = {} - suffix_bucket = {} - assert len(optimizer_state_dict) % len(model_state_keys) == 0 - for suffix in OPTIMIZER_STATE_NAME_SUFFIX: - suffix_bucket[suffix] = [] - for satte_name, satte_value in optimizer_state_dict.items(): - if "moment1" in satte_name: - suffix_bucket[".moment1"].append(satte_name) - elif "moment2" in satte_name: - suffix_bucket[".moment2"].append(satte_name) - elif "beta1_pow_acc" in satte_name: - suffix_bucket[".beta1_pow_acc"].append(satte_name) - elif "beta2_pow_acc" in satte_name: - suffix_bucket[".beta2_pow_acc"].append(satte_name) - else: - suffix_bucket[".master_weight"].append(satte_name) - - for suffix, old_names in suffix_bucket.items(): - assert len(old_names) == len(model_state_keys) - for i in range(len(old_names)): - name_mapping[old_names[i]] = model_state_keys[i] + suffix - - renamed_state_dict = {} - for k, v in optimizer_state_dict.items(): - renamed_state_dict[name_mapping[k]] = v - return renamed_state_dict - - def partition_parameters(self, model_state_shapes, is_sort, shard_num): - """ - In sharding_stage3 and sharding_stage1_v1, parameters and optimizer states will be assigned to different ranks. This function defines the allocation rules. - For details, refer to: python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py. - """ - mapping = {} - for rank_ in range(shard_num): - mapping[rank_] = [] - sizes = [0] * shard_num - - parameters = model_state_shapes.copy() - - if is_sort: - parameters.sort(key=lambda p: reduce(lambda x, y: x * y, p[1]), reverse=True) - - for param in parameters: - rank = sizes.index(min(sizes)) - mapping[rank].append(param) - numel = reduce(lambda x, y: x * y, param[1], 1) - assert numel > 0, f"param [{param[0]}] should larger than 0, but it is [{numel}]" - sizes[rank] += numel - - return mapping - - def get_is_model_meta_exists_flag(self): - save_sharded_model_flag = self.gather_global_object( - [os.path.exists(os.path.join(self.path, MODEL_META_FILE_NAME))] - ) - return True in save_sharded_model_flag - - def get_is_model_state_stored_flag(self): - if len(self.global_model_state_file_names) == 0: - return False - model_state_file_name = self.global_model_state_file_names[0] - file_readable = model_state_file_name in self.cur_rank_model_state_file_names - file_readables = self.gather_global_object([file_readable]) - coordinator_rank = file_readables.index(True) - is_model_state_stored = False - if self.cur_rank == coordinator_rank: - model_state_file_size = os.path.getsize(os.path.join(self.path, model_state_file_name)) - if model_state_file_size > MODEL_STATE_FILE_MIN_SIZE: - is_model_state_stored = True - - is_model_state_stored_flags = self.gather_global_object([is_model_state_stored]) - return True in is_model_state_stored_flags - - def flatten_state_dict(self, state_dict): - flattened_state_dict = {} - flat_state_dict, mapping = flatten_state_dict(state_dict) - for k, v in flat_state_dict.items(): - last_level_key = mapping[k][-1] - assert last_level_key not in flattened_state_dict - flattened_state_dict[last_level_key] = v - return flattened_state_dict - - def gather_global_object(self, cur_rank_object): - all_rank_objects = [] - if self.use_dist: - paddle.distributed.all_gather_object(all_rank_objects, cur_rank_object) - else: - all_rank_objects = [all_rank_objects] - - if isinstance(cur_rank_object, list): - for obj in all_rank_objects: - assert isinstance(obj, list) - return [item for sublist in all_rank_objects for item in sublist] - elif isinstance(cur_rank_object, dict): - for obj in all_rank_objects: - assert isinstance(obj, dict) - global_map = {} - for rank_map in all_rank_objects: - global_map.update(rank_map) - return global_map - else: - raise ValueError("cur_rank_object should be either a list or a dict") - - def get_local_checkpoint_file_names(self): - cur_rank_files = os.listdir(self.path) - cur_rank_model_state_file_names = [] - cur_rank_optimizer_state_file_names = [] - for file_name in cur_rank_files: - if file_name.endswith(MODEL_WEIGHT_SUFFIX): - cur_rank_model_state_file_names.append(file_name) - elif file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): - cur_rank_optimizer_state_file_names.append(file_name) - if SCHEDULER_NAME in cur_rank_model_state_file_names: - cur_rank_model_state_file_names.remove(SCHEDULER_NAME) - return cur_rank_model_state_file_names, cur_rank_optimizer_state_file_names - - def get_distribution_rank_from_file_name(self, file_name): - pp_degree = 0 - tp_degree = 0 - sharding_degree = 0 - pattern_pp = r"pp(\d+)" - pattern_tp = r"tp(\d+)" - pattern_shard = r"shard(\d+)" - match_pp = re.search(pattern_pp, file_name) - if match_pp: - pp_degree = int(match_pp.group(1)) - match_tp = re.search(pattern_tp, file_name) - if match_tp: - tp_degree = int(match_tp.group(1)) - match_shard = re.search(pattern_shard, file_name) - if match_shard: - sharding_degree = int(match_shard.group(1)) - return (tp_degree, pp_degree, sharding_degree) - - def initial_distributed_configuration(self): - self.pp_degree = 0 - self.tp_degree = 0 - self.sharding_degree = 0 - - all_files = self.global_model_state_file_names + self.global_optimizer_state_file_names - - for file in all_files: - (tp_degree, pp_degree, sharding_degree) = self.get_distribution_rank_from_file_name(file) - self.pp_degree = max(self.pp_degree, pp_degree) - self.tp_degree = max(self.tp_degree, tp_degree) - self.sharding_degree = max(self.sharding_degree, sharding_degree) - - self.pp_degree = self.pp_degree + 1 - self.tp_degree = self.tp_degree + 1 - self.sharding_degree = self.sharding_degree + 1 - - def infer_sharding_stage1_v(self): - sharding_stage1_v = [2] - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - if file.endswith(OPTIMIZER_WEIGHT_SUFFIX) and sharding_stage1_v[0] == 2: - for k, v in state_dict.items(): - # Under shardingv2, the optimizer state is first flattened and then split. - if len(v.shape) != 1: - sharding_stage1_v = [1] - break - - sharding_stage1_v = self.gather_global_object(sharding_stage1_v) - if 1 in sharding_stage1_v: - return 1 - return 2 - - def infer_is_sharding_stage3(self): - if self.sharding_degree == 1: - return False - if self.pp_degree > 1 or self.tp_degree > 1: - # Currently, sharding stage 3 does not support concurrent use with tensor parallelism (TP) and pipeline parallelism (PP). - return False - - is_sharding_stage3 = True - - file_to_state_shape_mapping = {} - for file, state_dict in self.cur_rank_loaded_state_dict.items(): - if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): - state_shape_mapping = {} - for k, v in state_dict.items(): - state_shape_mapping[k] = v.shape - if len(v.shape) != 1: - return False - file_to_state_shape_mapping[file] = state_shape_mapping - global_file_to_state_shape_mapping = self.gather_global_object(file_to_state_shape_mapping) - - state_dict_std = global_file_to_state_shape_mapping[list(global_file_to_state_shape_mapping.keys())[0]] - - for file, state_dict in global_file_to_state_shape_mapping.items(): - if state_dict != state_dict_std: - is_sharding_stage3 = False - break - return is_sharding_stage3 - - def get_model_state_file_from(self, optimizer_state_file_name): - (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(optimizer_state_file_name) - for model_state_file in self.global_model_state_file_names: - distributed_rank = self.get_distribution_rank_from_file_name(model_state_file) - if tp_rank == distributed_rank[0] and pp_rank == distributed_rank[1]: - return model_state_file - return None - - def optimizer_key_to_model_state_key(self, optimizer_key): - model_state_key = optimizer_key - for suffix in OPTIMIZER_STATE_NAME_SUFFIX: - if model_state_key.endswith(suffix): - # Remove the suffix from model_state_key - model_state_key = model_state_key[: -len(suffix)] - break - return model_state_key - - def print_checkpoint_file_info(self, flags): - processed_flags = [ - [str(item) if not isinstance(item, bool) else "True" if item else "False" for item in row] for row in flags - ] - - logger.info("Checkpoint file info:") - headers = ["Flag", "Value"] - col_widths = [max(len(str(item)) for item in column) for column in zip(headers, *flags)] - format_str = "| " + " | ".join(f"{{:<{width}}}" for width in col_widths) + " |" - separator_line = "+-" + "-+-".join("-" * width for width in col_widths) + "-+" - - logger.info(separator_line) - logger.info(format_str.format(*headers)) - logger.info(separator_line) - for row in processed_flags: - logger.info(format_str.format(*row)) - logger.info(separator_line) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index c21e5baa62f6..70cf3ae27af7 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -349,7 +349,7 @@ class TrainingArguments: The path to a folder with a valid checkpoint for your model. This argument is not directly used by [`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example scripts](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples) for more details. - resume_form_hybrid_parallel (`bool`, *optional*): + auto_parallel_resume_form_hybrid_parallel (`bool`, *optional*): Wether hybrid paralle checkpoints be loaded in auto parallel mode. flatten_param_grads (`bool`, *optional*): Whether use flatten_param_grads method in optimizer, only used on NPU devices. Default is `False`. @@ -772,7 +772,7 @@ class TrainingArguments: default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."}, ) - resume_form_hybrid_parallel: Optional[bool] = field( + auto_parallel_resume_form_hybrid_parallel: Optional[bool] = field( default=False, metadata={"help": "Wether hybrid paralle checkpoints be loaded in auto parallel mode."}, ) From dfb16b70943f3dc2f4cf675b64ee429f36680d39 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Wed, 21 Aug 2024 15:43:49 +0800 Subject: [PATCH 30/30] fix --- paddlenlp/trainer/utils/ckpt_converter.py | 1127 +++++++++++++++++++++ 1 file changed, 1127 insertions(+) create mode 100644 paddlenlp/trainer/utils/ckpt_converter.py diff --git a/paddlenlp/trainer/utils/ckpt_converter.py b/paddlenlp/trainer/utils/ckpt_converter.py new file mode 100644 index 000000000000..88b5b7c474df --- /dev/null +++ b/paddlenlp/trainer/utils/ckpt_converter.py @@ -0,0 +1,1127 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import re +from functools import reduce + +import paddle +from paddle.distributed.checkpoint.load_state_dict import ( + _load_state_dict, + get_rank_to_read_files, +) +from paddle.distributed.checkpoint.metadata import ( + LocalTensorIndex, + LocalTensorMetadata, + Metadata, +) +from paddle.distributed.checkpoint.utils import flatten_state_dict +from paddle.distributed.fleet.utils.log_util import logger + +MODEL_WEIGHT_SUFFIX = ".pdparams" +OPTIMIZER_WEIGHT_SUFFIX = ".pdopt" +SCHEDULER_NAME = "scheduler.pdparams" +MODEL_META_FILE_NAME = "model_meta.json" +OPTIMIZER_STATE_NAME_SUFFIX = [".moment1", ".moment2", ".beta1_pow_acc", ".beta2_pow_acc", ".master_weight"] +MODEL_STATE_FILE_MIN_SIZE = 512 + + +class CheckpointConverter: + def __init__(self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structured_name, patch_dict=None): + self.use_dist = True if paddle.distributed.get_world_size() > 1 else False + self.path = hybrid_parallel_ckpt_path + self.auto_parallel_state_dict = self.flatten_state_dict(state_dict) + self.parameter_to_structured_name = self.gather_global_object(parameter_to_structured_name) + model_state_global_shape = {} + for k, v in self.auto_parallel_state_dict.items(): + model_state_global_shape[k] = v.shape + self.model_state_global_shape = self.gather_global_object(model_state_global_shape) + self.cur_rank = paddle.distributed.get_rank() + + ( + self.cur_rank_model_state_file_names, + self.cur_rank_optimizer_state_file_names, + ) = self.get_local_checkpoint_file_names() + + self.global_model_state_file_names = self.gather_global_object(self.cur_rank_model_state_file_names) + + self.global_optimizer_state_file_names = self.gather_global_object(self.cur_rank_optimizer_state_file_names) + + self.is_model_meta_exists = self.get_is_model_meta_exists_flag() + self.is_model_state_stored = self.get_is_model_state_stored_flag() + + self.initial_distributed_configuration() + + if patch_dict is not None: + self.patch_dict = patch_dict + for k, v in self.parameter_to_structured_name.items(): + if v in self.patch_dict: + self.parameter_to_structured_name[k] = self.patch_dict[v] + + del_keys = [] + for k, v in self.auto_parallel_state_dict.items(): + if k in self.patch_dict: + del_keys.append(k) + + for k in del_keys: + self.auto_parallel_state_dict[self.patch_dict[k]] = self.auto_parallel_state_dict[k] + self.auto_parallel_state_dict.pop(k) + + flags = [ + ["tp degree", self.tp_degree], + ["pp degree", self.pp_degree], + ["sharding degree", self.sharding_degree], + ["is model_meta exists", self.is_model_meta_exists], + ["is model_state stored", self.is_model_state_stored], + ] + self.print_checkpoint_file_info(flags) + + def load_from_hybrid_parallel_checkpoint(self): + """ + Automatically and inplace load the distributed checkpoint stored in hybrid parallel mode into the auto parallel state_dict. + The main logic is as follows: + 1. Call rename_semi_auto_state_dict: Rename the keys of the auto parallel state_dict according to certain rules. + (Why rename? To facilitate the subsequent correspondence between the optimizer state names of the semi-automatic and static optimizers.) + 2. Call gen_metadata_and_prepare_source_state_dict: Automatically parse the manual checkpoint file based on the state_dict information + provided by auto parallel, obtaining the Metadata and state_dict required for auto parallel to load the checkpoint. + 3. Call load_state_dict: Automatically reshard and load. + 4. Special logic adaptation: In the save_sharded_model mode, the weights are obtained through the master_weight cast in the checkpoint. + """ + self.rename_auto_parallel_state_dict() + + metadata, source_state_dict = self.gen_metadata_and_prepare_source_state_dict() + logger.info("Generated the checkpoint’s metadata.") + logger.debug(f"The checkpoint's metadata is {metadata}.") + if not self.is_model_state_stored: + assert self.optimizer_state_with_master_weights + model_params = {} + for state_name, state_value in self.auto_parallel_state_dict.items(): + if state_name in self.parameter_to_structured_name.values(): + model_params[state_name] = state_value + for param_name in model_params.keys(): + self.auto_parallel_state_dict.pop(param_name) + + logger.info("Requesting GPU memory space to load master_weights.") + appended_master_weight_names = [] + for param_name, param_value in model_params.items(): + master_weight = param_name + ".master_weight" + if master_weight not in self.auto_parallel_state_dict: + appended_master_weight_names.append(master_weight) + if param_value.is_dist(): + param_shape = param_value._local_value().shape + else: + param_shape = param_value.shape + + tmp_tensor = paddle.zeros(param_shape, dtype="float32") + with paddle.base.dygraph.guard(): + if param_value.is_dist(): + self.auto_parallel_state_dict[ + master_weight + ] = paddle.distributed.auto_parallel.api.dtensor_from_local( + tmp_tensor, param_value.process_mesh, param_value.placements + ) + else: + self.auto_parallel_state_dict[master_weight] = tmp_tensor + + logger.info("Calling _load_state_dict to load the required weights.") + _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) + logger.info("Calling _load_state_dict completed, restored the required weights.") + + # In this scenario, the data type of the model state is bfloat16. + for param_name, param_value in model_params.items(): + if param_value.is_dist(): + master_weight = self.auto_parallel_state_dict[param_name + ".master_weight"] + cast_master_weight = paddle.cast(master_weight._local_value(), param_value.dtype) + paddle.assign(cast_master_weight, param_value._local_value()) + else: + master_weight = self.auto_parallel_state_dict[param_name + ".master_weight"] + cast_master_weight = paddle.cast(master_weight, param_value.dtype) + paddle.assign(cast_master_weight, param_value) + for master_weight_name in appended_master_weight_names: + self.auto_parallel_state_dict.pop(master_weight_name) + else: + logger.info("Calling _load_state_dict to load the required weights.") + _load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata]) + logger.info("Calling _load_state_dict completed, restored the required weights.") + logger.info("Successfully loaded hybrid_parallel checkpoint!") + + def gen_metadata_and_prepare_source_state_dict(self): + """ + Automatically parse the manual checkpoint file based on the state_dict information provided by auto parallel, + obtaining the Metadata and state_dict required for auto parallel to load the checkpoint: + 1. Call load_state_dict_and_rename: Parse the distributed information from the names of the checkpoint files, and evenly parse out the distributed + information for each weight/optimizer state into self.global_sharded_tensor_infos(data structure:param_name -> [{tp_rank: 1, sharding_rank: 1}, shape, dtype, file_name]). + Modify the names of the optimizer states in the form ofparameter+suffixand record them in self.cur_rank_loaded_state_dict(data structure:file_name -> renamed_state_dict). + 2. Construct the Metadata and state_dict based on the distributed information obtained in the previous step for the final load. + 3. Special logic adaptation: When sharding is enabled, the optimizer states are also split. In this step, the optimizer states need to be concatenated back according to the sharding dimension: + * Construct the Metadata for concatenating the sharded states back based on the characteristics of sharding. + * Construct a temporaryopt_state_dictand use the_load_state_dictinterface to obtain the state_dict with the sharded states concatenated back. + * Reshape the optimizer states back to the shape of the weights. + """ + self.load_state_dict_and_rename() + logger.info("Complete the loading and renaming of state_dict.") + if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: + for state_name, shard_info in self.global_sharded_tensor_infos.items(): + shard_info.sort(key=lambda x: x[0]["sharding_rank"]) + + state_dict_metadata = {} + storage_metadata = {} + # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. + for state_name, shard_info in self.global_sharded_tensor_infos.items(): + global_offset = [0] * self.tp_degree + for item in shard_info: + tp_rank = item[0]["tp_rank"] + state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank) + local_tensor_meta_data = LocalTensorMetadata((global_offset[tp_rank],), item[1], item[2]) + local_tensor_index = LocalTensorIndex(state_name_with_tp_rank, (global_offset[tp_rank],)) + global_offset[tp_rank] += item[1][0] + if state_name_with_tp_rank not in state_dict_metadata: + state_dict_metadata[state_name_with_tp_rank] = [local_tensor_meta_data] + else: + state_dict_metadata[state_name_with_tp_rank].append(local_tensor_meta_data) + storage_metadata[local_tensor_index] = item[3] + + metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) + + logger.debug(f"The metadata for merge sharding is: {metadata_for_merge_sharding}") + + source_state_dict_for_merge_sharding = {} + for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): + renamed_state_dict = {} + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) + for state_name, state_value in state_dict.items(): + state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank) + renamed_state_dict[state_name_with_tp_rank] = state_value + + source_state_dict_for_merge_sharding[file_name] = renamed_state_dict + + assert self.model_meta is not None + global_model_state_shapes = [] + sharding_metas_keys = [] + for i in range(self.pp_degree): + for j in range(self.tp_degree): + sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j)) + for key in sharding_metas_keys: + param_meta = self.model_meta["sharding_metas"][key]["param_meta"] + for param_name, param_shape_and_dtype in param_meta.items(): + global_model_state_shapes.append([param_name, param_shape_and_dtype[0]]) + + # Distribute all model parameters evenly across each card for loading + + world_size = paddle.distributed.get_world_size() + partition_mapping = self.partition_parameters(global_model_state_shapes, True, world_size) + + partition_model_state_keys = [] + for cur_rank, partition_model_state in partition_mapping.items(): + partition_model_state_keys.append([item[0] for item in partition_model_state]) + + all_param_meta = {} + for i in range(self.tp_degree): + for j in range(self.pp_degree): + key = "tp{:02d}_pp{:02d}".format(i, j) + param_meta = self.model_meta["sharding_metas"][key]["param_meta"] + for param_name, param_shape_and_dtype in param_meta.items(): + all_param_meta[param_name] = param_shape_and_dtype + + param_flattened_shapes = {} + for param_name, param_shape_and_dtype in all_param_meta.items(): + param_flattened_shapes[param_name] = reduce(lambda x, y: x * y, param_shape_and_dtype[0]) + + cur_rank_need_load_model_state_keys = partition_model_state_keys[self.cur_rank] + # Generate the optimizer states corresponding to the model weights. + logger.info("Requesting GPU memory space to concatenate tensors split by sharding1 v2.") + optimizer_state_dict = {} + for key in cur_rank_need_load_model_state_keys: + for tp_rank in range(self.tp_degree): + tp_rank_suffix = "_tp{:02d}".format(tp_rank) + optimizer_state_dict[key + ".moment1" + tp_rank_suffix] = paddle.zeros( + (param_flattened_shapes[key],), "float32" + ) + optimizer_state_dict[key + ".moment2" + tp_rank_suffix] = paddle.zeros( + (param_flattened_shapes[key],), "float32" + ) + if self.optimizer_state_with_master_weights: + optimizer_state_dict[key + ".master_weight" + tp_rank_suffix] = paddle.zeros( + (param_flattened_shapes[key],), "float32" + ) + # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. + # Later, when these are compared with the global shape, we realize that they are replicated. + + optimizer_state_dict[key + ".beta1_pow_acc" + tp_rank_suffix] = paddle.zeros((1,), "float32") + optimizer_state_dict[key + ".beta2_pow_acc" + tp_rank_suffix] = paddle.zeros((1,), "float32") + + malloc_size = 0 + for opt_state_name, opt_state_value in optimizer_state_dict.items(): + malloc_size += opt_state_value.numel() * opt_state_value.element_size() + malloc_size = malloc_size.numpy() / 2**20 + logger.debug(f"{malloc_size} MB of GPU memory were allocated.") + + # merge sharding + logger.info("First call _load_state_dict to stitch back the tensors split by sharding1 v2.") + _load_state_dict(optimizer_state_dict, source_state_dict_for_merge_sharding, [metadata_for_merge_sharding]) + logger.info("Completed the call _load_state_dict, concating back the tensors split by sharding.") + + # Reshape + for opt_state_name, opt_state_value in optimizer_state_dict.items(): + if opt_state_value.shape[0] > 1 and "_tp" in opt_state_name: + param_name = self.optimizer_key_to_model_state_key(opt_state_name[:-5]) + param_shape = all_param_meta[param_name][0] + assert opt_state_value.numel() == reduce(lambda x, y: x * y, param_shape) + reshaped_opt_state_value = opt_state_value.reshape(param_shape) + optimizer_state_dict[opt_state_name] = reshaped_opt_state_value + concat_optimier_state_dict = {} + + optimizer_state_key_to_tp_keys = {} + for opt_state_name in optimizer_state_dict.keys(): + # Count how each key is split into keys ending with ‘_tpXX’. + # optimizer_state_key_to_tp_keys : {key:[key_tp00,key_tp01]} + opt_state_name_removed_tp_rank = opt_state_name[:-5] + if opt_state_name_removed_tp_rank not in optimizer_state_key_to_tp_keys: + optimizer_state_key_to_tp_keys[opt_state_name_removed_tp_rank] = [opt_state_name] + else: + optimizer_state_key_to_tp_keys[opt_state_name_removed_tp_rank].append(opt_state_name) + + for opt_state_name_removed_tp_rank, opt_state_name in optimizer_state_key_to_tp_keys.items(): + opt_state_name.sort(key=lambda x: int(x[-2:])) + + for opt_state_name_removed_tp_rank, opt_state_name in optimizer_state_key_to_tp_keys.items(): + model_state_name = self.optimizer_key_to_model_state_key(opt_state_name_removed_tp_rank) + local_shape = optimizer_state_dict[opt_state_name[0]].shape + if ( + ".beta1_pow_acc" not in opt_state_name_removed_tp_rank + and ".beta2_pow_acc" not in opt_state_name_removed_tp_rank + ): + global_shape = self.model_state_global_shape[model_state_name] + else: + global_shape = (1,) + + if len(local_shape) != 1: + assert len(local_shape) == len(global_shape) + + axis = -1 + for i in range(len(local_shape)): + if local_shape[i] != global_shape[i]: + axis = i + break + + is_replicated = axis == -1 + tp_tensors = [] + for opt_state_name_with_tp_rank in opt_state_name: + tp_tensors.append(optimizer_state_dict[opt_state_name_with_tp_rank]) + + if not is_replicated: + # Derive the partition strategy based on the global_shape, then concatenate. + concat_optimier_state_dict[opt_state_name_removed_tp_rank] = paddle.concat(tp_tensors, axis=axis) + else: + concat_optimier_state_dict[opt_state_name_removed_tp_rank] = tp_tensors[0] + + fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" + local_tensor_meta_data = {} + local_tensor_index = {} + for k, v in concat_optimier_state_dict.items(): + # Generate metadata. + local_shape = v.shape + global_offset = tuple([0] * len(local_shape)) + dtype = str(v.dtype).split(".")[1] + local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) + local_tensor_index[k] = [LocalTensorIndex(k, global_offset), fake_file_name] + + global_local_tensor_meta_data = [] + global_local_tensor_index = [] + + use_dist = True if paddle.distributed.get_world_size() > 1 else False + + if use_dist: + paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) + paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) + else: + global_local_tensor_meta_data = [local_tensor_meta_data] + global_local_tensor_index = [local_tensor_index] + + state_dict_metadata = {} + for tensor_meta_data in global_local_tensor_meta_data: + for k, v in tensor_meta_data.items(): + if k not in state_dict_metadata: + state_dict_metadata[k] = [v] + else: + state_dict_metadata[k].append(v) + + storage_metadata = {} + for tensor_index in global_local_tensor_index: + for k, v in tensor_index.items(): + storage_metadata[v[0]] = v[1] + + meta_data = Metadata(state_dict_metadata, storage_metadata, None) + source_state_dict = {fake_file_name: concat_optimier_state_dict} + return meta_data, source_state_dict + + elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: + return self.gen_metadata_for_tp_sharded_tensor() + else: + if self.is_sharding_stage3: + for state_name, shard_info in self.global_sharded_tensor_infos.items(): + shard_info.sort(key=lambda x: x[0]["sharding_rank"]) + state_dict_metadata = {} + storage_metadata = {} + # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. + for state_name, shard_info in self.global_sharded_tensor_infos.items(): + global_offset = 0 + for item in shard_info: + if len(item[1]) == 1: + local_tensor_meta_data = LocalTensorMetadata((global_offset,), item[1], item[2]) + local_tensor_index = LocalTensorIndex(state_name, (global_offset,)) + global_offset += item[1][0] + else: + global_offset = tuple([0] * len(item[1])) + local_tensor_meta_data = LocalTensorMetadata(global_offset, item[1], item[2]) + local_tensor_index = LocalTensorIndex(state_name, global_offset) + if state_name not in state_dict_metadata: + state_dict_metadata[state_name] = [local_tensor_meta_data] + else: + state_dict_metadata[state_name].append(local_tensor_meta_data) + storage_metadata[local_tensor_index] = item[3] + + metadata_for_merge_sharding = Metadata(state_dict_metadata, storage_metadata, None) + model_state_shapes = [] + dtype = "" + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(MODEL_WEIGHT_SUFFIX): + for k, v in state_dict.items(): + model_state_shapes.append([k, v.shape]) + dtype = str(v.dtype).split(".")[1] + + dtypes = self.gather_global_object([dtype]) + for dtype_s in dtypes: + if len(dtype_s) > 0: + dtype = dtype_s + + assert len(dtype) > 0 + + global_model_state_shapes = self.gather_global_object(model_state_shapes) + + partition_result = self.partition_parameters( + global_model_state_shapes, True, paddle.distributed.get_world_size() + ) + + cur_rank_merger_model_params = partition_result[self.cur_rank] + target_state_dict = {} + for item in cur_rank_merger_model_params: + key = item[0] + shape = item[1] + flatten_shape = reduce(lambda a, b: a * b, item[1]) + target_state_dict[key] = paddle.zeros(shape, dtype) + target_state_dict[key + ".moment1"] = paddle.zeros((flatten_shape,), "float32") + target_state_dict[key + ".moment2"] = paddle.zeros((flatten_shape,), "float32") + if self.optimizer_state_with_master_weights: + target_state_dict[key + ".master_weight"] = paddle.zeros((flatten_shape,), "float32") + # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned. + # Later, when these are compared with the global shape, we realize that they are replicated. + + target_state_dict[key + ".beta1_pow_acc"] = paddle.zeros((1,), "float32") + target_state_dict[key + ".beta2_pow_acc"] = paddle.zeros((1,), "float32") + + _load_state_dict(target_state_dict, self.cur_rank_loaded_state_dict, [metadata_for_merge_sharding]) + + # Reshape + for item in cur_rank_merger_model_params: + key = item[0] + shape = item[1] + for k, v in target_state_dict.items(): + if key == self.optimizer_key_to_model_state_key(k): + if tuple(shape) != tuple(v.shape) and v.numel() == reduce(lambda x, y: x * y, shape): + reshaped_v = v.reshape(shape) + target_state_dict[k] = reshaped_v + + fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" + local_tensor_meta_data = {} + local_tensor_index = {} + for k, v in target_state_dict.items(): + # Generate metadata. + local_shape = v.shape + global_offset = tuple([0] * len(local_shape)) + dtype = str(v.dtype).split(".")[1] + local_tensor_meta_data[k] = LocalTensorMetadata(global_offset, local_shape, dtype) + local_tensor_index[k] = [LocalTensorIndex(k, global_offset), fake_file_name] + + global_local_tensor_meta_data = [] + global_local_tensor_index = [] + + use_dist = True if paddle.distributed.get_world_size() > 1 else False + + if use_dist: + paddle.distributed.all_gather_object(global_local_tensor_meta_data, local_tensor_meta_data) + paddle.distributed.all_gather_object(global_local_tensor_index, local_tensor_index) + else: + global_local_tensor_meta_data = [local_tensor_meta_data] + global_local_tensor_index = [local_tensor_index] + + state_dict_metadata = {} + for tensor_meta_data in global_local_tensor_meta_data: + for k, v in tensor_meta_data.items(): + if k not in state_dict_metadata: + state_dict_metadata[k] = [v] + else: + state_dict_metadata[k].append(v) + + storage_metadata = {} + for tensor_index in global_local_tensor_index: + for k, v in tensor_index.items(): + storage_metadata[v[0]] = v[1] + + meta_data = Metadata(state_dict_metadata, storage_metadata, None) + source_state_dict = {fake_file_name: target_state_dict} + + return meta_data, source_state_dict + else: + return self.gen_metadata_for_tp_sharded_tensor() + + def load_state_dict_and_rename(self): + """ + Parse the distributed information from the names of the checkpoint files and evenly parse out the distributed information for each weight/optimizer state + into self.global_sharded_tensor_infos (data structure: param_name -> [{tp_rank: 1, sharding_rank: 1}, shape, dtype, file_name]). Modify the names of the + optimizer states in the form of parameter+suffix and record them in self.cur_rank_loaded_state_dict (data structure: file_name -> renamed_state_dict). + 1. Load balancing: Each rank parses a portion of the checkpoint files. + 2. Flatten master_weights in opt_state into opt_state. + 3. Rename the keys in opt_state according to the rule: adamw_optimizer_param_suffix_name_mapping. + 4. Optimizer state renaming and distributed information extraction: + * If it is sharding_stage1/2_v2 version: + * Renaming: rename_using_model_meta: In this case, a model_meta file is required. According to this file, + obtain the name mapping of weights and optimizer parameters, so that the optimizer states of manual and static partitions can correspond. + * Distributed information extraction: Record the distributed information of parameters: name -> [{tp_rank, sharding_rank}, shape, dtype, file_name]. + * If it is sharding_stage1/2_v1 version: + * Renaming: + * If a model_meta file exists: + * rename_using_model_meta + * If a model_meta file does not exist: + * According to the characteristics of v1 partitioning, infer the mapping relationship between optimizer states and weights (partition_result): master_weight_name_to_model_weight_name_mapping. + * Distributed information extraction: Record the distributed information of parameters: name -> [{tp_rank}, shape, dtype, file_name] (parameters will not be sharded). + * If it is sharding_stage3: + * Renaming: + * If a model_meta file exists: + * rename_using_model_meta + * If a model_meta file does not exist: + * Establish the mapping between weights and optimizer names according to the order of optimizer states and weights: rename_using_optimizer_state_order. + * Distributed information extraction: Record the distributed information of parameters: name -> [{tp_rank, sharding_rank}, shape, dtype, file_name]. + """ + rank_access_files = {} + if self.is_model_state_stored: + rank_access_files[self.cur_rank] = ( + self.cur_rank_model_state_file_names + self.cur_rank_optimizer_state_file_names + ) + else: + rank_access_files[self.cur_rank] = self.cur_rank_optimizer_state_file_names + + global_rank_access_files = self.gather_global_object(rank_access_files) + need_read_files = get_rank_to_read_files(global_rank_access_files, global_rank_access_files) + logger.info(f"The file(s) to be loaded for the current rank are: {need_read_files}") + self.cur_rank_loaded_state_dict = {} + + for file in need_read_files: + self.cur_rank_loaded_state_dict[file] = paddle.load(os.path.join(self.path, file)) + + self.optimizer_state_with_master_weights = False + + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + state_dict.pop("LR_Scheduler") + if "master_weights" in state_dict: + self.optimizer_state_with_master_weights = True + master_weights = state_dict.pop("master_weights") + for master_weight_name, master_weight_value in master_weights.items(): + # In sharding stage3, ‘@slice’ will be added in front of the key for master_weight, which is removed here. + state_dict[master_weight_name.replace("slice@", "") + ".master_weight"] = master_weight_value + + self.cur_rank_loaded_state_dict[file] = state_dict + + memory_size = 0 + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + for k, v in state_dict.items(): + memory_size += v.numel() * v.element_size() + + memory_size = memory_size.numpy() / 2**20 + logger.debug( + f"The current rank has finished loading the checkpoint file and has allocated {memory_size} MB of GPU memory." + ) + + # After the rank has finished loading the files it needs, it can infer sharding_stage1_v and is_sharding_stage3. + self.sharding_stage1_v = self.infer_sharding_stage1_v() + self.is_sharding_stage3 = self.infer_is_sharding_stage3() + + flags = [ + ["is sharding stage1/2", (not self.is_sharding_stage3) and self.sharding_degree > 1], + ["sharding stage1/2 version", self.sharding_stage1_v], + ["is sharding stage3", self.is_sharding_stage3], + ["master_weight", self.optimizer_state_with_master_weights], + ] + self.print_checkpoint_file_info(flags) + + # In sharding stage3, the parameters need to be reordered based on whether they are sliced. + # The threshold for determining whether to slice is segment_size, with a default value of 2**20. + # However, sharding stage3 allows users to specify their own unsliced layers, which seems to be incompatible here. + if self.is_sharding_stage3: + logger.info("The currently loaded checkpoint file comes from sharding stage 3.") + segment_size = 2**20 + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(MODEL_WEIGHT_SUFFIX): + sliced_prameters = [] + unsliced_parameters = [] + sorted_state_dict = {} + for k, v in state_dict.items(): + if v.numel() > segment_size: + sliced_prameters.append(k) + else: + unsliced_parameters.append(k) + for k in sliced_prameters + unsliced_parameters: + sorted_state_dict[k] = state_dict.pop(k) + self.cur_rank_loaded_state_dict[file] = sorted_state_dict + + # rename and record sharded_tensor_info + cur_rank_sharded_tensor_infos = {} + + # 1. Handling the sharding stage1 v2 scenario, where the save_sharded_model flag must be enabled, independent of master_weights. + if self.sharding_degree > 1 and self.sharding_stage1_v == 2 and not self.is_sharding_stage3: + logger.info("The currently loaded checkpoint file comes from sharding stage1 v2.") + assert self.is_model_meta_exists + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, + # and then append the tp_degree. + renamed_state_dict = self.rename_using_model_meta(file) + self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) + self.cur_rank_loaded_state_dict[file] = renamed_state_dict + # 2. In handling the sharding stage1 v1 and stage2 scenario, the optimizer states are distributed across different ranks. + # We need to obtain the name mapping by simulating the partitioning method, without concern for the presence of master_weights. + elif self.sharding_degree > 1 and self.sharding_stage1_v == 1 and not self.is_sharding_stage3: + logger.info("The currently loaded checkpoint file comes from sharding stage1/2 v1.") + if not self.is_model_meta_exists: + file_to_state_dict_shapes_mapping = {} + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + shapes = [] + for state_name, state_value in state_dict.items(): + shapes.append([state_name, state_value.shape]) + file_to_state_dict_shapes_mapping[file] = shapes + + global_file_to_state_dict_shapes_mapping = self.gather_global_object(file_to_state_dict_shapes_mapping) + + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + sharding_optimizer_state_shards = [] + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + for k, v in global_file_to_state_dict_shapes_mapping.items(): + (tp_rank_, pp_rank_, sharding_rank_) = self.get_distribution_rank_from_file_name(k) + if tp_rank == tp_rank_ and pp_rank == pp_rank_ and k.endswith(OPTIMIZER_WEIGHT_SUFFIX): + sharding_optimizer_state_shards.append([v, sharding_rank_]) + model_state_file_name = self.get_model_state_file_from(file) + model_state_shapes = global_file_to_state_dict_shapes_mapping[model_state_file_name] + sharding_optimizer_state_shards.sort(key=lambda x: x[1]) + + partition_result_0 = self.partition_parameters(model_state_shapes, False, self.sharding_degree) + partition_result_1 = self.partition_parameters(model_state_shapes, True, self.sharding_degree) + + for rank, portion in partition_result_0.items(): + portion = sorted(portion, key=model_state_shapes.index) + partition_result_0[rank] = portion + + for rank, portion in partition_result_1.items(): + portion = sorted(portion, key=model_state_shapes.index) + partition_result_1[rank] = portion + + sharding_sort_parameters = False + + for i in range(len(sharding_optimizer_state_shards)): + if not sharding_sort_parameters: + state_shard = sharding_optimizer_state_shards[i][0] + partitioned_shard = partition_result_0[i] + for j in range(len(partitioned_shard)): + if partitioned_shard[j][1] != state_shard[j][1]: + sharding_sort_parameters = True + break + + if sharding_sort_parameters: + for i in range(len(sharding_optimizer_state_shards)): + state_shard = sharding_optimizer_state_shards[i][0] + partitioned_shard = partition_result_1[i] + for j in range(len(partitioned_shard)): + assert partitioned_shard[j][1] == state_shard[j][1] + + if sharding_sort_parameters: + partition_result = partition_result_1 + else: + partition_result = partition_result_0 + + name_mapping = {} + for i in range(len(sharding_optimizer_state_shards)): + state_shard = sharding_optimizer_state_shards[i][0] + partitioned_shard = partition_result[i] + suffix_bucket = {} + for suffix in OPTIMIZER_STATE_NAME_SUFFIX: + suffix_bucket[suffix] = [] + for j in range(len(state_shard)): + optimizer_state_name = state_shard[j][0] + if "moment1" in optimizer_state_name: + suffix_bucket[".moment1"].append(optimizer_state_name) + elif "moment2" in optimizer_state_name: + suffix_bucket[".moment2"].append(optimizer_state_name) + elif "beta1_pow_acc" in optimizer_state_name: + suffix_bucket[".beta1_pow_acc"].append(optimizer_state_name) + elif "beta2_pow_acc" in optimizer_state_name: + suffix_bucket[".beta2_pow_acc"].append(optimizer_state_name) + else: + suffix_bucket[".master_weight"].append(optimizer_state_name) + + # In this scenario, the order of master_weights might differ from the order of the regular optimizer states and needs to be reordered. + if len(suffix_bucket[".master_weight"]) != 0: + master_weight_keys = [] + for master_weight_key in suffix_bucket[".master_weight"]: + for index in range(len(state_shard)): + if master_weight_key[: -len(".master_weight")] in state_shard[index][0]: + # Find the first match + master_weight_keys.append([master_weight_key, index]) + break + + master_weight_keys = sorted(master_weight_keys, key=lambda x: x[1]) + suffix_bucket[".master_weight"] = [x[0] for x in master_weight_keys] + + for suffix, old_names in suffix_bucket.items(): + assert len(old_names) == len(partitioned_shard) + for k in range(len(old_names)): + name_mapping[old_names[k]] = partitioned_shard[k][0] + suffix + + renamed_state_dict = {} + # In this branch, sharding does not split the optimizer states; it merely relocates them to different cards. + # Therefore, the sharding information can now be directly removed. + for opt_state_name, opt_state_value in state_dict.items(): + renamed_state_dict[name_mapping[opt_state_name]] = opt_state_value + + self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) + + self.cur_rank_loaded_state_dict[file] = renamed_state_dict + else: + self.get_sharded_tensor_infos(file, state_dict, cur_rank_sharded_tensor_infos) + else: + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + renamed_state_dict = self.rename_using_model_meta(file) + self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) + + self.cur_rank_loaded_state_dict[file] = renamed_state_dict + else: + # 3. Handling the sharding stage3 and non-sharding scenario + + file_to_state_dict_keys_mapping = {} + for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): + file_to_state_dict_keys_mapping[file_name] = list(state_dict.keys()) + global_file_to_state_dict_keys_mapping = self.gather_global_object(file_to_state_dict_keys_mapping) + + logger.info("The current checkpoint comes from either sharding stage 3 or non-sharding.") + if not self.is_model_meta_exists: + for file_name, state_dict in self.cur_rank_loaded_state_dict.items(): + if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): + model_state_file_name = self.get_model_state_file_from(file_name) + assert model_state_file_name is not None + model_state_keys = global_file_to_state_dict_keys_mapping[model_state_file_name] + renamed_state_dict = self.rename_using_optimizer_state_order(model_state_keys, state_dict) + self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) + self.cur_rank_loaded_state_dict[file_name] = renamed_state_dict + else: + self.get_sharded_tensor_infos(file_name, state_dict, cur_rank_sharded_tensor_infos) + else: + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, + # and then append the tp_degree. + renamed_state_dict = self.rename_using_model_meta(file) + self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) + self.cur_rank_loaded_state_dict[file] = renamed_state_dict + + # gather global sharded tensor infos + sharded_tensor_infos = self.gather_global_object({self.cur_rank: cur_rank_sharded_tensor_infos}) + self.global_sharded_tensor_infos = {} + for rank, sharded_tensor_info in sharded_tensor_infos.items(): + for state_name, shard_info in sharded_tensor_info.items(): + if state_name not in self.global_sharded_tensor_infos: + self.global_sharded_tensor_infos[state_name] = shard_info + else: + self.global_sharded_tensor_infos[state_name] += shard_info + logger.debug(f"global_sharded_tensor_infos: {self.global_sharded_tensor_infos}") + + def get_sharded_tensor_infos(self, file, state_dict, cur_rank_sharded_tensor_infos): + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file) + for state_name, state_value in state_dict.items(): + if state_name not in cur_rank_sharded_tensor_infos: + cur_rank_sharded_tensor_infos[state_name] = [ + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + state_value.shape, + str(state_value.dtype).split(".")[1], + file, + ] + ] + else: + cur_rank_sharded_tensor_infos[state_name].append( + [ + {"tp_rank": tp_rank, "sharding_rank": sharding_rank}, + state_value.shape, + str(state_value.dtype).split(".")[1], + file, + ] + ) + + def gen_metadata_for_tp_sharded_tensor(self): + """ + Based on the distributed information of each weight/optimizer state (global_sharded_tensor_infos), construct Metadata + information: LocalTensorMetadata,LocalTensorIndex + """ + for state_name, shard_info in self.global_sharded_tensor_infos.items(): + shard_info.sort(key=lambda x: x[0]["tp_rank"]) + + state_dict_metadata = {} + storage_metadata = {} + + # After obtaining the local_shape and sharding rank of each tensor, the global offset of each tensor can be calculated. + for state_name, shard_info in self.global_sharded_tensor_infos.items(): + + global_offset = 0 + local_shape = shard_info[0][1] + + model_state_name = self.optimizer_key_to_model_state_key(state_name) + if ".beta1_pow_acc" not in state_name and ".beta2_pow_acc" not in state_name: + global_shape = self.model_state_global_shape[model_state_name] + else: + global_shape = (1,) + assert len(local_shape) == len(global_shape) + axis = -1 + for i in range(len(local_shape)): + if local_shape[i] != global_shape[i]: + axis = i + break + + is_replicated = axis == -1 + global_offset = [0] * len(local_shape) + + if is_replicated: + shard_info = [shard_info[0]] + + for item in shard_info: + local_tensor_meta_data = LocalTensorMetadata(tuple(global_offset), item[1], item[2]) + local_tensor_index = LocalTensorIndex(state_name, tuple(global_offset)) + global_offset[axis] += item[1][axis] + if state_name not in state_dict_metadata: + state_dict_metadata[state_name] = [local_tensor_meta_data] + else: + state_dict_metadata[state_name].append(local_tensor_meta_data) + storage_metadata[local_tensor_index] = item[3] + + metadata = Metadata(state_dict_metadata, storage_metadata, None) + source_state_dict = self.cur_rank_loaded_state_dict + + return metadata, source_state_dict + + def rename_using_model_meta(self, file_name): + """ + Rename the keys in opt_state_dict based on the following rule: model_meta records a mapping of parameter names to optimizer names. + Here, we unify the optimizer state names to parameter names directly. For example: + * model_meta: linear0 -> param0 + * opt_state: param0.w0 + * Renamed opt_state: linear0.w0 + NOTE:The reason for renaming is that there is a difference in the naming of optimizer parameters between dynamic and static partitions, + making it difficult to match optimizer parameters directly by name. Therefore, we unify them to the weight names. + """ + if not hasattr(self, "model_meta"): + meta_file_path = os.path.join(self.path, MODEL_META_FILE_NAME) + assert os.path.exists(meta_file_path) + with open(meta_file_path, "r") as file: + self.model_meta = json.load(file) + + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) + dist_strategy_key = "tp" + "{:02d}".format(tp_rank) + "_" + "pp" + "{:02d}".format(pp_rank) + # Map model weight names to their corresponding names of master_weights in the optimizer state. + if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): + structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"] + parameter_to_structured_name = {} + for k, v in structure_name_mapping.items(): + parameter_to_structured_name[v] = k + state_dict = self.cur_rank_loaded_state_dict[file_name] + return self.rename_using_parameter_to_structured_name_mapping(state_dict, parameter_to_structured_name) + else: + return self.cur_rank_loaded_state_dict[file_name] + + def rename_auto_parallel_state_dict(self): + """ + Rename the keys of the auto parallel state_dict according to certain rules: + 1. Rename the suffixes of the optimizer states to a unified format: adamw_optimizer_status_name_suffix_mappings + """ + self.auto_parallel_state_dict = self.rename_using_parameter_to_structured_name_mapping( + self.auto_parallel_state_dict, self.parameter_to_structured_name + ) + + def rename_using_parameter_to_structured_name_mapping(self, state_dict, parameter_to_structured_name): + renamed_state_dict = {} + + def rename(old_name, parameter_to_structured_name): + for i in range(1, len(old_name) + 1): + param_name = old_name[:i] # param_name + suffix = old_name[i:] # suffix + if param_name in parameter_to_structured_name: + structure_name = parameter_to_structured_name[param_name] + if "moment1" in suffix: + return structure_name + ".moment1" + elif "moment2" in suffix: + return structure_name + ".moment2" + elif "beta1_pow_acc" in suffix: + return structure_name + ".beta1_pow_acc" + elif "beta2_pow_acc" in suffix: + return structure_name + ".beta2_pow_acc" + else: + return structure_name + ".master_weight" + return None + + for key, value in state_dict.items(): + if key in parameter_to_structured_name.values(): + new_name = key + else: + new_name = rename(key, parameter_to_structured_name) + assert new_name is not None + renamed_state_dict[new_name] = value + + return renamed_state_dict + + def rename_using_optimizer_state_order(self, model_state_keys, optimizer_state_dict): + + name_mapping = {} + suffix_bucket = {} + assert len(optimizer_state_dict) % len(model_state_keys) == 0 + for suffix in OPTIMIZER_STATE_NAME_SUFFIX: + suffix_bucket[suffix] = [] + for satte_name, satte_value in optimizer_state_dict.items(): + if "moment1" in satte_name: + suffix_bucket[".moment1"].append(satte_name) + elif "moment2" in satte_name: + suffix_bucket[".moment2"].append(satte_name) + elif "beta1_pow_acc" in satte_name: + suffix_bucket[".beta1_pow_acc"].append(satte_name) + elif "beta2_pow_acc" in satte_name: + suffix_bucket[".beta2_pow_acc"].append(satte_name) + else: + suffix_bucket[".master_weight"].append(satte_name) + + for suffix, old_names in suffix_bucket.items(): + assert len(old_names) == len(model_state_keys) + for i in range(len(old_names)): + name_mapping[old_names[i]] = model_state_keys[i] + suffix + + renamed_state_dict = {} + for k, v in optimizer_state_dict.items(): + renamed_state_dict[name_mapping[k]] = v + return renamed_state_dict + + def partition_parameters(self, model_state_shapes, is_sort, shard_num): + """ + In sharding_stage3 and sharding_stage1_v1, parameters and optimizer states will be assigned to different ranks. This function defines the allocation rules. + For details, refer to: python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py. + """ + mapping = {} + for rank_ in range(shard_num): + mapping[rank_] = [] + sizes = [0] * shard_num + + parameters = model_state_shapes.copy() + + if is_sort: + parameters.sort(key=lambda p: reduce(lambda x, y: x * y, p[1]), reverse=True) + + for param in parameters: + rank = sizes.index(min(sizes)) + mapping[rank].append(param) + numel = reduce(lambda x, y: x * y, param[1], 1) + assert numel > 0, f"param [{param[0]}] should larger than 0, but it is [{numel}]" + sizes[rank] += numel + + return mapping + + def get_is_model_meta_exists_flag(self): + save_sharded_model_flag = self.gather_global_object( + [os.path.exists(os.path.join(self.path, MODEL_META_FILE_NAME))] + ) + return True in save_sharded_model_flag + + def get_is_model_state_stored_flag(self): + if len(self.global_model_state_file_names) == 0: + return False + model_state_file_name = self.global_model_state_file_names[0] + file_readable = model_state_file_name in self.cur_rank_model_state_file_names + file_readables = self.gather_global_object([file_readable]) + coordinator_rank = file_readables.index(True) + is_model_state_stored = False + if self.cur_rank == coordinator_rank: + model_state_file_size = os.path.getsize(os.path.join(self.path, model_state_file_name)) + if model_state_file_size > MODEL_STATE_FILE_MIN_SIZE: + is_model_state_stored = True + + is_model_state_stored_flags = self.gather_global_object([is_model_state_stored]) + return True in is_model_state_stored_flags + + def flatten_state_dict(self, state_dict): + flattened_state_dict = {} + flat_state_dict, mapping = flatten_state_dict(state_dict) + for k, v in flat_state_dict.items(): + last_level_key = mapping[k][-1] + assert last_level_key not in flattened_state_dict + flattened_state_dict[last_level_key] = v + return flattened_state_dict + + def gather_global_object(self, cur_rank_object): + all_rank_objects = [] + if self.use_dist: + paddle.distributed.all_gather_object(all_rank_objects, cur_rank_object) + else: + all_rank_objects = [all_rank_objects] + + if isinstance(cur_rank_object, list): + for obj in all_rank_objects: + assert isinstance(obj, list) + return [item for sublist in all_rank_objects for item in sublist] + elif isinstance(cur_rank_object, dict): + for obj in all_rank_objects: + assert isinstance(obj, dict) + global_map = {} + for rank_map in all_rank_objects: + global_map.update(rank_map) + return global_map + else: + raise ValueError("cur_rank_object should be either a list or a dict") + + def get_local_checkpoint_file_names(self): + cur_rank_files = os.listdir(self.path) + cur_rank_model_state_file_names = [] + cur_rank_optimizer_state_file_names = [] + for file_name in cur_rank_files: + if file_name.endswith(MODEL_WEIGHT_SUFFIX): + cur_rank_model_state_file_names.append(file_name) + elif file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): + cur_rank_optimizer_state_file_names.append(file_name) + if SCHEDULER_NAME in cur_rank_model_state_file_names: + cur_rank_model_state_file_names.remove(SCHEDULER_NAME) + return cur_rank_model_state_file_names, cur_rank_optimizer_state_file_names + + def get_distribution_rank_from_file_name(self, file_name): + pp_degree = 0 + tp_degree = 0 + sharding_degree = 0 + pattern_pp = r"pp(\d+)" + pattern_tp = r"tp(\d+)" + pattern_shard = r"shard(\d+)" + match_pp = re.search(pattern_pp, file_name) + if match_pp: + pp_degree = int(match_pp.group(1)) + match_tp = re.search(pattern_tp, file_name) + if match_tp: + tp_degree = int(match_tp.group(1)) + match_shard = re.search(pattern_shard, file_name) + if match_shard: + sharding_degree = int(match_shard.group(1)) + return (tp_degree, pp_degree, sharding_degree) + + def initial_distributed_configuration(self): + self.pp_degree = 0 + self.tp_degree = 0 + self.sharding_degree = 0 + + all_files = self.global_model_state_file_names + self.global_optimizer_state_file_names + + for file in all_files: + (tp_degree, pp_degree, sharding_degree) = self.get_distribution_rank_from_file_name(file) + self.pp_degree = max(self.pp_degree, pp_degree) + self.tp_degree = max(self.tp_degree, tp_degree) + self.sharding_degree = max(self.sharding_degree, sharding_degree) + + self.pp_degree = self.pp_degree + 1 + self.tp_degree = self.tp_degree + 1 + self.sharding_degree = self.sharding_degree + 1 + + def infer_sharding_stage1_v(self): + sharding_stage1_v = [2] + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX) and sharding_stage1_v[0] == 2: + for k, v in state_dict.items(): + # Under shardingv2, the optimizer state is first flattened and then split. + if len(v.shape) != 1: + sharding_stage1_v = [1] + break + + sharding_stage1_v = self.gather_global_object(sharding_stage1_v) + if 1 in sharding_stage1_v: + return 1 + return 2 + + def infer_is_sharding_stage3(self): + if self.sharding_degree == 1: + return False + if self.pp_degree > 1 or self.tp_degree > 1: + # Currently, sharding stage 3 does not support concurrent use with tensor parallelism (TP) and pipeline parallelism (PP). + return False + + is_sharding_stage3 = True + + file_to_state_shape_mapping = {} + for file, state_dict in self.cur_rank_loaded_state_dict.items(): + if file.endswith(OPTIMIZER_WEIGHT_SUFFIX): + state_shape_mapping = {} + for k, v in state_dict.items(): + state_shape_mapping[k] = v.shape + if len(v.shape) != 1: + return False + file_to_state_shape_mapping[file] = state_shape_mapping + global_file_to_state_shape_mapping = self.gather_global_object(file_to_state_shape_mapping) + + state_dict_std = global_file_to_state_shape_mapping[list(global_file_to_state_shape_mapping.keys())[0]] + + for file, state_dict in global_file_to_state_shape_mapping.items(): + if state_dict != state_dict_std: + is_sharding_stage3 = False + break + return is_sharding_stage3 + + def get_model_state_file_from(self, optimizer_state_file_name): + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(optimizer_state_file_name) + for model_state_file in self.global_model_state_file_names: + distributed_rank = self.get_distribution_rank_from_file_name(model_state_file) + if tp_rank == distributed_rank[0] and pp_rank == distributed_rank[1]: + return model_state_file + return None + + def optimizer_key_to_model_state_key(self, optimizer_key): + model_state_key = optimizer_key + for suffix in OPTIMIZER_STATE_NAME_SUFFIX: + if model_state_key.endswith(suffix): + # Remove the suffix from model_state_key + model_state_key = model_state_key[: -len(suffix)] + break + return model_state_key + + def print_checkpoint_file_info(self, flags): + processed_flags = [ + [str(item) if not isinstance(item, bool) else "True" if item else "False" for item in row] for row in flags + ] + + logger.info("Checkpoint file info:") + headers = ["Flag", "Value"] + col_widths = [max(len(str(item)) for item in column) for column in zip(headers, *flags)] + format_str = "| " + " | ".join(f"{{:<{width}}}" for width in col_widths) + " |" + separator_line = "+-" + "-+-".join("-" * width for width in col_widths) + "-+" + + logger.info(separator_line) + logger.info(format_str.format(*headers)) + logger.info(separator_line) + for row in processed_flags: + logger.info(format_str.format(*row)) + logger.info(separator_line)