Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][A-29] Add type annotations for paddle/nn/initializer/constant.py #65095

Merged
merged 3 commits into from
Jun 13, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions python/paddle/nn/initializer/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import paddle
from paddle import _C_ops

from ...base import core, framework
Expand All @@ -35,13 +37,17 @@ class ConstantInitializer(Initializer):

"""

def __init__(self, value=0.0, force_cpu=False):
def __init__(self, value: float = 0.0, force_cpu: bool = False) -> None:
assert value is not None
super().__init__()
self._value = value
self._force_cpu = force_cpu

def forward(self, var, block=None):
def forward(
self,
var: paddle.Tensor,
block: framework.Block | paddle.pir.Block | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
block: framework.Block | paddle.pir.Block | None = None,
block: paddle.pir.Block | None = None,

统一下吧,静态图逻辑里不暴露老 IR 的概念,动静统一逻辑里不暴露静态图的概念

):
"""Initialize the input tensor with constant.

Args:
Expand All @@ -52,7 +58,6 @@ def forward(self, var, block=None):
Returns:
The initialization op
"""
import paddle

block = self._check_block(block)

Expand Down Expand Up @@ -135,7 +140,7 @@ class Constant(ConstantInitializer):

"""

def __init__(self, value=0.0):
def __init__(self, value: float = 0.0) -> None:
if value is None:
raise ValueError("value must not be none.")
super().__init__(value=value, force_cpu=False)