From 431045f965509443f0108435f9385db568b25d73 Mon Sep 17 00:00:00 2001 From: Maik Froebe Date: Wed, 13 Nov 2024 10:34:23 +0100 Subject: [PATCH] mf --- data/pooling.py | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/data/pooling.py b/data/pooling.py index 14d59a8..f14a889 100755 --- a/data/pooling.py +++ b/data/pooling.py @@ -26,6 +26,15 @@ def passage_ids(doc_id): return sorted(list(ret)) +def all_re_rankers(): + from pyterrier_t5 import MonoT5ReRanker + import pyterrier_dr + return [ + ('mono-t5', lambda: MonoT5ReRanker()), + ('colbert', lambda: pyterrier_dr.TctColBert(verbose=True)) + ('ance', lambda: pyterrier_dr.Ance(verbose=True)) + ] + @click.command('pooling') @click.option('--retrieval-index', default='msmarco-passage-v2.1', help='The chatnoir index for pooling.') @@ -72,10 +81,10 @@ def main(directory, retrieval_index, feedback_index, corpus_offset): for doc in results['docno']: all_docs.add(doc) print('Corpus-size', len(all_docs)) - docs_store = ir_datasets.load('msmarco-segment-v2.1').docs_store() if not os.path.exists(f'{directory}/pyterrier-index'): if not os.path.exists(f'{directory}/documents.jsonl.gz'): + docs_store = ir_datasets.load('msmarco-segment-v2.1').docs_store() with gzip.open(f'{directory}/documents.jsonl.gz', 'wt') as f: for doc in tqdm(all_docs, 'Load Docs'): doc = docs_store.get(doc) @@ -100,6 +109,30 @@ def main(directory, retrieval_index, feedback_index, corpus_offset): results = retriever(topics) pt.io.write_results(results, output_file) + for retrieval_model, reranker in tqdm(all_re_rankers(), 'Re-Rankers'): + for query_type in ['title', 'description']: + output_file = f'{directory}/neural-{retrieval_model}-on{query_type}-run.gz' + if os.path.exists(output_file): + continue + + documents = {} + with gzip.open(f'{directory}/documents.jsonl.gz', 'rt') as f: + for l in f: + l = json.loads(l) + documents[l['docno']] = l['text'] + + + def add_text(df): + df['text'] = df['docno'].apply(lambda i: documents[i]) + return df + + topics = pt.io.read_topics(f'{directory}/topics.xml', 'trecxml', tags=[query_type], tokenise=False) + first_stage = pt.io.read_results(f'{directory}/pyterrier-BM25-run.gz') + first_stage = pt.transformer.get_transformer(first_stage) + first_stage = first_stage >> pt.apply.generic(add_text) >> reranker() + + results = first_stage(topics) + pt.io.write_results(results, output_file) for _, t in tqdm(list(relevant_documents_per_topic.iterrows()), 'Expansion Docs'): docs_for_query = set() @@ -115,9 +148,10 @@ def main(directory, retrieval_index, feedback_index, corpus_offset): found = True if not found: for doc in passage_ids(relevant_doc): - if len(docs_for_query) == 0 and len(relevant_doc) != 0: + docs_for_query.add(doc) + + if len(docs_for_query) == 0 and len(relevant_docnos) != 0: print('Missing relevant docs for topic', relevant_docnos) - print(len(docs_for_query)) if __name__ == '__main__': main()