Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model LUKE #1677

Merged
merged 18 commits into from
Mar 10, 2022
Merged

Add model LUKE #1677

merged 18 commits into from
Mar 10, 2022

Conversation

Beacontownfc
Copy link
Contributor

@Beacontownfc Beacontownfc commented Feb 15, 2022

Description
Add new model LUKE
The model weight:
链接:https://pan.baidu.com/s/17aC-27kjJdEaGT6nZt5T_Q
提取码:i4p2

@@ -0,0 +1,30 @@
#encoding=utf8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前PaddleNLP支持python3.6以上版本,无须指定编码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -0,0 +1,237 @@
# encoding=utf8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


parser = argparse.ArgumentParser(description="LUKE FOR OPEN ENTITY")

parser.add_argument("--output_dir", type=str, required=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args请天假每个argument的description

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

parser.add_argument("--max_mention_length", type=str, default=30)

args = parser.parse_args()
args.tokenizer = LukeTokenizer.from_pretrained(args.model_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tokenizer一定要作为全局变量使用吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

f.entity_attention_mask for f in features
]
self.all_labels = [f.labels for f in features]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个数据集多大?
不建议一次性将所有数据集加载进来,这样有可能造成占用内存溢出。
可以通过继承自paddle.io.Dataset, 以迭代的方式返回数据。

Copy link
Contributor Author

@Beacontownfc Beacontownfc Feb 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

两个数据集,open_entity数据集不到1MB,SQuAD1.1数据集不到30MB.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已根据您的建议修改

max_len - len(each_batch[k])))
return np.array(new_data, dtype='int64')

return (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

给下注释,解释189 - 196 行代码的意义。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已解释


class DataGenerator(Dataset):
def __init__(self, features, args):
super(DataGenerator, self).__init__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -0,0 +1,187 @@
# encoding=utf8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def load_examples(args, evaluate=False):
args.evaluate = evaluate
features = []
if not evaluate:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议增加data file参数,控制加载的数据集。省去if-else的判断。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

input_data = json.load(reader)["data"]
return self._create_examples(input_data)

# def __init__(self, qas_id, title, question_text, context_text, answers, is_impossible=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删去无用注释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

parser.add_argument(
"--data_dir", type=str, required=True, help="Dataset folder")
parser.add_argument(
"--eval_batch_size",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议统一用batch_size,不用区分eval_batch_sizetrain_batch_size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

all_mentions = mentions_a + mentions_b
if all_mentions:
print(all_mentions)
exit()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

192 - 194 行 代码作用?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此代码是LUKE模型作者提供的实体检测代码,我们使用他的代码检测实体,但在SQuAD1.1数据集上并没有检测到实体,我们已经把实体检测相关代码删除。

min_mention_link_prob=args.min_mention_link_prob,
segment_b_id=0,
add_extra_sep_token=True,
is_training='train' in data_file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

以上参数是否还有简化的空间呢?可以看看是否有默认参数可以省去。目前API参数量太多了,可读性差。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码已优化

# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
# in one example possible giving several features when a context is long, each of those features having a
# context that overlaps a bit the context of the previous feature.
#NOTE: Almost the same functionality as HuggingFace's prepare_train_features function. The main difference is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# NOTE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
# in one example possible giving several features when a context is long, each of those features having a
# context that overlaps a bit the context of the previous feature.
#NOTE: Almost the same functionality as HuggingFace's prepare_train_features function. The main difference is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# NOTE:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

set_seed(args)
if rank == 0:
if os.path.exists(args.model_name_or_path):
print("init checkpoint from %s" % args.model_name_or_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Loads checkpoints from %s."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

input_ids, token_type_ids, attention_mask, start_positions, end_positions = batch
logits = model(
input_ids=input_ids,
attention_mask=attention_mask, )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉末尾逗号

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

dev_batchify_fn = lambda samples, fn=Dict({
"input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
"token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
"attention_mask": Pad(axis=0, pad_val=tokenizer.pad_token_type_id), }): fn(samples)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention_mask是否一定需要传入指定吗?比如BERT可以根据pad token id确定是否attention_mask

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无需指定,已删除attention_mask

@@ -0,0 +1,113 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个trainer.py的作用是?似乎没看见用到。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用于在open entity数据集上训练,我已将其与run_open_entity.py合并

@yingyibiao yingyibiao merged commit 30ab253 into PaddlePaddle:develop Mar 10, 2022
@Beacontownfc Beacontownfc deleted the luke branch March 10, 2022 23:51
@Beacontownfc Beacontownfc restored the luke branch March 10, 2022 23:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants