Skip to content

Commit

Permalink
mf
Browse files Browse the repository at this point in the history
  • Loading branch information
mam10eks committed Nov 13, 2024
1 parent 6b055af commit 431045f
Showing 1 changed file with 37 additions and 3 deletions.
40 changes: 37 additions & 3 deletions data/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ def passage_ids(doc_id):

return sorted(list(ret))

def all_re_rankers():
from pyterrier_t5 import MonoT5ReRanker
import pyterrier_dr
return [
('mono-t5', lambda: MonoT5ReRanker()),
('colbert', lambda: pyterrier_dr.TctColBert(verbose=True))
('ance', lambda: pyterrier_dr.Ance(verbose=True))
]


@click.command('pooling')
@click.option('--retrieval-index', default='msmarco-passage-v2.1', help='The chatnoir index for pooling.')
Expand Down Expand Up @@ -72,10 +81,10 @@ def main(directory, retrieval_index, feedback_index, corpus_offset):
for doc in results['docno']:
all_docs.add(doc)
print('Corpus-size', len(all_docs))
docs_store = ir_datasets.load('msmarco-segment-v2.1').docs_store()

if not os.path.exists(f'{directory}/pyterrier-index'):
if not os.path.exists(f'{directory}/documents.jsonl.gz'):
docs_store = ir_datasets.load('msmarco-segment-v2.1').docs_store()
with gzip.open(f'{directory}/documents.jsonl.gz', 'wt') as f:
for doc in tqdm(all_docs, 'Load Docs'):
doc = docs_store.get(doc)
Expand All @@ -100,6 +109,30 @@ def main(directory, retrieval_index, feedback_index, corpus_offset):
results = retriever(topics)
pt.io.write_results(results, output_file)

for retrieval_model, reranker in tqdm(all_re_rankers(), 'Re-Rankers'):
for query_type in ['title', 'description']:
output_file = f'{directory}/neural-{retrieval_model}-on{query_type}-run.gz'
if os.path.exists(output_file):
continue

documents = {}
with gzip.open(f'{directory}/documents.jsonl.gz', 'rt') as f:
for l in f:
l = json.loads(l)
documents[l['docno']] = l['text']


def add_text(df):
df['text'] = df['docno'].apply(lambda i: documents[i])
return df

topics = pt.io.read_topics(f'{directory}/topics.xml', 'trecxml', tags=[query_type], tokenise=False)
first_stage = pt.io.read_results(f'{directory}/pyterrier-BM25-run.gz')
first_stage = pt.transformer.get_transformer(first_stage)
first_stage = first_stage >> pt.apply.generic(add_text) >> reranker()

results = first_stage(topics)
pt.io.write_results(results, output_file)

for _, t in tqdm(list(relevant_documents_per_topic.iterrows()), 'Expansion Docs'):
docs_for_query = set()
Expand All @@ -115,9 +148,10 @@ def main(directory, retrieval_index, feedback_index, corpus_offset):
found = True
if not found:
for doc in passage_ids(relevant_doc):
if len(docs_for_query) == 0 and len(relevant_doc) != 0:
docs_for_query.add(doc)

if len(docs_for_query) == 0 and len(relevant_docnos) != 0:
print('Missing relevant docs for topic', relevant_docnos)
print(len(docs_for_query))

if __name__ == '__main__':
main()
Expand Down

0 comments on commit 431045f

Please sign in to comment.