diff --git a/data/pooling.py b/data/pooling.py index 60ed73f..b802a25 100755 --- a/data/pooling.py +++ b/data/pooling.py @@ -61,17 +61,30 @@ def main(directory, retrieval_index, feedback_index, corpus_offset): docs_store = ir_datasets.load('msmarco-segment-v2.1').docs_store() if not os.path.exists(f'{directory}/pyterrier-index'): - documents = [] if not os.path.exists(f'{directory}/documents.jsonl.gz'): with gzip.open(f'{directory}/documents.jsonl.gz', 'wt') as f: for doc in tqdm(all_docs, 'Load Docs'): doc = docs_store.get(doc) f.write(json.dumps({'docno': doc.doc_id, 'text': doc.default_text()}) + '\n') f.flush() + + documents = [] + with gzip.open(f'{directory}/documents.jsonl.gz', 'rt') as f: + for l in f: + documents += [json.loads(l)] + + indexer = pt.IterDictIndexer(os.path.abspath(f'{directory}/pyterrier-index'), meta={'docno': 100, 'text': 20480}) + index_ref = indexer.index(tqdm(documents, 'Index')) - indexer = pt.IterDictIndexer(os.abspath(f'{directory}/pyterrier-index'), meta={'docno': 100, 'text': 20480}) - index_ref = indexer.index(documents) - index = pt.IndexFactory.of(os.abspath(f'{directory}/pyterrier-index')) + for retrieval_model in ['BM25', 'PL2', 'DirichletLM', 'TF_IDF', 'Hiemstra_LM']: + index = pt.IndexFactory.of(os.path.abspath(f'{directory}/pyterrier-index')) + output_file = f'{directory}/pyterrier-{retrieval_model}-run.gz' + if os.path.exists(output_file): + continue + retriever = pt.BatchRetrieve(index, wmodel=retrieval_model) + topics = pt.io.read_topics(f'{directory}/topics.xml', 'trecxml', tags=['title'], tokenise=True) + results = retriever(topics) + pt.io.write_results(results, output_file) for _, t in tqdm(list(relevant_documents_per_topic.iterrows()), 'Expansion Docs'):