-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model LUKE #1677
Add model LUKE #1677
Conversation
@@ -0,0 +1,30 @@ | |||
#encoding=utf8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前PaddleNLP支持python3.6以上版本,无须指定编码
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -0,0 +1,237 @@ | |||
# encoding=utf8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
parser = argparse.ArgumentParser(description="LUKE FOR OPEN ENTITY") | ||
|
||
parser.add_argument("--output_dir", type=str, required=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args请天假每个argument的description
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
parser.add_argument("--max_mention_length", type=str, default=30) | ||
|
||
args = parser.parse_args() | ||
args.tokenizer = LukeTokenizer.from_pretrained(args.model_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tokenizer一定要作为全局变量使用吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
f.entity_attention_mask for f in features | ||
] | ||
self.all_labels = [f.labels for f in features] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个数据集多大?
不建议一次性将所有数据集加载进来,这样有可能造成占用内存溢出。
可以通过继承自paddle.io.Dataset
, 以迭代的方式返回数据。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
两个数据集,open_entity数据集不到1MB,SQuAD1.1数据集不到30MB.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已根据您的建议修改
max_len - len(each_batch[k]))) | ||
return np.array(new_data, dtype='int64') | ||
|
||
return ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
给下注释,解释189 - 196 行代码的意义。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已解释
|
||
class DataGenerator(Dataset): | ||
def __init__(self, features, args): | ||
super(DataGenerator, self).__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -0,0 +1,187 @@ | |||
# encoding=utf8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
def load_examples(args, evaluate=False): | ||
args.evaluate = evaluate | ||
features = [] | ||
if not evaluate: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议增加data file参数,控制加载的数据集。省去if-else的判断。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
input_data = json.load(reader)["data"] | ||
return self._create_examples(input_data) | ||
|
||
# def __init__(self, qas_id, title, question_text, context_text, answers, is_impossible=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删去无用注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
parser.add_argument( | ||
"--data_dir", type=str, required=True, help="Dataset folder") | ||
parser.add_argument( | ||
"--eval_batch_size", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议统一用batch_size,不用区分eval_batch_size
和 train_batch_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
all_mentions = mentions_a + mentions_b | ||
if all_mentions: | ||
print(all_mentions) | ||
exit() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
192 - 194 行 代码作用?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此代码是LUKE模型作者提供的实体检测代码,我们使用他的代码检测实体,但在SQuAD1.1数据集上并没有检测到实体,我们已经把实体检测相关代码删除。
min_mention_link_prob=args.min_mention_link_prob, | ||
segment_b_id=0, | ||
add_extra_sep_token=True, | ||
is_training='train' in data_file) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
以上参数是否还有简化的空间呢?可以看看是否有默认参数可以省去。目前API参数量太多了,可读性差。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码已优化
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results | ||
# in one example possible giving several features when a context is long, each of those features having a | ||
# context that overlaps a bit the context of the previous feature. | ||
#NOTE: Almost the same functionality as HuggingFace's prepare_train_features function. The main difference is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# NOTE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results | ||
# in one example possible giving several features when a context is long, each of those features having a | ||
# context that overlaps a bit the context of the previous feature. | ||
#NOTE: Almost the same functionality as HuggingFace's prepare_train_features function. The main difference is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# NOTE:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
set_seed(args) | ||
if rank == 0: | ||
if os.path.exists(args.model_name_or_path): | ||
print("init checkpoint from %s" % args.model_name_or_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Loads checkpoints from %s."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
input_ids, token_type_ids, attention_mask, start_positions, end_positions = batch | ||
logits = model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉末尾逗号
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
dev_batchify_fn = lambda samples, fn=Dict({ | ||
"input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id), | ||
"token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id), | ||
"attention_mask": Pad(axis=0, pad_val=tokenizer.pad_token_type_id), }): fn(samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attention_mask是否一定需要传入指定吗?比如BERT可以根据pad token id确定是否attention_mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
无需指定,已删除attention_mask
@@ -0,0 +1,113 @@ | |||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个trainer.py的作用是?似乎没看见用到。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用于在open entity数据集上训练,我已将其与run_open_entity.py合并
Description
Add new model LUKE
The model weight:
链接:https://pan.baidu.com/s/17aC-27kjJdEaGT6nZt5T_Q
提取码:i4p2