-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel] Fix ViT embed #1
[AutoParallel] Fix ViT embed #1
Conversation
@@ -622,7 +623,8 @@ def __init__(self, *args, **kwargs): | |||
# self.is_pretraining = True | |||
|
|||
def _wrap_for_dist_loader(self, train_dataloader): | |||
dist_loader = super()._wrap_for_dist_loader(train_dataloader) | |||
dtensor_idx = [2, 3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dense_tensor_idx 是否好一些,dtensor_idx 容易有歧义
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,我修改下并加上用法描述
num_attention_heads=config.num_attention_heads, | ||
) | ||
|
||
def get_tensor_parallel_split_mappings(num_layers): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数是给动手用的,其中的内容可以去掉,换成pass。因为整个去掉此函数可能导致 from_pretrain 报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,感谢!
super().__init__(config) | ||
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config) | ||
# mesh = fleet.auto.get_mesh() | ||
mesh = dist.ProcessMesh([[0], [1], [2], [3], [4], [5], [6], [7]], dim_names=["dp", "mp"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否这样会通用一些
mesh = dist.get_mesh()
if "pp" in mesh.dim_names:
mesh = mesh.get_mesh_with_dim("pp")[0]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,我修改下,感谢!
是否也要修改 |
是的,这里需要修改,忽略了 |
ColumnParallelLinear = linear_utils.ColumnParallelLinear | ||
RowParallelLinear = linear_utils.RowParallelLinear | ||
|
||
if config.tensor_parallel_degree > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
类似判断 config.tensor_parallel_degree > 1
的分支都可以去掉,这些分支只有动手需要
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
5f32352
into
jeff41404:verify_auto_parallel_intermediate_api_in_paddlemix
Fix ViT embed