diff --git a/findpapers/models/search.py b/findpapers/models/search.py index cfd6e4b..a2b1c77 100644 --- a/findpapers/models/search.py +++ b/findpapers/models/search.py @@ -1,15 +1,20 @@ -from datetime import datetime, date - +import itertools +from datetime import date, datetime +from difflib import SequenceMatcher from typing import List, Optional + from pydantic import BaseModel, Field, validator + +from findpapers.data.available_databases import AVAILABLE_DATABASES from findpapers.models.paper import Paper from findpapers.models.publication import Publication -from findpapers.data.available_databases import AVAILABLE_DATABASES class Search(BaseModel): query: str = Field( - ..., examples=["BRAIN"], description="The query used to fetch the papers." + ..., + examples=["BRAIN"], + description="The query used to fetch the papers.", ) since: Optional[date] = Field( None, @@ -49,16 +54,21 @@ class Search(BaseModel): examples=["journal", "book"], description="The list of publication types that the search will be limited to. If not provided all publication types will be used.", ) - papers: Optional[set] = Field( + collected_papers: Optional[set] = Field( None, examples=["https://doi.org/10.1016/j.jbiome.2019.01.002"], description="A list of papers already collected.", ) + + papers: set = set() paper_by_key: dict = {} publication_by_key: dict = {} paper_by_doi: dict = {} papers_by_database: dict = {} + def __hash__(self) -> int: + return self.query.__hash__() + @validator("query", pre=True) def check_query(cls, value: str) -> str: if not value: @@ -77,8 +87,8 @@ def assign_processed_at(cls, value): return datetime.utcnow() return value - @validator("papers") - def set_papers(cls, value): + @validator("collected_papers") + def validate_collected_papers(cls, value): if value: for paper in value: try: @@ -97,12 +107,15 @@ def add_database(self, database_name: str) -> None: self.databases.add(database_name) def get_paper_key( - self, paper_title: str, publication_date: date, paper_doi: Optional[str] = None + self, + paper_title: str, + paper_publication_date: date, + paper_doi: Optional[str] = None, ) -> str: return ( f"DOI-{paper_doi}" if paper_doi - else f"{paper_title.lower()}|{publication_date.year if publication_date else ''}" + else f"{paper_title.lower()}|{paper_publication_date.year if paper_publication_date else ''}" ) def get_publication_key( @@ -118,29 +131,29 @@ def get_publication_key( else: return f"TITLE-{publication_title.lower()}" - # TODO: check this. Currently the code is copied directly from findpapers def add_paper(self, paper: Paper) -> None: if not paper.databases: raise ValueError( - "Paper cannot be added to search without at least one defined database." + "Paper cannot be added to search without at least one defined database.", ) + databases_lowered = {db.lower() for db in self.databases} + for database in paper.databases: - if self.databases and database.lower() not in self.databases: + if self.databases and database.lower() not in databases_lowered: raise ValueError(f"Database {database} isn't in databases list.") if self.reached_its_limit(database): raise OverflowError( - "When the papers limit is provided, you cannot exceed it." + "When the papers limit is provided, you cannot exceed it.", ) - - if database not in self.papers_by_database: - self.papers_by_database[database] = set() + if database not in self.papers_by_database: + self.papers_by_database[database] = set() if paper.publication: publication_key = self.get_publication_key( - paper.publication.title, - paper.publication.issn, - paper.publication.isbn, + publication_title=paper.publication.title, + publication_issn=paper.publication.issn, + publication_isbn=paper.publication.isbn, ) already_collected_publication = self.publication_by_key.get(publication_key) @@ -150,34 +163,45 @@ def add_paper(self, paper: Paper) -> None: else: self.publication_by_key[publication_key] = paper.publication - paper_key = self.get_paper_key(paper.title, paper.publication_date, paper.doi) + paper_key = self.get_paper_key( + paper_title=paper.title, + paper_publication_date=paper.publication_date, + paper_doi=paper.doi, + ) already_collected_paper = self.paper_by_key.get(paper_key) if (self.since is None or paper.publication_date >= self.since) and ( self.until is None or paper.publication_date <= self.until ): - if already_collected_paper is None: + if not already_collected_paper: self.papers.add(paper) self.paper_by_key[paper_key] = paper - if paper.doi is not None: + if paper.doi: self.paper_by_doi[paper.doi] = paper + for database in paper.databases: + if database not in self.papers_by_database: + self.papers_by_database[database] = set() + self.papers_by_database[database].add(paper) else: self.papers_by_database[database].add(already_collected_paper) - - already_collected_paper.enrich(paper) - - self.papers_by_database[database].add(paper) + already_collected_paper.enrich(paper) def get_paper( - self, paper_title: str, publication_date: str, paper_doi: Optional[str] = None + self, + paper_title: str, + paper_publication_date: str, + paper_doi: Optional[str] = None, ) -> Paper: - paper_key = self.get_paper_key(paper_title, publication_date, paper_doi) + paper_key = self.get_paper_key(paper_title, paper_publication_date, paper_doi) return self.paper_by_key.get(paper_key, None) def get_publication( - self, title: str, issn: Optional[str] = None, isbn: Optional[str] = None + self, + title: str, + issn: Optional[str] = None, + isbn: Optional[str] = None, ) -> Publication: publication_key = self.get_publication_key(title, issn, isbn) return self.publication_by_key.get(publication_key, None) @@ -193,20 +217,90 @@ def remove_paper(self, paper: Paper) -> None: self.papers.discard(paper) - # TODO: implement this - def merge_duplications(self): - pass + def merge_duplications(self, similarity_threshold: float = 0.95) -> None: + paper_key_pairs = list(itertools.combinations(self.paper_by_key.keys(), 2)) + + for _, pair in enumerate(paper_key_pairs): + paper_1 = self.paper_by_key.get(pair[0]) + paper_2 = self.paper_by_key.get(pair[1]) + # check if de-duplication can be performed + if not (paper_1 and paper_2 and not paper_1.title and not paper_2.title): + continue + + elif (paper_1.doi != paper_2.doi or not paper_1.doi) and ( + paper_1.abstract + and paper_2.abstract + and paper_1.abstract not in ["", "[No abstract available]"] + and paper_2.abstract not in ["", "[No abstract available]"] + ): + max_title_length = max(len(paper_1.title), len(paper_2.title)) + diff_title_length = abs(len(paper_1.title) - len(paper_2.title)) + max_abstract_length = max(len(paper_1.abstract), len(paper_2.abstract)) + diff_abstract_length = abs( + len(paper_1.abstract) - len(paper_2.abstract), + ) - def reached_its_limit(self, database: str) -> bool: - n_dbs = ( - len(self.papers_by_database.get(database)) - if bool(self.papers_by_database.get(database)) - else 0 - ) + # Adj: larger length differences decreasing the threshold + adjusted_title_threshold = max( + similarity_threshold * (1 - 0.5 * diff_title_length / max_title_length), + similarity_threshold * 0.75, + ) + adjusted_abstract_threshold = max( + similarity_threshold * (1 - 0.5 * diff_abstract_length / max_abstract_length), + similarity_threshold * 0.75, + ) - reached_general_limit = ( - self.limit is not None and len(self.papers) >= self.limit - ) + # calculating the distance between the titles + titles_similarity = SequenceMatcher( + None, + paper_1.title.lower(), + paper_2.title.lower(), + ).ratio() + abstracts_similarity = SequenceMatcher( + None, + paper_1.abstract.lower(), + paper_2.abstract.lower(), + ).ratio() + + if (titles_similarity > adjusted_title_threshold) or ( + abstracts_similarity > adjusted_abstract_threshold + ): + # using the information of paper_2 to enrich paper_1 + paper_1.enrich(paper_2) + + # removing the paper_2 instance + self.remove_paper(paper_2) + elif ( + (paper_1.publication_date and paper_2.publication_date) + and (paper_1.publication_date.year == paper_2.publication_date.year) + and (paper_1.doi and paper_1.doi == paper_2.doi) + ): + max_title_length = max(len(paper_1.title), len(paper_2.title)) + diff_title_length = abs(len(paper_1.title) - len(paper_2.title)) + + # Adj: larger length differences decreasing the threshold + adjusted_title_threshold = max( + similarity_threshold * (1 - 0.5 * diff_title_length / max_title_length), + similarity_threshold * 0.75, + ) + + # calculating the similarity + titles_similarity = SequenceMatcher( + None, + paper_1.title.lower(), + paper_2.title.lower(), + ).ratio() + + if titles_similarity > adjusted_title_threshold: + # using the information of paper_2 to enrich paper_1 + paper_1.enrich(paper_2) + + # removing the paper_2 instance + self.remove_paper(paper_2) + + def reached_its_limit(self, database: str) -> bool: + n_dbs = len(self.papers_by_database.get(database)) if bool(self.papers_by_database.get(database)) else 0 + reached_general_limit = self.limit is not None and len(self.papers) >= self.limit reached_database_limit = ( self.limit_per_database is not None and database in self.papers_by_database