Skip to content

Commit

Permalink
Fix qa transformer (#34)
Browse files Browse the repository at this point in the history
* fix qa transformer, update transformers/tokenizers

* add fix model files

* rename model checkpoint
  • Loading branch information
ronakice authored May 27, 2020
1 parent add59ba commit f3485ac
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
10 changes: 4 additions & 6 deletions pygaggle/run/evaluate_kaggle_highlighter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pydantic import BaseModel, validator
from transformers import (AutoModel,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoTokenizer,
BertForQuestionAnswering,
Expand Down Expand Up @@ -130,16 +131,13 @@ def construct_qa_transformer(options: KaggleEvaluationOptions) -> Reranker:
# We load a sequence classification model first -- again, as a workaround.
# Refactor
try:
model = AutoModelForSequenceClassification.from_pretrained(
model = AutoModelForQuestionAnswering.from_pretrained(
options.model_name)
except OSError:
model = AutoModelForSequenceClassification.from_pretrained(
model = AutoModelForQuestionAnswering.from_pretrained(
options.model_name, from_tf=True)
fixed_model = BertForQuestionAnswering(model.config)
fixed_model.qa_outputs = model.classifier
fixed_model.bert = model.bert
device = torch.device(options.device)
model = fixed_model.to(device).eval()
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(
options.tokenizer_name, do_lower_case=options.do_lower_case)
return QuestionAnsweringTransformerReranker(model, tokenizer)
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
coloredlogs==14.0
dataclasses;python_version<"3.7"
numpy==1.18.2
numpy>=1.18
pydantic==1.5
pyserini==0.9.0.0
scikit-learn>=0.22
scipy>=1.4
spacy==2.2.4
tensorboard>=2.1.0
tensorflow>=2.2.0rc1
tokenizers==0.5.2
tokenizers>=0.7
tqdm==4.45.0
transformers==2.7.0
transformers>=2.9.0
6 changes: 6 additions & 0 deletions scripts/rename-checkpoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

mv bert_config.json config.json
for filename in model.ckpt*; do
mv $filename $(python -c "import re; print(re.sub(r'ckpt-\\d+', 'ckpt', '$filename'))");
done

0 comments on commit f3485ac

Please sign in to comment.