-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model LUKE #1677
Add model LUKE #1677
Changes from 12 commits
9b7f955
df5981f
f7ec4f4
23fcc3f
8e3c285
8d492e5
ced44a6
0f729ef
abd989c
3e2a455
ce88391
8bf00b9
5ff11c9
01eb3c2
308cf45
a368ee3
2923237
9f498c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# LUKE with PaddleNLP | ||
|
||
[LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) | ||
|
||
**模型简介:** | ||
许多NLP任务都涉及实体,例如:关系分类、实体类型、命名实体识别(NER)和问答(QA)。解决此类实体相关任务的关键是学习实体有效表示。传统的实体表示为每个实体分配一个固定的Embedding向量,该向量将有关实体的信息存储在知识库(KB)中。它们需要实体链接(entity linking)来表示文本中的实体,而不能表示KB中不存在的实体。 | ||
|
||
相比之下,基于contextualized word representations(CWRs) transformer的大型预训练模型,如BERT和RoBERTa,提供了基于语言建模的有效通用词语表征。然而,由于以下两个原因,CWRs的体系结构不适合表示实体: | ||
|
||
- 由于CWR不输出实体的跨级(span-level)表示,因此它们通常需要学习如何基于通常较小的下游数据集计算此类表征。 | ||
|
||
- 许多与实体相关的任务,如关系分类和问答(QA)涉及实体之间关系的推理。尽管transformer可以通过使用self-attention机制将单词相互关联来捕捉单词之间的复杂关系。在实体之间执行关系推理是困难的,因为许多实体在模型中被分割成多个词。此外,基于单词的CWRs预训练任务不适合学习实体的表征,因为在实体中预测一个被MASK的单词,例如预测“Rings”, 给予句子“The Lord of the [MASK]”,一个完整的实体就这样被拆分。 | ||
|
||
LUKE和现有CWRs之间的一个重要区别在于,它不仅将单词视为独立的token,还将实体视为独立的token,并使用transformer计算所有token的中间表征和输出表征。由于实体被视为token,LUKE可以直接建模实体之间的关系。 | ||
本项目是 LUKE 在 Paddle 2.x上的开源实现。 | ||
|
||
## 快速开始 | ||
|
||
### 下游任务微调 | ||
|
||
数据集 | ||
下载Open Entity数据集 | ||
[下载地址](https://cloud.tsinghua.edu.cn/f/6ec98dbd931b4da9a7f0/) | ||
把下载好的文件解压,并把解压后的Open Entity目录下的`train.json`、`test.json`和`dev.json`分别为训练集、验证集和测试集 | ||
|
||
下载SQuAD1.1数据集,主流机器阅读理解数据集 | ||
[下载地址](https://data.deepai.org/squad1.1.zip) | ||
|
||
同时需要下载由LUKE官方提供的维基百科(实体)数据集 | ||
[下载地址](https://drive.google.com/file/d/129tDJ3ev6IdbJiKOmO6GTgNANunhO_vt/view) | ||
|
||
#### 1、SQuAD1.1 | ||
以SQuAD1.1数据集为例 | ||
LUKE做阅读理解较为特殊,需要提供能够自动感知实体的数据集文件并需要安装以下环境: | ||
|
||
```shell | ||
pip install wikipedia2vec==1.0.5 | ||
pip install regex | ||
``` | ||
|
||
运行以下两个命令即可训练并评估LUKE在SQuAD1.1数据集的精度 | ||
|
||
```shell | ||
python -m paddle.distributed.launch examples/language_model/luke/run_squad.py \ | ||
--model_type luke-base \ | ||
--data_dir data/ | ||
--output_dir output/ \ | ||
--device gpu | ||
--learning_rate 12e-6 \ | ||
--num_train_epochs 2 \ | ||
--train_batch_size 8 \ | ||
--do_train \ | ||
--do_eval | ||
``` | ||
其中参数释义如下: | ||
- `model_type` 指示了模型类型,当前支持`luke-base`和`luke-large`模型。 | ||
- `data_dir` 数据集路径。 | ||
- `train_batch_size` 表示每次迭代**每张卡**上的样本数目。 | ||
- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。 | ||
- `output_dir` 表示模型保存路径。 | ||
- `device` 表示使用的设备类型。默认为GPU,可以配置为CPU、GPU、XPU。若希望使用多GPU训练,将其设置为GPU,同时环境变量CUDA_VISIBLE_DEVICES配置要使用的GPU id。 | ||
- `num_train_epochs` 表示需要训练的epoch数量 | ||
- `do_train` 表示是否开启训练 | ||
- `do_eval` 表示是否开启评估 | ||
|
||
训练结束后模型会对模型进行评估,其评估在验证集上完成, 训练完成后你将看到如下结果: | ||
```text | ||
{"exact_match": 89.75691579943235, "f1": 94.95702001984502} | ||
``` | ||
|
||
#### 2、Open Entity | ||
|
||
```shell | ||
python -m paddle.distributed.launch examples/language_model/luke/run_open_entity.py \ | ||
--model_type luke-base \ | ||
--data_dir data/ \ | ||
--output_dir output/ \ | ||
--device gpu \ | ||
--learning_rate 1e-5 \ | ||
--num_train_epochs 3 \ | ||
--train_batch_size 2 | ||
``` | ||
训练结束后模型会对模型进行评估,其评估在测试集上完成, 训练完成后你将看到如下结果: | ||
```text | ||
Results: { | ||
"test_f1": 0.7815726767275616, | ||
"test_precision": 0.7880405766150561, | ||
"test_recall": 0.7752100840336135 | ||
} | ||
``` | ||
|
||
|
||
# Reference | ||
|
||
```bibtex | ||
@inproceedings{yamada2020luke, | ||
title={LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention}, | ||
author={Ikuya Yamada and Akari Asai and Hiroyuki Shindo and Hideaki Takeda and Yuji Matsumoto}, | ||
booktitle={EMNLP}, | ||
year={2020} | ||
} | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
import argparse | ||
from paddle.io import Dataset, DataLoader | ||
import numpy as np | ||
from paddlenlp.transformers import LukeTokenizer | ||
from paddlenlp.transformers import LukeForEntityClassification | ||
from utils.open_entity_processor import convert_examples_to_features, DatasetProcessor | ||
import paddle | ||
import json | ||
from tqdm import tqdm | ||
from utils.trainer import Trainer | ||
import os | ||
|
||
ENTITY_TOKEN = "[ENTITY]" | ||
|
||
parser = argparse.ArgumentParser(description="LUKE FOR OPEN ENTITY") | ||
|
||
parser.add_argument( | ||
"--output_dir", | ||
type=str, | ||
required=True, | ||
help="Use to store all outputs during training and evaluation.") | ||
parser.add_argument( | ||
"--data_dir", type=str, required=True, help="Dataset folder") | ||
parser.add_argument( | ||
"--eval_batch_size", | ||
type=int, | ||
default=32, | ||
help="Batch size per GPU/CPU for evaluating.") | ||
parser.add_argument( | ||
"--num_train_epochs", type=int, default=2, help="Number of training cycles") | ||
parser.add_argument( | ||
"--seed", type=int, default=42, help="random seed for initialization") | ||
parser.add_argument( | ||
"--train_batch_size", | ||
type=int, | ||
default=8, | ||
help="Batch size per GPU/CPU for training.") | ||
parser.add_argument( | ||
"--device", | ||
type=str, | ||
default='gpu', | ||
help="Batch size per GPU/CPU for training.") | ||
parser.add_argument( | ||
"--gradient_accumulation_steps", | ||
type=int, | ||
default=3, | ||
help="Gradient accumulated before each parameter update.") | ||
parser.add_argument( | ||
"--weight_decay", | ||
type=float, | ||
default=0.01, | ||
help="Weight decay if we apply some") | ||
parser.add_argument( | ||
"--warmup_proportion", | ||
type=float, | ||
default=0.06, | ||
help="Proportion of training steps to perform linear learning rate warmup for." | ||
) | ||
parser.add_argument( | ||
"--learning_rate", | ||
type=float, | ||
default=1e-5, | ||
help="The initial learning rate for Adam.") | ||
parser.add_argument( | ||
"--model_type", | ||
type=str, | ||
default='luke-base', | ||
help="Type of pre-trained model.") | ||
parser.add_argument( | ||
"--max_mention_length", | ||
type=int, | ||
default=30, | ||
help="Max entity position's length") | ||
|
||
args = parser.parse_args() | ||
|
||
|
||
class DataGenerator(Dataset): | ||
def __init__(self, features): | ||
super(DataGenerator, self).__init__() | ||
self.features = features | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个数据集多大? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 两个数据集,open_entity数据集不到1MB,SQuAD1.1数据集不到30MB. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已根据您的建议修改 |
||
def __getitem__(self, item): | ||
word_ids = self.features[item].word_segment_ids | ||
word_segment_ids = self.features[item].word_segment_ids | ||
word_attention_mask = self.features[item].word_attention_mask | ||
entity_ids = self.features[item].entity_ids | ||
entity_position_ids = self.features[item].entity_position_ids | ||
entity_segment_ids = self.features[item].entity_segment_ids | ||
entity_attention_mask = self.features[item].entity_attention_mask | ||
labels = self.features[item].labels | ||
|
||
return (word_ids, word_segment_ids, word_attention_mask, entity_ids, | ||
entity_position_ids, entity_segment_ids, entity_attention_mask, | ||
labels) | ||
|
||
def __len__(self): | ||
return len(self.features) | ||
|
||
|
||
@paddle.no_grad() | ||
def evaluate(args, model, fold="dev", output_file=None): | ||
dataloader, _, _, label_list = load_examples(args, fold=fold) | ||
model.eval() | ||
|
||
all_logits = [] | ||
all_labels = [] | ||
|
||
for batch in tqdm(dataloader, desc=fold): | ||
logits = model( | ||
input_ids=batch[0], | ||
token_type_ids=batch[1], | ||
attention_mask=batch[2], | ||
entity_ids=batch[3], | ||
entity_position_ids=batch[4], | ||
entity_segment_ids=batch[5], | ||
entity_attention_mask=batch[6]) | ||
|
||
logits = logits.tolist() | ||
labels = batch[7].tolist() | ||
|
||
all_logits.extend(logits) | ||
all_labels.extend(labels) | ||
|
||
all_predicted_indexes = [] | ||
all_label_indexes = [] | ||
for logits, labels in zip(all_logits, all_labels): | ||
all_predicted_indexes.append([i for i, v in enumerate(logits) if v > 0]) | ||
all_label_indexes.append([i for i, v in enumerate(labels) if v > 0]) | ||
|
||
if output_file: | ||
with open(output_file, "w") as f: | ||
for predicted_indexes, label_indexes in zip(all_predicted_indexes, | ||
all_label_indexes): | ||
data = dict( | ||
predictions=[label_list[ind] for ind in predicted_indexes], | ||
labels=[label_list[ind] for ind in label_indexes], ) | ||
f.write(json.dumps(data) + "\n") | ||
|
||
num_predicted_labels = 0 | ||
num_gold_labels = 0 | ||
num_correct_labels = 0 | ||
|
||
for predicted_indexes, label_indexes in zip(all_predicted_indexes, | ||
all_label_indexes): | ||
num_predicted_labels += len(predicted_indexes) | ||
num_gold_labels += len(label_indexes) | ||
num_correct_labels += len( | ||
frozenset(predicted_indexes).intersection( | ||
frozenset(label_indexes))) | ||
|
||
if num_predicted_labels > 0: | ||
precision = num_correct_labels / num_predicted_labels | ||
else: | ||
precision = 0.0 | ||
|
||
recall = num_correct_labels / num_gold_labels | ||
if precision + recall == 0.0: | ||
f1 = 0.0 | ||
else: | ||
f1 = 2 * precision * recall / (precision + recall) | ||
|
||
return dict(precision=precision, recall=recall, f1=f1) | ||
|
||
|
||
def load_examples(args, fold="train"): | ||
Beacontownfc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tokenizer = LukeTokenizer.from_pretrained(args.model_type) | ||
tokenizer.add_special_tokens(dict(additional_special_tokens=[ENTITY_TOKEN])) | ||
processor = DatasetProcessor() | ||
if fold == "train": | ||
examples = processor.get_train_examples(args.data_dir) | ||
elif fold == "dev": | ||
examples = processor.get_dev_examples(args.data_dir) | ||
else: | ||
examples = processor.get_test_examples(args.data_dir) | ||
|
||
label_list = processor.get_label_list(args.data_dir) | ||
|
||
logging.info("Creating features from the dataset...") | ||
features = convert_examples_to_features(examples, label_list, tokenizer, | ||
args.max_mention_length) | ||
|
||
dataset = DataGenerator(features) | ||
|
||
def collate_fn(batch): | ||
def create_padded_sequence(k, padding_value): | ||
"""Pad sequence to maximum length""" | ||
new_data = [] | ||
max_len = 0 | ||
for each_batch in batch: | ||
if len(each_batch[k]) > max_len: | ||
max_len = len(each_batch[k]) | ||
for each_batch in batch: | ||
new_data.append(each_batch[k] + [padding_value] * ( | ||
max_len - len(each_batch[k]))) | ||
return np.array(new_data, dtype='int64') | ||
|
||
return ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 给下注释,解释189 - 196 行代码的意义。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已解释 |
||
create_padded_sequence(0, 1), # pad word_ids | ||
create_padded_sequence(1, 0), # pad word_segment_ids | ||
create_padded_sequence(2, 0), # pad word_attention_mask | ||
create_padded_sequence(3, 0), # pad entity_ids | ||
create_padded_sequence(4, 0), # pad entity_position_ids | ||
create_padded_sequence(5, 0), # pad entity_segment_ids | ||
create_padded_sequence(6, 0), # pad entity_attention_mask | ||
create_padded_sequence(7, 0), ) # convert to numpy array | ||
|
||
if fold in ("dev", "test"): | ||
dataloader = DataLoader( | ||
dataset, | ||
batch_size=args.eval_batch_size, | ||
shuffle=False, | ||
collate_fn=collate_fn) | ||
else: | ||
dataloader = DataLoader( | ||
dataset, | ||
shuffle=True, | ||
batch_size=args.train_batch_size, | ||
collate_fn=collate_fn) | ||
|
||
return dataloader, examples, features, label_list | ||
|
||
|
||
if __name__ == '__main__': | ||
results = {} | ||
train_dataloader, _, features, _ = load_examples(args, fold="train") | ||
num_labels = len(features[0].labels) | ||
num_train_steps_per_epoch = len( | ||
train_dataloader) // args.gradient_accumulation_steps | ||
num_train_steps = int(num_train_steps_per_epoch * args.num_train_epochs) | ||
model = LukeForEntityClassification.from_pretrained( | ||
args.model_type, num_classes=num_labels) | ||
trainer = Trainer( | ||
args, | ||
model=model, | ||
dataloader=train_dataloader, | ||
num_train_steps=num_train_steps) | ||
trainer.train(is_op=True) | ||
output_file = os.path.join(args.output_dir, f"test_predictions.jsonl") | ||
results.update({ | ||
f"test_{k}": v | ||
for k, v in evaluate(args, model, 'test', output_file).items() | ||
}) | ||
|
||
print("Results: %s", json.dumps(results, indent=2, sort_keys=True)) | ||
with open(os.path.join(args.output_dir, "results.json"), "w") as f: | ||
json.dump(results, f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议统一用batch_size,不用区分
eval_batch_size
和train_batch_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改