Skip to content

Commit

Permalink
add the merge_duplicate & hash method to the search model and format …
Browse files Browse the repository at this point in the history
…the module #35
  • Loading branch information
Kashyap Maheshwari committed Jun 8, 2023
1 parent 339d11b commit b2dd999
Showing 1 changed file with 135 additions and 41 deletions.
176 changes: 135 additions & 41 deletions findpapers/models/search.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from datetime import datetime, date

import itertools
from datetime import date, datetime
from difflib import SequenceMatcher
from typing import List, Optional

from pydantic import BaseModel, Field, validator

from findpapers.data.available_databases import AVAILABLE_DATABASES
from findpapers.models.paper import Paper
from findpapers.models.publication import Publication
from findpapers.data.available_databases import AVAILABLE_DATABASES


class Search(BaseModel):
query: str = Field(
..., examples=["BRAIN"], description="The query used to fetch the papers."
...,
examples=["BRAIN"],
description="The query used to fetch the papers.",
)
since: Optional[date] = Field(
None,
Expand Down Expand Up @@ -49,16 +54,21 @@ class Search(BaseModel):
examples=["journal", "book"],
description="The list of publication types that the search will be limited to. If not provided all publication types will be used.",
)
papers: Optional[set] = Field(
collected_papers: Optional[set] = Field(
None,
examples=["https://doi.org/10.1016/j.jbiome.2019.01.002"],
description="A list of papers already collected.",
)

papers: set = set()
paper_by_key: dict = {}
publication_by_key: dict = {}
paper_by_doi: dict = {}
papers_by_database: dict = {}

def __hash__(self) -> int:
return self.query.__hash__()

@validator("query", pre=True)
def check_query(cls, value: str) -> str:
if not value:
Expand All @@ -77,8 +87,8 @@ def assign_processed_at(cls, value):
return datetime.utcnow()
return value

@validator("papers")
def set_papers(cls, value):
@validator("collected_papers")
def validate_collected_papers(cls, value):
if value:
for paper in value:
try:
Expand All @@ -97,12 +107,15 @@ def add_database(self, database_name: str) -> None:
self.databases.add(database_name)

def get_paper_key(
self, paper_title: str, publication_date: date, paper_doi: Optional[str] = None
self,
paper_title: str,
paper_publication_date: date,
paper_doi: Optional[str] = None,
) -> str:
return (
f"DOI-{paper_doi}"
if paper_doi
else f"{paper_title.lower()}|{publication_date.year if publication_date else ''}"
else f"{paper_title.lower()}|{paper_publication_date.year if paper_publication_date else ''}"
)

def get_publication_key(
Expand All @@ -118,29 +131,29 @@ def get_publication_key(
else:
return f"TITLE-{publication_title.lower()}"

# TODO: check this. Currently the code is copied directly from findpapers
def add_paper(self, paper: Paper) -> None:
if not paper.databases:
raise ValueError(
"Paper cannot be added to search without at least one defined database."
"Paper cannot be added to search without at least one defined database.",
)

databases_lowered = {db.lower() for db in self.databases}

for database in paper.databases:
if self.databases and database.lower() not in self.databases:
if self.databases and database.lower() not in databases_lowered:
raise ValueError(f"Database {database} isn't in databases list.")
if self.reached_its_limit(database):
raise OverflowError(
"When the papers limit is provided, you cannot exceed it."
"When the papers limit is provided, you cannot exceed it.",
)

if database not in self.papers_by_database:
self.papers_by_database[database] = set()
if database not in self.papers_by_database:
self.papers_by_database[database] = set()

if paper.publication:
publication_key = self.get_publication_key(
paper.publication.title,
paper.publication.issn,
paper.publication.isbn,
publication_title=paper.publication.title,
publication_issn=paper.publication.issn,
publication_isbn=paper.publication.isbn,
)
already_collected_publication = self.publication_by_key.get(publication_key)

Expand All @@ -150,34 +163,45 @@ def add_paper(self, paper: Paper) -> None:
else:
self.publication_by_key[publication_key] = paper.publication

paper_key = self.get_paper_key(paper.title, paper.publication_date, paper.doi)
paper_key = self.get_paper_key(
paper_title=paper.title,
paper_publication_date=paper.publication_date,
paper_doi=paper.doi,
)
already_collected_paper = self.paper_by_key.get(paper_key)

if (self.since is None or paper.publication_date >= self.since) and (
self.until is None or paper.publication_date <= self.until
):
if already_collected_paper is None:
if not already_collected_paper:
self.papers.add(paper)
self.paper_by_key[paper_key] = paper

if paper.doi is not None:
if paper.doi:
self.paper_by_doi[paper.doi] = paper

for database in paper.databases:
if database not in self.papers_by_database:
self.papers_by_database[database] = set()
self.papers_by_database[database].add(paper)
else:
self.papers_by_database[database].add(already_collected_paper)

already_collected_paper.enrich(paper)

self.papers_by_database[database].add(paper)
already_collected_paper.enrich(paper)

def get_paper(
self, paper_title: str, publication_date: str, paper_doi: Optional[str] = None
self,
paper_title: str,
paper_publication_date: str,
paper_doi: Optional[str] = None,
) -> Paper:
paper_key = self.get_paper_key(paper_title, publication_date, paper_doi)
paper_key = self.get_paper_key(paper_title, paper_publication_date, paper_doi)
return self.paper_by_key.get(paper_key, None)

def get_publication(
self, title: str, issn: Optional[str] = None, isbn: Optional[str] = None
self,
title: str,
issn: Optional[str] = None,
isbn: Optional[str] = None,
) -> Publication:
publication_key = self.get_publication_key(title, issn, isbn)
return self.publication_by_key.get(publication_key, None)
Expand All @@ -193,20 +217,90 @@ def remove_paper(self, paper: Paper) -> None:

self.papers.discard(paper)

# TODO: implement this
def merge_duplications(self):
pass
def merge_duplications(self, similarity_threshold: float = 0.95) -> None:
paper_key_pairs = list(itertools.combinations(self.paper_by_key.keys(), 2))

for _, pair in enumerate(paper_key_pairs):
paper_1 = self.paper_by_key.get(pair[0])
paper_2 = self.paper_by_key.get(pair[1])
# check if de-duplication can be performed
if not (paper_1 and paper_2 and not paper_1.title and not paper_2.title):
continue

elif (paper_1.doi != paper_2.doi or not paper_1.doi) and (
paper_1.abstract
and paper_2.abstract
and paper_1.abstract not in ["", "[No abstract available]"]
and paper_2.abstract not in ["", "[No abstract available]"]
):
max_title_length = max(len(paper_1.title), len(paper_2.title))
diff_title_length = abs(len(paper_1.title) - len(paper_2.title))
max_abstract_length = max(len(paper_1.abstract), len(paper_2.abstract))
diff_abstract_length = abs(
len(paper_1.abstract) - len(paper_2.abstract),
)

def reached_its_limit(self, database: str) -> bool:
n_dbs = (
len(self.papers_by_database.get(database))
if bool(self.papers_by_database.get(database))
else 0
)
# Adj: larger length differences decreasing the threshold
adjusted_title_threshold = max(
similarity_threshold * (1 - 0.5 * diff_title_length / max_title_length),
similarity_threshold * 0.75,
)
adjusted_abstract_threshold = max(
similarity_threshold * (1 - 0.5 * diff_abstract_length / max_abstract_length),
similarity_threshold * 0.75,
)

reached_general_limit = (
self.limit is not None and len(self.papers) >= self.limit
)
# calculating the distance between the titles
titles_similarity = SequenceMatcher(
None,
paper_1.title.lower(),
paper_2.title.lower(),
).ratio()
abstracts_similarity = SequenceMatcher(
None,
paper_1.abstract.lower(),
paper_2.abstract.lower(),
).ratio()

if (titles_similarity > adjusted_title_threshold) or (
abstracts_similarity > adjusted_abstract_threshold
):
# using the information of paper_2 to enrich paper_1
paper_1.enrich(paper_2)

# removing the paper_2 instance
self.remove_paper(paper_2)
elif (
(paper_1.publication_date and paper_2.publication_date)
and (paper_1.publication_date.year == paper_2.publication_date.year)
and (paper_1.doi and paper_1.doi == paper_2.doi)
):
max_title_length = max(len(paper_1.title), len(paper_2.title))
diff_title_length = abs(len(paper_1.title) - len(paper_2.title))

# Adj: larger length differences decreasing the threshold
adjusted_title_threshold = max(
similarity_threshold * (1 - 0.5 * diff_title_length / max_title_length),
similarity_threshold * 0.75,
)

# calculating the similarity
titles_similarity = SequenceMatcher(
None,
paper_1.title.lower(),
paper_2.title.lower(),
).ratio()

if titles_similarity > adjusted_title_threshold:
# using the information of paper_2 to enrich paper_1
paper_1.enrich(paper_2)

# removing the paper_2 instance
self.remove_paper(paper_2)

def reached_its_limit(self, database: str) -> bool:
n_dbs = len(self.papers_by_database.get(database)) if bool(self.papers_by_database.get(database)) else 0
reached_general_limit = self.limit is not None and len(self.papers) >= self.limit
reached_database_limit = (
self.limit_per_database is not None
and database in self.papers_by_database
Expand Down

0 comments on commit b2dd999

Please sign in to comment.