Skip to content

Commit

Permalink
Simplify boilerplate for monoT5 and monoBERT (#83)
Browse files Browse the repository at this point in the history
* Simplify boilerplate for monoT5 and monoBERT

* Fold into constructors

* Capitalize class names
  • Loading branch information
yuxuan-ji authored Sep 13, 2020
1 parent a258c13 commit 41513a9
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 53 deletions.
30 changes: 4 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,45 +35,23 @@ Currently, this repo contains implementations of the rerankers for [CovidQA](htt
Here's how to initalize the T5 reranker from [Document Ranking with a Pretrained Sequence-to-Sequence Model](https://arxiv.org/pdf/2003.06713.pdf):

```python
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
from pygaggle.model import T5BatchTokenizer
from pygaggle.rerank.base import Query, Text
from pygaggle.rerank.transformer import T5Reranker
from pygaggle.rerank.transformer import MonoT5

model_name = 'castorini/monot5-base-msmarco'
tokenizer_name = 't5-base'
batch_size = 8

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = T5ForConditionalGeneration.from_pretrained(model_name)
model = model.to(device).eval()

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer = T5BatchTokenizer(tokenizer, batch_size)
reranker = T5Reranker(model, tokenizer)
reranker = MonoT5(model_name, tokenizer_name)
```

Alternatively, here's the BERT reranker from [Passage Re-ranking with BERT](https://arxiv.org/pdf/1901.04085.pdf), which isn't as good as the T5 reranker:

```python
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pygaggle.model import BatchTokenizer
from pygaggle.rerank.base import Query, Text
from pygaggle.rerank.transformer import SequenceClassificationTransformerReranker
from pygaggle.rerank.transformer import MonoBERT

model_name = 'castorini/monobert-large-msmarco'
tokenizer_name = 'bert-large-uncased'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AutoModelForSequenceClassification.from_pretrained(model_name)
model = model.to(device).eval()

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
reranker = SequenceClassificationTransformerReranker(model, tokenizer)
reranker = MonoBERT(model_name, tokenizer_name)
```

Either way, continue with a complere reranking example:
Expand Down
47 changes: 32 additions & 15 deletions pygaggle/rerank/transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from copy import deepcopy
from typing import List
from typing import List, Union

from transformers import (PreTrainedModel,
from transformers import (AutoTokenizer,
AutoModelForSequenceClassification,
PreTrainedModel,
PreTrainedTokenizer,
T5ForConditionalGeneration)
import torch
Expand All @@ -13,21 +15,29 @@
QueryDocumentBatch,
QueryDocumentBatchTokenizer,
SpecialTokensCleaner,
T5BatchTokenizer,
greedy_decode)


__all__ = ['T5Reranker',
__all__ = ['MonoT5',
'UnsupervisedTransformerReranker',
'SequenceClassificationTransformerReranker',
'MonoBERT',
'QuestionAnsweringTransformerReranker']


class T5Reranker(Reranker):
class MonoT5(Reranker):
def __init__(self,
model: T5ForConditionalGeneration,
tokenizer: QueryDocumentBatchTokenizer):
self.model = model
self.tokenizer = tokenizer
model_name_or_instance: Union[str, T5ForConditionalGeneration] = 'castorini/monoT5-base-msmarco',
tokenizer_name_or_instance: Union[str, QueryDocumentBatchTokenizer] = 't5-base'):
if isinstance(model_name_or_instance, str):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name_or_instance = T5ForConditionalGeneration.from_pretrained(model_name_or_instance).to(device).eval()
self.model = model_name_or_instance

if isinstance(tokenizer_name_or_instance, str):
tokenizer_name_or_instance = T5BatchTokenizer(AutoTokenizer.from_pretrained(tokenizer_name_or_instance), batch_size=8)
self.tokenizer = tokenizer_name_or_instance

self.device = next(self.model.parameters(), None).device

def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
Expand Down Expand Up @@ -97,13 +107,20 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
return texts


class SequenceClassificationTransformerReranker(Reranker):
class MonoBERT(Reranker):
def __init__(self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer):
self.tokenizer = tokenizer
self.model = model
self.device = next(model.parameters()).device
model_name_or_instance: Union[str, PreTrainedModel] = 'castorini/monoBERT-large-msmarco',
tokenizer_name_or_instance: Union[str, PreTrainedTokenizer] = 'bert-large-uncased'):
if isinstance(model_name_or_instance, str):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name_or_instance = AutoModelForSequenceClassification.from_pretrained(model_name_or_instance).to(device).eval()
self.model = model_name_or_instance

if isinstance(tokenizer_name_or_instance, str):
tokenizer_name_or_instance = AutoTokenizer.from_pretrained(tokenizer_name_or_instance)
self.tokenizer = tokenizer_name_or_instance

self.device = next(self.model.parameters(), None).device

@torch.no_grad()
def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
Expand Down
8 changes: 4 additions & 4 deletions pygaggle/run/evaluate_document_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from pygaggle.rerank.bm25 import Bm25Reranker
from pygaggle.rerank.transformer import (
UnsupervisedTransformerReranker,
T5Reranker,
SequenceClassificationTransformerReranker
MonoT5,
MonoBERT
)
from pygaggle.rerank.random import RandomReranker
from pygaggle.rerank.similarity import CosineSimilarityMatrixProvider
Expand Down Expand Up @@ -85,7 +85,7 @@ def construct_t5(options: DocumentRankingEvaluationOptions) -> Reranker:
from_tf=options.from_tf).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(options.model_type)
tokenizer = T5BatchTokenizer(tokenizer, options.batch_size)
return T5Reranker(model, tokenizer)
return MonoT5(model, tokenizer)


def construct_transformer(options:
Expand All @@ -106,7 +106,7 @@ def construct_seq_class_transformer(options: DocumentRankingEvaluationOptions
device = torch.device(options.device)
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(options.tokenizer_name)
return SequenceClassificationTransformerReranker(model, tokenizer)
return MonoBERT(model, tokenizer)


def construct_bm25(options: DocumentRankingEvaluationOptions) -> Reranker:
Expand Down
8 changes: 4 additions & 4 deletions pygaggle/run/evaluate_kaggle_highlighter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from pygaggle.rerank.bm25 import Bm25Reranker
from pygaggle.rerank.transformer import (
QuestionAnsweringTransformerReranker,
SequenceClassificationTransformerReranker,
T5Reranker,
MonoBERT,
MonoT5,
UnsupervisedTransformerReranker
)
from pygaggle.rerank.random import RandomReranker
Expand Down Expand Up @@ -82,7 +82,7 @@ def construct_t5(options: KaggleEvaluationOptions) -> Reranker:
tokenizer = AutoTokenizer.from_pretrained(
options.model_name, do_lower_case=options.do_lower_case)
tokenizer = T5BatchTokenizer(tokenizer, options.batch_size)
return T5Reranker(model, tokenizer)
return MonoT5(model, tokenizer)


def construct_transformer(options: KaggleEvaluationOptions) -> Reranker:
Expand Down Expand Up @@ -124,7 +124,7 @@ def construct_seq_class_transformer(options:
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(
options.tokenizer_name, do_lower_case=options.do_lower_case)
return SequenceClassificationTransformerReranker(model, tokenizer)
return MonoBERT(model, tokenizer)


def construct_qa_transformer(options: KaggleEvaluationOptions) -> Reranker:
Expand Down
8 changes: 4 additions & 4 deletions pygaggle/run/evaluate_passage_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from pygaggle.rerank.bm25 import Bm25Reranker
from pygaggle.rerank.transformer import (
UnsupervisedTransformerReranker,
T5Reranker,
SequenceClassificationTransformerReranker
MonoT5,
MonoBERT
)
from pygaggle.rerank.random import RandomReranker
from pygaggle.rerank.similarity import CosineSimilarityMatrixProvider
Expand Down Expand Up @@ -83,7 +83,7 @@ def construct_t5(options: PassageRankingEvaluationOptions) -> Reranker:
from_tf=options.from_tf).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(options.model_type)
tokenizer = T5BatchTokenizer(tokenizer, options.batch_size)
return T5Reranker(model, tokenizer)
return MonoT5(model, tokenizer)


def construct_transformer(options:
Expand Down Expand Up @@ -116,7 +116,7 @@ def construct_seq_class_transformer(options: PassageRankingEvaluationOptions
device = torch.device(options.device)
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(options.tokenizer_name)
return SequenceClassificationTransformerReranker(model, tokenizer)
return MonoBERT(model, tokenizer)


def construct_bm25(options: PassageRankingEvaluationOptions) -> Reranker:
Expand Down

0 comments on commit 41513a9

Please sign in to comment.