Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model LUKE #1677

Merged
merged 18 commits into from
Mar 10, 2022
91 changes: 91 additions & 0 deletions examples/language_model/luke/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# LUKE with PaddleNLP

[LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057)

**模型简介:**
许多NLP任务都涉及实体,例如:关系分类、实体类型、命名实体识别(NER)和问答(QA)。解决此类实体相关任务的关键是学习实体有效表示。传统的实体表示为每个实体分配一个固定的Embedding向量,该向量将有关实体的信息存储在知识库(KB)中。它们需要实体链接(entity linking)来表示文本中的实体,而不能表示KB中不存在的实体。

相比之下,基于contextualized word representations(CWRs) transformer的大型预训练模型,如BERT和RoBERTa,提供了基于语言建模的有效通用词语表征。然而,由于以下两个原因,CWRs的体系结构不适合表示实体:

- 由于CWR不输出实体的跨级(span-level)表示,因此它们通常需要学习如何基于通常较小的下游数据集计算此类表征。

- 许多与实体相关的任务,如关系分类和问答(QA)涉及实体之间关系的推理。尽管transformer可以通过使用self-attention机制将单词相互关联来捕捉单词之间的复杂关系。在实体之间执行关系推理是困难的,因为许多实体在模型中被分割成多个词。此外,基于单词的CWRs预训练任务不适合学习实体的表征,因为在实体中预测一个被MASK的单词,例如预测“Rings”, 给予句子“The Lord of the [MASK]”,一个完整的实体就这样被拆分。

LUKE和现有CWRs之间的一个重要区别在于,它不仅将单词视为独立的token,还将实体视为独立的token,并使用transformer计算所有token的中间表征和输出表征。由于实体被视为token,LUKE可以直接建模实体之间的关系。
本项目是 LUKE 在 Paddle 2.x上的开源实现。

## 快速开始

### 下游任务微调

数据集
下载Open Entity数据集
[下载地址](https://cloud.tsinghua.edu.cn/f/6ec98dbd931b4da9a7f0/)
把下载好的文件解压,并把解压后的Open Entity目录下的`train.json`、`test.json`和`dev.json`分别为训练集、验证集和测试集

下载SQuAD1.1数据集,主流机器阅读理解数据集
[下载地址](https://data.deepai.org/squad1.1.zip)

#### 1、SQuAD1.1
以SQuAD1.1数据集为例

运行以下两个命令即可训练并评估LUKE在SQuAD1.1数据集的精度

```shell
python -m paddle.distributed.launch examples/language_model/luke/run_squad.py
--model_type luke \
--device gpu \
--learning_rate 15e-6 \
--num_train_epochs 2 \
--batch_size 8 \
--do_predict \
--do_train \
--model_name_or_path luke-large
```
其中参数释义如下:
- `model_type` 指示了模型类型,当前支持`luke`
- `batch_size` 表示每次迭代**每张卡**上的样本数目。
- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。
- `device` 表示使用的设备类型。默认为GPU,可以配置为CPU、GPU、XPU。若希望使用多GPU训练,将其设置为GPU,同时环境变量CUDA_VISIBLE_DEVICES配置要使用的GPU id。
- `num_train_epochs` 表示需要训练的epoch数量
- `do_train` 表示是否开启训练
- `do_predict` 表示是否开启评估
- `model_name_or_path` 模型的名称和路径,支持`luke-base` 和 `luke-large`

训练结束后模型会对模型进行评估,其评估在验证集上完成, 训练完成后你将看到如下结果:
```text
{"exact_match": 89.75691579943235, "f1": 94.95702001984502}
```

#### 2、Open Entity

```shell
python -m paddle.distributed.launch examples/language_model/luke/run_open_entity.py \
--model_type luke-large \
--data_dir data/ \
--output_dir output/ \
--device gpu \
--learning_rate 1e-5 \
--num_train_epochs 3 \
--train_batch_size 2
```
训练结束后模型会对模型进行评估,其评估在测试集上完成, 训练完成后你将看到如下结果:
```text
Results: {
"test_f1": 0.7815726767275616,
"test_precision": 0.7880405766150561,
"test_recall": 0.7752100840336135
}
```


# Reference

```bibtex
@inproceedings{yamada2020luke,
title={LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention},
author={Ikuya Yamada and Akari Asai and Hiroyuki Shindo and Hideaki Takeda and Yuji Matsumoto},
booktitle={EMNLP},
year={2020}
}
```
148 changes: 148 additions & 0 deletions examples/language_model/luke/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import argparse


def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--train_file",
type=str,
required=False,
default=None,
help="Train data path.")
parser.add_argument(
"--predict_file",
type=str,
required=False,
default=None,
help="Predict data path.")
parser.add_argument(
"--model_type",
default="bert",
type=str,
help="Type of pre-trained model.")
parser.add_argument(
"--model_name_or_path",
default="bert-base-uncased",
type=str,
help="Path to pre-trained model or shortcut name of model.")
parser.add_argument(
"--output_dir",
default="outputs",
type=str,
help="The output directory where the model predictions and checkpoints will be written. "
"Default as `outputs`")
parser.add_argument(
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.")
parser.add_argument(
"--batch_size",
default=8,
type=int,
help="Batch size per GPU/CPU for training.")
parser.add_argument(
"--learning_rate",
default=5e-5,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument(
"--weight_decay",
default=0.0,
type=float,
help="Weight decay if we apply some.")
parser.add_argument(
"--adam_epsilon",
default=1e-8,
type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument(
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--num_train_epochs",
default=3,
type=int,
help="Total number of training epochs to perform.")
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs."
)
parser.add_argument(
"--warmup_proportion",
default=0.0,
type=float,
help="Proportion of training steps to perform linear learning rate warmup for."
)
parser.add_argument(
"--logging_steps",
type=int,
default=500,
help="Log every X updates steps.")
parser.add_argument(
"--save_steps",
type=int,
default=500,
help="Save checkpoint every X updates steps.")
parser.add_argument(
"--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
'--device',
choices=['cpu', 'gpu'],
default="gpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument(
"--doc_stride",
type=int,
default=128,
help="When splitting up a long document into chunks, how much stride to take between chunks."
)
parser.add_argument(
"--n_best_size",
type=int,
default=20,
help="The total number of n-best predictions to generate in the nbest_predictions.json output file."
)
parser.add_argument(
"--null_score_diff_threshold",
type=float,
default=0.0,
help="If null_score - best_non_null is greater than the threshold predict null."
)
parser.add_argument(
"--max_query_length", type=int, default=64, help="Max query length.")
parser.add_argument(
"--max_answer_length", type=int, default=30, help="Max answer length.")
parser.add_argument(
"--do_lower_case",
action='store_false',
help="Whether to lower case the input text. Should be True for uncased models and False for cased models."
)
parser.add_argument(
"--verbose", action='store_true', help="Whether to output verbose log.")
parser.add_argument(
"--version_2_with_negative",
action='store_true',
help="If true, the SQuAD examples contain some that do not have an answer. If using squad v2.0, it should be set true."
)
parser.add_argument(
"--do_train", action='store_true', help="Whether to train the model.")
parser.add_argument(
"--do_predict", action='store_true', help="Whether to predict.")
args = parser.parse_args()
return args
139 changes: 139 additions & 0 deletions examples/language_model/luke/open_entity_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

from tqdm import tqdm

ENTITY_TOKEN = "[ENTITY]"


class InputExample(object):
def __init__(self, id_, text, span, labels):
self.id = id_
self.text = text
self.span = span
self.labels = labels


class InputFeatures(object):
def __init__(
self,
word_ids,
word_segment_ids,
word_attention_mask,
entity_ids,
entity_position_ids,
entity_segment_ids,
entity_attention_mask,
labels, ):
self.word_ids = word_ids
self.word_segment_ids = word_segment_ids
self.word_attention_mask = word_attention_mask
self.entity_ids = entity_ids
self.entity_position_ids = entity_position_ids
self.entity_segment_ids = entity_segment_ids
self.entity_attention_mask = entity_attention_mask
self.labels = labels


class DatasetProcessor(object):
def get_train_examples(self, data_dir):
return self._create_examples(data_dir, "train")

def get_dev_examples(self, data_dir):
return self._create_examples(data_dir, "dev")

def get_test_examples(self, data_dir):
return self._create_examples(data_dir, "test")

def get_label_list(self, data_dir):
labels = set()
for example in self.get_train_examples(data_dir):
labels.update(example.labels)
return sorted(labels)

def _create_examples(self, data_dir, set_type):
with open(os.path.join(data_dir, set_type + ".json"), "r") as f:
data = json.load(f)
return [
InputExample(i, item["sent"], (item["start"], item["end"]),
item["labels"]) for i, item in enumerate(data)
]


def convert_examples_to_features(examples, label_list, tokenizer,
max_mention_length):
label_map = {label: i for i, label in enumerate(label_list)}

conv_tables = (
("-LRB-", "("),
("-LCB-", "("),
("-LSB-", "("),
("-RRB-", ")"),
("-RCB-", ")"),
("-RSB-", ")"), )
features = []
for example in tqdm(examples):

def preprocess_and_tokenize(text, start, end=None):
target_text = text[start:end].rstrip()
for a, b in conv_tables:
target_text = target_text.replace(a, b)

return tokenizer.tokenize(target_text, add_prefix_space=True)

tokens = [tokenizer.cls_token]
tokens += preprocess_and_tokenize(example.text, 0, example.span[0])
mention_start = len(tokens)
tokens.append(ENTITY_TOKEN)
tokens += preprocess_and_tokenize(example.text, example.span[0],
example.span[1])
tokens.append(ENTITY_TOKEN)
mention_end = len(tokens)

tokens += preprocess_and_tokenize(example.text, example.span[1])
tokens.append(tokenizer.sep_token)

word_ids = tokenizer.convert_tokens_to_ids(tokens)
word_attention_mask = [1] * len(tokens)
word_segment_ids = [0] * len(tokens)

entity_ids = [2, 0]
entity_attention_mask = [1, 0]
entity_segment_ids = [0, 0]
entity_position_ids = list(range(mention_start,
mention_end))[:max_mention_length]
entity_position_ids += [-1] * (
max_mention_length - mention_end + mention_start)
entity_position_ids = [entity_position_ids, [-1] * max_mention_length]

labels = [0] * len(label_map)

for label in example.labels:
labels[label_map[label]] = 1

features.append(
InputFeatures(
word_ids=word_ids,
word_segment_ids=word_segment_ids,
word_attention_mask=word_attention_mask,
entity_ids=entity_ids,
entity_position_ids=entity_position_ids,
entity_segment_ids=entity_segment_ids,
entity_attention_mask=entity_attention_mask,
labels=labels, ))

return features
Loading