Skip to content

Commit

Permalink
[Typing][B-60] Add type annotations for `python/paddle/text/datasets/…
Browse files Browse the repository at this point in the history
…conll05.py` (#65993)
  • Loading branch information
enkilee authored Jul 19, 2024
1 parent f624a93 commit f747bbf
Showing 1 changed file with 49 additions and 18 deletions.
67 changes: 49 additions & 18 deletions python/paddle/text/datasets/conll05.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt

import gzip
import tarfile
Expand Down Expand Up @@ -45,15 +52,15 @@ class Conll05st(Dataset):
only test dataset of Conll05st is public.
Args:
data_file(str): path to data tar file, can be set None if
data_file(str|None): path to data tar file, can be set None if
:attr:`download` is True. Default None
word_dict_file(str): path to word dictionary file, can be set None if
word_dict_file(str|None): path to word dictionary file, can be set None if
:attr:`download` is True. Default None
verb_dict_file(str): path to verb dictionary file, can be set None if
verb_dict_file(str|None): path to verb dictionary file, can be set None if
:attr:`download` is True. Default None
target_dict_file(str): path to target dictionary file, can be set None if
target_dict_file(str|None): path to target dictionary file, can be set None if
:attr:`download` is True. Default None
emb_file(str): path to embedding dictionary file, only used for
emb_file(str|None): path to embedding dictionary file, only used for
:code:`get_embedding` can be set None if :attr:`download` is
True. Default None
download(bool): whether to download dataset automatically if
Expand Down Expand Up @@ -103,14 +110,26 @@ class Conll05st(Dataset):
"""

data_file: str | None
word_dict_file: str | None
verb_dict_file: str | None
target_dict_file: str | None
emb_file: str | None
word_dict: dict[str, int]
predicate_dict: dict[str, int]
label_dict: dict[str, int]
sentences: list
predicates: list
labels: list

def __init__(
self,
data_file=None,
word_dict_file=None,
verb_dict_file=None,
target_dict_file=None,
emb_file=None,
download=True,
data_file: str | None = None,
word_dict_file: str | None = None,
verb_dict_file: str | None = None,
target_dict_file: str | None = None,
emb_file: str | None = None,
download: bool = True,
):
self.data_file = data_file
if self.data_file is None:
Expand Down Expand Up @@ -176,7 +195,7 @@ def __init__(
# read dataset into memory
self._load_anno()

def _load_label_dict(self, filename):
def _load_label_dict(self, filename: str) -> dict[str, int]:
d = {}
tag_dict = set()
with open(filename, 'r') as f:
Expand All @@ -195,14 +214,14 @@ def _load_label_dict(self, filename):
d["O"] = index
return d

def _load_dict(self, filename):
def _load_dict(self, filename: str) -> dict[str, int]:
d = {}
with open(filename, 'r') as f:
for i, line in enumerate(f):
d[line.strip()] = i
return d

def _load_anno(self):
def _load_anno(self) -> None:
tf = tarfile.open(self.data_file)
wf = tf.extractfile(
"conll05st-release/test.wsj/words/test.wsj.words.gz"
Expand Down Expand Up @@ -273,7 +292,19 @@ def _load_anno(self):
wf.close()
tf.close()

def __getitem__(self, idx):
def __getitem__(
self, idx: int
) -> tuple[
npt.NDArray[np.int_],
npt.NDArray[np.int_],
npt.NDArray[np.int_],
npt.NDArray[np.int_],
npt.NDArray[np.int_],
npt.NDArray[np.int_],
npt.NDArray[np.int_],
npt.NDArray[np.int_],
npt.NDArray[np.int_],
]:
sentence = self.sentences[idx]
predicate = self.predicates[idx]
labels = self.labels[idx]
Expand Down Expand Up @@ -332,10 +363,10 @@ def __getitem__(self, idx):
np.array(label_idx),
)

def __len__(self):
def __len__(self) -> int:
return len(self.sentences)

def get_dict(self):
def get_dict(self) -> tuple[dict[str, int], dict[str, int], dict[str, int]]:
"""
Get the word, verb and label dictionary of Wikipedia corpus.
Expand All @@ -351,7 +382,7 @@ def get_dict(self):
"""
return self.word_dict, self.predicate_dict, self.label_dict

def get_embedding(self):
def get_embedding(self) -> str:
"""
Get the embedding dictionary file.
Expand Down

0 comments on commit f747bbf

Please sign in to comment.