diff --git a/cli/tirex.py b/cli/tirex.py index 59519a9..766b6cf 100644 --- a/cli/tirex.py +++ b/cli/tirex.py @@ -13,6 +13,7 @@ from pyterrier.io import read_topics, write_results, read_results from pyterrier.apply import generic from tqdm import tqdm +from glob import glob from trectools import TrecRun, TrecPoolMaker @@ -129,9 +130,15 @@ def get_documents(pooling_path: Path): documents_path = pooling_path / "documents.jsonl.gz" if not documents_path.exists(): docs_store = irds_load("msmarco-segment-v2.1").docs_store() + all_docs = set() + for file_name in glob(f'{pooling_path}/corpus-chatnoir*run.gz'): + run = TrecRun(file_name).run_data + for doc in run['docid']: + all_docs.add(doc) + print('docs size', len(all_docs)) + + with gzip_open(documents_path, "wt") as file: - # FIXME: Where does this come from? - all_docs = NotImplemented for doc in tqdm(all_docs, "Load Docs"): doc = docs_store.get(doc) file.write(