-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PIR-Auto-Parallel] Add sync shared param pass #71167
[PIR-Auto-Parallel] Add sync shared param pass #71167
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
dd15992
to
8749a03
Compare
8749a03
to
a3f58a5
Compare
14d4059
to
848e00c
Compare
return self.num_samples | ||
|
||
|
||
class TestSimpleNetForSemiAutoParallel: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
继承 unittest.TestCase,修改 assert,增加allreduce 判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ranks = sorted(ranks) | ||
if tuple(ranks) in self.comm_group: | ||
return self.comm_group[tuple(ranks)].id | ||
group = new_process_group(ranks, force_new_group=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议加一个注释说明 force_new_group 是 False 时 存在的问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
params, _ = get_pir_parameters(main_program) | ||
for param in params: | ||
users = param.all_used_ops() | ||
for reshard_op in users: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 users 算子不一定时 reshard_op,建议改下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if tmp_param.name == param_name: | ||
dy_param = tmp_param | ||
break | ||
assert dy_param is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码中多处 assert 建议 补充 message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
self.comm_group[tuple(ranks)] = group.id | ||
return group.id | ||
|
||
def init_shared_params(self, main_program, startup_program): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里目的是做参数同步,建议函数名改成 sync_shared_params
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
return new_shared_params | ||
|
||
def allreduce_shared_param_gradient( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里目的是做梯度同步,建议函数名改成 sync_shared_param_gradient
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
return params_grads | ||
|
||
# Only support one shared param. | ||
assert len(self.params_maybe_shared) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是当前只支持1个共享参数的场景吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的
之后可以支持更多共享参数场景
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Auto Parallel
PR Types
Improvements
Description
Optimize performance for syncing shared params in PIR
reference PR : https://github.com/PaddlePaddle/Paddle/pull/65321
Pcard-89509