-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Typing][C-34][BUAA] Add type annotations for python/paddle/distributed/fleet/fleet.py
#67624
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有逐个 review, 几个问题:
- 类型的标注尽量简洁,如
paddle.Tensor
改为Tensor
,相应的导入语句放到TYPE_CHECKING
里面 - 泛型的参数化,如
list
需要写为list[xxx]
,以及set
dict
tuple
等 - 不是
callable
,是Callable
,从collections.abc
导入 - 注意
paddle._typing
中通用类型的使用,如place: paddle.CUDAPlace,
是否可以使用paddle._typing
中的类型 Literal
的使用,如def all_reduce(self, input, mode="sum")
中的mode
- 类中实例参数的标注,如
Fleet
的strategy_compiler, user_defined_optimizer
等 - 返回的
Self
的使用 - 私有函数/方法不需要标注
def all_reduce( | ||
self, | ||
input: Tensor, | ||
mode: str = mode_Literal, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Literal
的用法看一下其他 PR 是如何实现的 ~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要删掉默认值???
def apply_ir_passes( | ||
main_program: Program, | ||
startup_program: Program, | ||
config: BuildStrategy, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config: BuildStrategy, | |
config: Fleet, |
self._role_maker = None | ||
self.strategy_compiler = None | ||
self.strategy_compiler: StrategyCompiler = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.strategy_compiler: StrategyCompiler = None | |
self.strategy_compiler: StrategyCompiler | None = None |
is_collective: bool = False, | ||
strategy: DistributedStrategy | None = None, | ||
log_level: int | str = "INFO", | ||
) -> Fleet: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) -> Fleet: | |
) -> Self: |
): | ||
iteration: int, | ||
x: Tensor, | ||
group: HybridCommunicateGroup, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
group: HybridCommunicateGroup, | |
group: Group, |
self, | ||
iteration: int, | ||
x: Tensor, | ||
group: HybridCommunicateGroup, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
group: HybridCommunicateGroup, | |
group: Group, |
@@ -500,7 +546,7 @@ def reduce_scatter_perf( | |||
f"[Perf Warning] ReduceScatter Test Timeout! {ret} > {perf_threshold_time}" | |||
) | |||
|
|||
def _collective_perf_impl(self, round=50, context={}, hcg=None): | |||
def _collective_perf_impl(self, round, context, hcg=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认值呢?
self, | ||
optimizer: Optimizer, | ||
strategy: DistributedStrategy | None = None, | ||
) -> Fleet: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) -> Fleet: | |
) -> Self: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Self在python3.11才引入,此处使用TypeVar实现
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯 好问题 ~ 但是,为什么项目中都在用 Self
呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
抱歉,资料没有找全,python3.11之前Self
由typing_extensions
引入
@@ -1319,10 +1394,14 @@ def set_date(self, table_id, day_id): | |||
|
|||
@is_non_distributed_check | |||
@inited_runtime_handler | |||
def shrink(self, threshold=None): | |||
def shrink(self, threshold: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def shrink(self, threshold: int) -> None: | |
def shrink(self, threshold: int | None = None) -> None: |
self, | ||
loss, | ||
startup_program, | ||
parameter_list, | ||
no_grad_set, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认值呢?
startup_programs, | ||
parameter_list, | ||
no_grad_set, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认值呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
另外,
_inited_runtime_handler_
_is_non_distributed_check_
这两个函数也要标一下,参考:
Any, | ||
Iterable, | ||
Literal, | ||
Sequence, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从 collections.abc
导入 Sequence
Iterable
from paddle import ( | ||
Tensor, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from paddle import ( | |
Tensor, | |
) | |
from paddle import Tensor |
def all_reduce(self, input, mode="sum"): | ||
def all_reduce( | ||
self, | ||
input: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input: int, | |
input: Any, |
应该不只是 int
Operator, | ||
Parameter, | ||
Program, | ||
Scope, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不是这个 Scope
,是 from paddle.base.core import _Scope
TypedDict, | ||
) | ||
|
||
if TYPE_CHECKING: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把 if TYPE_CHECKING
这部分放到 80 行 __all__
上面吧(注意前后留一个空行~),其他应该没啥问题了 ~ 辛苦 🤟
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ~
PR Category
User Experience
PR Types
Improvements
Description
#65008
C-34