Skip to content

Commit

Permalink
Flake8 (#19)
Browse files Browse the repository at this point in the history
* make flake8 happy

* fix1

* fix2

* consistency

* flake8 compliant
  • Loading branch information
ronakice authored May 13, 2020
1 parent 58f0244 commit 2905235
Show file tree
Hide file tree
Showing 25 changed files with 342 additions and 171 deletions.
4 changes: 4 additions & 0 deletions pygaggle/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

from .logger import *
4 changes: 4 additions & 0 deletions pygaggle/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

from .kaggle import *
from .relevance import *
from .msmarco import *
18 changes: 12 additions & 6 deletions pygaggle/data/kaggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from pygaggle.rerank.base import Query, Text


__all__ = ['MISSING_ID', 'LitReviewCategory', 'LitReviewAnswer', 'LitReviewDataset', 'LitReviewSubcategory']
__all__ = ['MISSING_ID', 'LitReviewCategory', 'LitReviewAnswer',
'LitReviewDataset', 'LitReviewSubcategory']


MISSING_ID = '<missing>'
Expand Down Expand Up @@ -45,7 +46,8 @@ def from_file(cls, filename: str) -> 'LitReviewDataset':
return cls(**json.load(f))

def query_answer_pairs(self, split: str = 'nq'):
return ((subcat.nq_name if split == 'nq' else subcat.kq_name, ans) for cat in self.categories
return ((subcat.nq_name if split == 'nq' else subcat.kq_name, ans)
for cat in self.categories
for subcat in cat.sub_categories
for ans in subcat.answers)

Expand Down Expand Up @@ -80,8 +82,10 @@ def to_senticized_dataset(self,
mean_stats['Random P@1'].append(np.mean(int_rels))
n = len(int_rels) - p
N = len(int_rels)
mean_stats['Random R@3'].append(1 - (n * (n - 1) * (n - 2)) / (N * (N - 1) * (N - 2)))
numer = np.array([sp.comb(n, i) / (N - i) for i in range(0, n + 1)]) * p
mean_stats['Random R@3'].append(1 - (n * (n - 1) * (n - 2)) / (N *
(N - 1) * (N - 2)))
numer = np.array([sp.comb(n, i) / (N - i)
for i in range(0, n + 1)]) * p
denom = np.array([sp.comb(N, i) for i in range(0, n + 1)])
rr = 1 / np.arange(1, n + 2)
rmrr = np.sum(numer * rr / denom)
Expand All @@ -90,5 +94,7 @@ def to_senticized_dataset(self,
logging.warning(f'{doc_id} has no relevant answers')
for k, v in mean_stats.items():
logging.info(f'{k}: {np.mean(v)}')
return [RelevanceExample(Query(query), list(map(lambda s: Text(s, dict(docid=docid)), sents)), rels)
for ((query, docid), sents), (_, rels) in zip(example_map.items(), rel_map.items())]
return [RelevanceExample(Query(query), list(map(lambda s: Text(s,
dict(docid=docid)), sents)), rels)
for ((query, docid), sents), (_, rels) in
zip(example_map.items(), rel_map.items())]
65 changes: 36 additions & 29 deletions pygaggle/data/msmarco.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from collections import OrderedDict, defaultdict
from typing import List, Set, DefaultDict
import json
import logging
from itertools import permutations

Expand All @@ -10,7 +9,6 @@
import numpy as np

from .relevance import RelevanceExample, MsMarcoPassageLoader
from pygaggle.model.tokenize import SpacySenticizer
from pygaggle.rerank.base import Query, Text
from pygaggle.data.unicode import convert_to_unicode

Expand All @@ -24,6 +22,7 @@ class MsMarcoExample(BaseModel):
candidates: List[str]
relevant_candidates: Set[str]


class MsMarcoDataset(BaseModel):
examples: List[MsMarcoExample]

Expand Down Expand Up @@ -55,40 +54,38 @@ def load_run(cls, path: str):
return sorted_run

@classmethod
def load_queries(cls,
path: str,
qrels: DefaultDict[str, Set[str]],
def load_queries(cls,
path: str,
qrels: DefaultDict[str, Set[str]],
run) -> List[MsMarcoExample]:
queries = []
with open(path) as f:
for i, line in enumerate(f):
qid, query = line.rstrip().split('\t')
candidates = run[qid]
queries.append(MsMarcoExample(qid = qid,
text = query,
candidates = run[qid],
relevant_candidates = qrels[qid]))
queries.append(MsMarcoExample(qid=qid,
text=query,
candidates=run[qid],
relevant_candidates=qrels[qid]))
return queries

@classmethod
def from_folder(cls,
folder: str,
split: str = 'dev',
def from_folder(cls,
folder: str,
split: str = 'dev',
is_duo: bool = False) -> 'MsMarcoDataset':
run_mono = "mono." if is_duo else ""
query_path = os.path.join(folder, f"queries.{split}.small.tsv")
qrels_path = os.path.join(folder, f"qrels.{split}.small.tsv")
run_path = os.path.join(folder, f"run.{run_mono}{split}.small.tsv")
return cls(examples = cls.load_queries(query_path,
cls.load_qrels(qrels_path),
cls.load_run(run_path)))

return cls(examples=cls.load_queries(query_path,
cls.load_qrels(qrels_path),
cls.load_run(run_path)))

def query_passage_tuples(self, is_duo: bool = False):
return (((ex.qid, ex.text, ex.relevant_candidates), perm_pas) for ex in self.examples
return (((ex.qid, ex.text, ex.relevant_candidates), perm_pas)
for ex in self.examples
for perm_pas in permutations(ex.candidates, r=1+int(is_duo)))


def to_relevance_examples(self,
index_path: str,
is_duo: bool = False) -> List[RelevanceExample]:
Expand All @@ -100,21 +97,25 @@ def to_relevance_examples(self,
example_map[qid][1].append([cand for cand in cands][0])
try:
passages = [loader.load_passage(cand) for cand in cands]
example_map[qid][2].append([convert_to_unicode(passage.all_text) for passage in passages][0])
except ValueError as e:
example_map[qid][2].append(
[convert_to_unicode(passage.all_text)
for passage in passages][0])
except ValueError:
logging.warning(f'Skipping {passages}')
continue
example_map[qid][3].append(cands[0] in rel_cands)
mean_stats = defaultdict(list)
for ex in self.examples:
int_rels = np.array(list(map(int, example_map[ex.qid][3])))
p = int_rels.sum()/(len(ex.candidates) - 1) if is_duo else int_rels.sum()
p = int_rels.sum()/(len(ex.candidates) - 1) if is_duo \
else int_rels.sum()
mean_stats['Random P@1'].append(np.mean(int_rels))
n = len(ex.candidates) - p
N = len(ex.candidates)
if len(ex.candidates) <= 1000:
mean_stats['Random R@1000'].append(1 if 1 in int_rels else 0)
numer = np.array([sp.comb(n, i) / (N - i) for i in range(0, n + 1) if i!=N]) * p
numer = np.array([sp.comb(n, i) / (N - i) for i in range(0, n + 1)
if i != N]) * p
if n == N:
numer = np.append(numer, 0)
denom = np.array([sp.comb(N, i) for i in range(0, n + 1)])
Expand All @@ -127,11 +128,17 @@ def to_relevance_examples(self,
for rel_cand in ex.relevant_candidates:
if rel_cand in ex.candidates:
ex_index = min(ex.candidates.index(rel_cand), ex_index)
mean_stats['Existing MRR'].append(1 / (ex_index + 1) if ex_index < len(ex.candidates) else 0)
mean_stats['Existing MRR@10'].append(1 / (ex_index + 1) if ex_index < 10 else 0)
mean_stats['Existing MRR'].append(1 / (ex_index + 1)
if ex_index < len(ex.candidates)
else 0)
mean_stats['Existing MRR@10'].append(1 / (ex_index + 1)
if ex_index < 10 else 0)
for k, v in mean_stats.items():
logging.info(f'{k}: {np.mean(v)}')
return [RelevanceExample(Query(text=query_text, id=qid),
list(map(lambda s: Text(s[1], dict(docid=s[0])), zip(cands, cands_text))),
rel_cands) \
for qid, (query_text, cands, cands_text, rel_cands) in example_map.items()]
return [RelevanceExample(Query(text=query_text, id=qid),
list(map(lambda s: Text(s[1],
dict(docid=s[0])),
zip(cands, cands_text))),
rel_cands)
for qid, (query_text, cands, cands_text, rel_cands) in
example_map.items()]
3 changes: 2 additions & 1 deletion pygaggle/data/relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def load_document(self, id: str) -> Cord19Document:
def unfold(entries):
return '\n'.join(x['text'] for x in entries)
try:
article = json.loads(self.searcher.doc(id).lucene_document().get('raw'))
article = json.loads(
self.searcher.doc(id).lucene_document().get('raw'))
except json.decoder.JSONDecodeError:
raise ValueError('article not found')
except AttributeError:
Expand Down
5 changes: 3 additions & 2 deletions pygaggle/data/unicode.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
"""
Converts `text` to Unicode (if it's not already) assuming utf-8 input."""
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
raise ValueError(f"Unsupported string type: f{type(text)}")
3 changes: 2 additions & 1 deletion pygaggle/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
__all__ = []


coloredlogs.install(level='INFO', fmt='%(asctime)s [%(levelname)s] %(module)s: %(message)s')
coloredlogs.install(level='INFO',
fmt='%(asctime)s [%(levelname)s] %(module)s: %(message)s')
4 changes: 4 additions & 0 deletions pygaggle/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

from .decode import *
from .encode import *
from .evaluate import *
Expand Down
18 changes: 11 additions & 7 deletions pygaggle/model/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,33 @@
from transformers import PreTrainedModel
import torch


__all__ = ['greedy_decode']

DecodedOutput = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]


@torch.no_grad()
def greedy_decode(model: PreTrainedModel,
input_ids: torch.Tensor,
length: int,
attention_mask: torch.Tensor = None,
return_last_logits: bool = True) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return_last_logits: bool = True) -> DecodedOutput:
decode_ids = torch.full((input_ids.size(0), 1),
model.config.decoder_start_token_id,
dtype=torch.long).to(input_ids.device)
past = model.get_encoder()(input_ids, attention_mask=attention_mask)
next_token_logits = None
for _ in range(length):
model_inputs = model.prepare_inputs_for_generation(decode_ids,
past=past,
attention_mask=attention_mask,
use_cache=True)
model_inputs = model.prepare_inputs_for_generation(
decode_ids,
past=past,
attention_mask=attention_mask,
use_cache=True)
outputs = model(**model_inputs) # (batch_size, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size, vocab_size)
decode_ids = torch.cat([decode_ids, next_token_logits.max(1)[1].unsqueeze(-1)], dim=-1)
decode_ids = torch.cat([decode_ids,
next_token_logits.max(1)[1].unsqueeze(-1)],
dim=-1)
past = outputs[1]
if return_last_logits:
return decode_ids, next_token_logits
Expand Down
37 changes: 25 additions & 12 deletions pygaggle/model/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from pygaggle.rerank.base import TextType


__all__ = ['LongBatchEncoder', 'EncoderOutputBatch', 'SingleEncoderOutput', 'SpecialTokensCleaner']
__all__ = ['LongBatchEncoder', 'EncoderOutputBatch', 'SingleEncoderOutput',
'SpecialTokensCleaner']


@dataclass
Expand All @@ -26,10 +27,12 @@ class EncoderOutputBatch:
texts: List[TextType]

def as_single(self) -> 'SingleEncoderOutput':
return SingleEncoderOutput(self.encoder_output[0], self.token_ids[0], self.texts[0])
return SingleEncoderOutput(self.encoder_output[0],
self.token_ids[0], self.texts[0])

def __iter__(self):
return iter(SingleEncoderOutput(enc_out, token_ids, text) for enc_out, token_ids, text
return iter(SingleEncoderOutput(enc_out, token_ids, text) for
(enc_out, token_ids, text)
in zip(self.encoder_output, self.token_ids, self.texts))


Expand All @@ -38,14 +41,17 @@ def __init__(self, tokenizer: PreTrainedTokenizer):
self.special_ids = tokenizer.all_special_ids

def clean(self, output: SingleEncoderOutput) -> SingleEncoderOutput:
indices = [idx for idx, tok in enumerate(output.token_ids.tolist()) if tok not in self.special_ids]
return SingleEncoderOutput(output.encoder_output[indices], output.token_ids[indices], output.text)
indices = [idx for idx, tok in enumerate(output.token_ids.tolist())
if tok not in self.special_ids]
return SingleEncoderOutput(output.encoder_output[indices],
output.token_ids[indices], output.text)


class LongBatchEncoder:
"""Encodes batches of documents that are longer than the maximum sequence length by striding a window across
"""
Encodes batches of documents that are longer than the maximum sequence
length by striding a window across
the sequence dimension.
Parameters
----------
encoder : nn.Module
Expand Down Expand Up @@ -79,17 +85,24 @@ def encode(self, batch_input: List[TextType]) -> EncoderOutputBatch:
encode_lst = [[] for _ in input_ids]
new_input_ids = [(idx, x[:max_len]) for idx, x in input_ids]
while new_input_ids:
attn_mask = [[1] * len(x[1]) + [0] * (max_len - len(x[1])) for x in new_input_ids]
attn_mask = [[1] * len(x[1]) +
[0] * (max_len - len(x[1]))
for x in new_input_ids]
attn_mask = torch.tensor(attn_mask).to(self.device)
nonpadded_input_ids = new_input_ids
new_input_ids = [x + [0] * (max_len - len(x[:max_len])) for _, x in new_input_ids]
new_input_ids = [x + [0] * (max_len - len(x[:max_len]))
for _, x in new_input_ids]
new_input_ids = torch.tensor(new_input_ids).to(self.device)
outputs, _ = self.encoder(input_ids=new_input_ids, attention_mask=attn_mask)
outputs, _ = self.encoder(input_ids=new_input_ids,
attention_mask=attn_mask)
for (idx, _), output in zip(nonpadded_input_ids, outputs):
encode_lst[idx].append(output)

new_input_ids = [(idx, x[max_len:]) for idx, x in nonpadded_input_ids if len(x) > max_len]
max_len = min(max(map(lambda x: len(x[1]), new_input_ids), default=0), self.msl)
new_input_ids = [(idx, x[max_len:])
for idx, x in nonpadded_input_ids
if len(x) > max_len]
max_len = min(max(map(lambda x: len(x[1]), new_input_ids),
default=0), self.msl)

encode_lst = list(map(torch.cat, encode_lst))
batch_output.extend(encode_lst)
Expand Down
Loading

0 comments on commit 2905235

Please sign in to comment.