Skip to content

Commit

Permalink
fix DatasetTuple identifier bug (PaddlePaddle#1941)
Browse files Browse the repository at this point in the history
  • Loading branch information
smallv0221 authored Apr 12, 2022
1 parent 0b9f2bd commit 3a98c9c
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions paddlenlp/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,26 @@ def load_from_ppnlp(path, *args, **kwargs):

class DatasetTuple:
def __init__(self, splits):
self.tuple_cls = namedtuple('datasets', splits)
self.identifier_map, identifiers = self._gen_identifier_map(splits)
self.tuple_cls = namedtuple('datasets', identifiers)
self.tuple = self.tuple_cls(* [None for _ in splits])

def __getitem__(self, key):
if isinstance(key, (int, slice)):
return self.tuple[key]
if isinstance(key, str):
return getattr(self.tuple, key)

def __repr__(self):
return self.tuple.__repr__()
return getattr(self.tuple, self.identifier_map[key])

def __setitem__(self, key, value):
self.tuple = self.tuple._replace(**{key: value})
self.tuple = self.tuple._replace(**{self.identifier_map[key]: value})

def _gen_identifier_map(self, splits):
identifier_map = {}
identifiers = []
for i in range(len(splits)):
identifiers.append('splits_' + str(i))
identifier_map[splits[i]] = 'splits_' + str(i)
return identifier_map, identifiers

def __len__(self):
return len(self.tuple)
Expand Down

0 comments on commit 3a98c9c

Please sign in to comment.