From f4d7e98bf881243b74b6c342d30daea7c1780f62 Mon Sep 17 00:00:00 2001 From: yindaheng98 Date: Mon, 26 Feb 2024 21:47:01 -0800 Subject: [PATCH] first cache --- citation_crawler/summarizers/neo4j.py | 35 ++++++++++++++++------ citation_crawler/summarizers/neo4jcache.py | 25 ++++++++++++++++ setup.py | 2 +- 3 files changed, 52 insertions(+), 10 deletions(-) create mode 100644 citation_crawler/summarizers/neo4jcache.py diff --git a/citation_crawler/summarizers/neo4j.py b/citation_crawler/summarizers/neo4j.py index d6ff961..f94d509 100644 --- a/citation_crawler/summarizers/neo4j.py +++ b/citation_crawler/summarizers/neo4j.py @@ -1,3 +1,4 @@ +import abc import logging from typing import AsyncIterable from citation_crawler import Summarizer, Paper @@ -6,12 +7,14 @@ from neo4j import AsyncSession import neo4j.time +from .neo4jcache import Neo4jSummarizerCache + '''Use with dblp-crawler''' logger = logging.getLogger("graph") -async def add_paper(tx, paper: Paper): +async def add_paper(tx, cache: Neo4jSummarizerCache, paper: Paper): n4jset = "MERGE (p:Publication {title_hash: $title_hash}) "\ "SET p.title=$title, p.year=$year" if paper.doi(): @@ -32,6 +35,19 @@ async def add_paper(tx, paper: Paper): n4jset += ", p.date=$date" except Exception as e: logger.error(f"Cannot parse date {paper.date()}: {e}") + if cache.try_get_paper( + key=paper.title_hash(), + data=dict( + title_hash=paper.title_hash(), + title=paper.title(), + year=paper.year(), + paperId=paper.paperId(), + dblp_id=paper.dblp_id(), + doi=paper.doi(), + date=date + ) + ): + return await tx.run(n4jset, title_hash=paper.title_hash(), title=paper.title(), @@ -87,7 +103,7 @@ async def divide_author(tx, paper: Paper, author_kv, write_fields, division_kv): title_hash=paper.title_hash(), **write_fields) -async def _add_references(tx, paper: Paper): +async def _add_references(tx, cache: Neo4jSummarizerCache, paper: Paper): title_hash_exists = set([ title_hash for (title_hash,) in await (await tx.run("MATCH (a:Publication)-[:CITE]->(p:Publication {title_hash: $title_hash}) RETURN a.title_hash", @@ -96,11 +112,11 @@ async def _add_references(tx, paper: Paper): async for ref in paper.get_references(): if ref.title_hash() in title_hash_exists: continue - await add_paper(tx, ref) + await add_paper(tx, cache, ref) await add_reference(tx, paper, ref) -async def _add_citations(tx, paper: Paper): +async def _add_citations(tx, cache: Neo4jSummarizerCache, paper: Paper): title_hash_exists = set([ title_hash for (title_hash,) in await (await tx.run("MATCH (p:Publication {title_hash: $title_hash})-[:CITE]->(a:Publication) RETURN a.title_hash", @@ -109,19 +125,20 @@ async def _add_citations(tx, paper: Paper): async for cit in paper.get_references(): if cit.title_hash() in title_hash_exists: continue - await add_paper(tx, cit) + await add_paper(tx, cache, cit) await add_reference(tx, cit, paper) class Neo4jSummarizer(Summarizer): - def __init__(self, session: AsyncSession, *args, **kwargs): + def __init__(self, session: AsyncSession, cache: Neo4jSummarizerCache = Neo4jSummarizerCache(), *args, **kwargs): super().__init__(*args, **kwargs) self.session = session + self.cache = cache async def write_paper(self, paper) -> None: - await self.session.execute_write(add_paper, paper) - await self.session.execute_write(_add_references, paper) - await self.session.execute_write(_add_citations, paper) + await self.session.execute_write(add_paper, self.cache, paper) + await self.session.execute_write(_add_references, self.cache, paper) + await self.session.execute_write(_add_citations, self.cache, paper) async def write_reference(self, paper, reference) -> None: await self.session.execute_write(add_reference, paper, reference) diff --git a/citation_crawler/summarizers/neo4jcache.py b/citation_crawler/summarizers/neo4jcache.py new file mode 100644 index 0000000..ce0d441 --- /dev/null +++ b/citation_crawler/summarizers/neo4jcache.py @@ -0,0 +1,25 @@ +from typing import Dict, List + + +class Neo4jSummarizerCache: + def __init__(self, size: int = 2**20) -> None: + self.papers: Dict[str, Dict] = dict() + self.keys: List[str] = [] + self.size = size + + def try_get_paper(self, key: str, data: Dict): + if key not in self.papers: + self.papers[key] = data + self.keys.append(key) + if len(self.keys) > self.size: + key = self.keys.pop(0) + del self.papers[key] + return False + same = True + old_data = self.papers[key] + for k in data: + if k not in old_data or old_data[k] != data[k]: + old_data[k] = data[k] + same = False + self.papers[key] = old_data + return same diff --git a/setup.py b/setup.py index d8c12c4..605a55c 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setup( name='citation_crawler', - version='2.5.3', + version='2.6', author='yindaheng98', author_email='yindaheng98@gmail.com', url='https://github.com/yindaheng98/citation-crawler',