Skip to content

Commit

Permalink
Merge pull request #643 from NASA-IMPACT/638-add-title-rules-with-f-s…
Browse files Browse the repository at this point in the history
…tring-formatting

638 add title rules with f string formatting
  • Loading branch information
bishwaspraveen authored Mar 7, 2024
2 parents 2eadfd0 + 650ba88 commit d7f64c8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 47 deletions.
68 changes: 29 additions & 39 deletions sde_collections/models/pattern.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import re

from django.apps import apps
from django.db import models

from ..pattern_interpreter import interpret_title_pattern
from ..pattern_interpreter import safe_f_string_evaluation
from .collection_choice_fields import DocumentTypes


Expand All @@ -22,9 +23,7 @@ class MatchPatternTypeChoices(models.IntegerChoices):
help_text="This pattern is compared against the URL of all the documents in the collection "
"and matching documents will be returned",
)
match_pattern_type = models.IntegerField(
choices=MatchPatternTypeChoices.choices, default=1
)
match_pattern_type = models.IntegerField(choices=MatchPatternTypeChoices.choices, default=1)
candidate_urls = models.ManyToManyField(
"CandidateURL",
related_name="%(class)s_urls",
Expand All @@ -34,14 +33,10 @@ def matched_urls(self):
"""Find all the urls matching the pattern."""
escaped_match_pattern = re.escape(self.match_pattern)
if self.match_pattern_type == self.MatchPatternTypeChoices.INDIVIDUAL_URL:
return self.collection.candidate_urls.filter(
url__regex=f"{escaped_match_pattern}$"
)
return self.collection.candidate_urls.filter(url__regex=f"{escaped_match_pattern}$")
elif self.match_pattern_type == self.MatchPatternTypeChoices.MULTI_URL_PATTERN:
return self.collection.candidate_urls.filter(
url__regex=escaped_match_pattern.replace(
r"\*", ".*"
) # allow * wildcards
url__regex=escaped_match_pattern.replace(r"\*", ".*") # allow * wildcards
)
else:
raise NotImplementedError
Expand All @@ -56,10 +51,7 @@ def _process_match_pattern(self) -> str:
if not processed_pattern.startswith("http"):
# if it doesn't begin with http, it must need a star at the beginning
processed_pattern = f"*{processed_pattern}"
if (
self.match_pattern_type
== BaseMatchPattern.MatchPatternTypeChoices.MULTI_URL_PATTERN
):
if self.match_pattern_type == BaseMatchPattern.MatchPatternTypeChoices.MULTI_URL_PATTERN:
# all multi urls should have a star at the end, but individuals should not
processed_pattern = f"{processed_pattern}*"
return processed_pattern
Expand Down Expand Up @@ -97,9 +89,7 @@ def apply(self) -> None:
candidate_url_ids = list(matched_urls.values_list("id", flat=True))
self.candidate_urls.through.objects.bulk_create(
objs=[
ExcludePattern.candidate_urls.through(
candidateurl_id=candidate_url_id, excludepattern_id=self.id
)
ExcludePattern.candidate_urls.through(candidateurl_id=candidate_url_id, excludepattern_id=self.id)
for candidate_url_id in candidate_url_ids
]
)
Expand All @@ -122,9 +112,7 @@ def apply(self) -> None:
candidate_url_ids = list(matched_urls.values_list("id", flat=True))
self.candidate_urls.through.objects.bulk_create(
objs=[
IncludePattern.candidate_urls.through(
candidateurl_id=candidate_url_id, includepattern_id=self.id
)
IncludePattern.candidate_urls.through(candidateurl_id=candidate_url_id, includepattern_id=self.id)
for candidate_url_id in candidate_url_ids
]
)
Expand All @@ -149,26 +137,28 @@ class TitlePattern(BaseMatchPattern):
)

def apply(self) -> None:
CandidateURL = apps.get_model("sde_collections", "CandidateURL")
matched_urls = self.matched_urls()

# since this is not running in celery, this is a bit slow
for url, scraped_title in matched_urls.values_list("url", "scraped_title"):
generated_title = interpret_title_pattern(
url, scraped_title, self.title_pattern
)
matched_urls.filter(url=url, scraped_title=scraped_title).update(
generated_title=generated_title
)

candidate_url_ids = list(matched_urls.values_list("id", flat=True))
self.candidate_urls.through.objects.bulk_create(
objs=[
TitlePattern.candidate_urls.through(
candidateurl_id=candidate_url_id, titlepattern_id=self.id
)
for candidate_url_id in candidate_url_ids
]
)
updated_urls = []

for candidate_url in matched_urls:
context = {"url": candidate_url.url, "scraped_title": candidate_url.scraped_title}

try:
generated_title = safe_f_string_evaluation(self.title_pattern, context)
candidate_url.generated_title = generated_title
updated_urls.append(candidate_url)
except ValueError as e:
print(f"Error applying title pattern to {candidate_url.url}: {e}")

if updated_urls:
CandidateURL.objects.bulk_update(updated_urls, ["generated_title"])

TitlePatternCandidateURL = TitlePattern.candidate_urls.through
pattern_url_associations = [
TitlePatternCandidateURL(titlepattern_id=self.id, candidateurl_id=url.id) for url in updated_urls
]
TitlePatternCandidateURL.objects.bulk_create(pattern_url_associations, ignore_conflicts=True)

def unapply(self) -> None:
self.candidate_urls.update(generated_title="")
Expand Down
26 changes: 18 additions & 8 deletions sde_collections/pattern_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
def interpret_title_pattern(url, scraped_title, title_pattern):
"""Interpret a title pattern."""
# If "{title}" is in the title_pattern, replace it with scraped_title
if "{title}" in title_pattern:
return title_pattern.replace("{title}", scraped_title)
# If "{title}" is not in the title_pattern, return title_pattern as is
else:
return title_pattern
import _ast
import ast


def safe_f_string_evaluation(pattern, context):
"""Safely interpolates the variables in an f-string pattern using the provided context."""
parsed = ast.parse(f"f'''{pattern}'''", mode="eval")

# Walk through the AST to ensure it only contains safe expressions
for node in ast.walk(parsed):
if isinstance(node, _ast.FormattedValue):
if not isinstance(node.value, _ast.Name):
raise ValueError("Unsupported expression in f-string pattern.")
if node.value.id not in context:
raise ValueError(f"Variable {node.value.id} not allowed in f-string pattern.")

compiled = compile(parsed, "<string>", "eval")
return eval(compiled, {}, context)

0 comments on commit d7f64c8

Please sign in to comment.