diff --git a/data/pooling.py b/data/pooling.py index 141ddbc..fee7373 100755 --- a/data/pooling.py +++ b/data/pooling.py @@ -6,10 +6,12 @@ import os import json from tqdm import tqdm +import ir_datasets if not pt.started(): pt.init() + @click.command('pooling') @click.option('--retrieval-index', default='msmarco-passage-v2.1', help='The chatnoir index for pooling.') @click.option('--corpus-offset', default=1500, help='The offset for the corpus.') @@ -55,6 +57,18 @@ def main(directory, retrieval_index, feedback_index, corpus_offset): for doc in results['docno']: all_docs.add(doc) print('Corpus-size', len(all_docs)) + docs_store = ir_datasets.load(retrieval_index).docs_store() + + if not os.path.exists(f'{directory}/pyterrier-index'): + documents = [] + + for doc in tqdm(all_docs, 'Load Docs'): + doc = docs_store.get(doc) + documents += [{'docno': doc.doc_id, 'text': doc.default_text()}] + + indexer = pt.IterDictIndexer(os.abspath(f'{directory}/pyterrier-index'), meta={'docno': 100, 'text': 20480}) + index_ref = indexer.index(documents) + index = pt.IndexFactory.of(os.abspath(f'{directory}/pyterrier-index')) for _, t in tqdm(list(relevant_documents_per_topic.iterrows()), 'Expansion Docs'):