Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle ambiguous column names in queries involving 'any' field and a relation field #5541

Merged
merged 4 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions beets/dbcore/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
from ..util import cached_classproperty, functemplate
from . import types
from .query import (
AndQuery,
FieldQuery,
FieldQueryType,
FieldSort,
MatchQuery,
Expand Down Expand Up @@ -718,33 +716,6 @@ def set_parse(self, key, string: str):
"""Set the object's key to a value represented by a string."""
self[key] = self._parse(key, string)

# Convenient queries.

@classmethod
def field_query(
cls,
field,
pattern,
query_cls: FieldQueryType = MatchQuery,
) -> FieldQuery:
"""Get a `FieldQuery` for this model."""
return query_cls(field, pattern, field in cls._fields)

@classmethod
def all_fields_query(
cls: type[Model],
pats: Mapping[str, str],
query_cls: FieldQueryType = MatchQuery,
):
"""Get a query that matches many fields with different patterns.

`pats` should be a mapping from field names to patterns. The
resulting query is a conjunction ("and") of per-field queries
for all of these field/pattern pairs.
"""
subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()]
return AndQuery(subqueries)


# Database controller and supporting interfaces.

Expand Down
47 changes: 3 additions & 44 deletions beets/dbcore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def match(self, obj: Model):
"""
...

def __and__(self, other: Query) -> AndQuery:
return AndQuery([self, other])

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"

Expand Down Expand Up @@ -505,50 +508,6 @@ def __hash__(self) -> int:
return reduce(mul, map(hash, self.subqueries), 1)


class AnyFieldQuery(CollectionQuery):
"""A query that matches if a given FieldQuery subclass matches in
any field. The individual field query class is provided to the
constructor.
"""

@property
def field_names(self) -> set[str]:
"""Return a set with field names that this query operates on."""
return set(self.fields)

def __init__(self, pattern, fields, cls: FieldQueryType):
self.pattern = pattern
self.fields = fields
self.query_class = cls

subqueries = []
for field in self.fields:
subqueries.append(cls(field, pattern, True))
# TYPING ERROR
super().__init__(subqueries)

def clause(self) -> tuple[str | None, Sequence[SQLiteType]]:
return self.clause_with_joiner("or")

def match(self, obj: Model) -> bool:
for subq in self.subqueries:
if subq.match(obj):
return True
return False

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.pattern!r}, {self.fields!r}, "
f"{self.query_class.__name__})"
)

def __eq__(self, other) -> bool:
return super().__eq__(other) and self.query_class == other.query_class

def __hash__(self) -> int:
return hash((self.pattern, tuple(self.fields), self.query_class))


class MutableCollectionQuery(CollectionQuery):
"""A collection query whose subqueries may be modified after the
query is initialized.
Expand Down
40 changes: 14 additions & 26 deletions beets/dbcore/queryparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@
import re
from typing import TYPE_CHECKING

from . import Model, query
from . import query

if TYPE_CHECKING:
from collections.abc import Collection, Sequence

from ..library import LibModel
from .query import FieldQueryType, Sort

Prefixes = dict[str, FieldQueryType]


PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
r"(-|\^)?" # Negation prefixes.
Expand Down Expand Up @@ -112,7 +114,7 @@ def parse_query_part(


def construct_query_part(
model_cls: type[Model],
model_cls: type[LibModel],
prefixes: Prefixes,
query_part: str,
) -> query.Query:
Expand Down Expand Up @@ -147,28 +149,14 @@ def construct_query_part(
query_part, query_classes, prefixes
)

# If there's no key (field name) specified, this is a "match
# anything" query.
if key is None:
# The query type matches a specific field, but none was
# specified. So we use a version of the query that matches
# any field.
out_query = query.AnyFieldQuery(
pattern, model_cls._search_fields, query_class
)

# Field queries get constructed according to the name of the field
# they are querying.
# If there's no key (field name) specified, this is a "match anything"
# query.
out_query = model_cls.any_field_query(pattern, query_class)
else:
field = table = key.lower()
if field in model_cls.shared_db_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError if we try to query it in a join.
# Using an explicit table name resolves this.
table = f"{model_cls._table}.{field}"

field_in_db = field in model_cls.all_db_fields
out_query = query_class(table, pattern, field_in_db)
# Field queries get constructed according to the name of the field
# they are querying.
out_query = model_cls.field_query(key.lower(), pattern, query_class)

# Apply negation.
if negate:
Expand All @@ -180,7 +168,7 @@ def construct_query_part(
# TYPING ERROR
def query_from_strings(
query_cls: type[query.CollectionQuery],
model_cls: type[Model],
model_cls: type[LibModel],
prefixes: Prefixes,
query_parts: Collection[str],
) -> query.Query:
Expand All @@ -197,7 +185,7 @@ def query_from_strings(


def construct_sort_part(
model_cls: type[Model],
model_cls: type[LibModel],
part: str,
case_insensitive: bool = True,
) -> Sort:
Expand Down Expand Up @@ -228,7 +216,7 @@ def construct_sort_part(


def sort_from_strings(
model_cls: type[Model],
model_cls: type[LibModel],
sort_parts: Sequence[str],
case_insensitive: bool = True,
) -> Sort:
Expand All @@ -247,7 +235,7 @@ def sort_from_strings(


def parse_sorted_query(
model_cls: type[Model],
model_cls: type[LibModel],
parts: list[str],
prefixes: Prefixes = {},
case_insensitive: bool = True,
Expand Down
8 changes: 2 additions & 6 deletions beets/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,9 +707,7 @@ def find_duplicates(self, lib):
# use a temporary Album object to generate any computed fields.
tmp_album = library.Album(lib, **info)
keys = config["import"]["duplicate_keys"]["album"].as_str_seq()
dup_query = library.Album.all_fields_query(
{key: tmp_album.get(key) for key in keys}
)
dup_query = tmp_album.duplicates_query(keys)

# Don't count albums with the same files as duplicates.
task_paths = {i.path for i in self.items if i}
Expand Down Expand Up @@ -1025,9 +1023,7 @@ def find_duplicates(self, lib):
# temporary `Item` object to generate any computed fields.
tmp_item = library.Item(lib, **info)
keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
dup_query = library.Album.all_fields_query(
{key: tmp_item.get(key) for key in keys}
)
dup_query = tmp_item.duplicates_query(keys)

found_items = []
for other_item in lib.items(dup_query):
Expand Down
64 changes: 53 additions & 11 deletions beets/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import unicodedata
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING

import platformdirs
from mediafile import MediaFile, UnreadableFileError
Expand All @@ -42,6 +43,9 @@
)
from beets.util.functemplate import Template, template

if TYPE_CHECKING:
from .dbcore.query import FieldQuery, FieldQueryType

# To use the SQLite "blob" type, it doesn't suffice to provide a byte
# string; SQLite treats that as encoded text. Wrapping it in a
# `memoryview` tells it that we actually mean non-text data.
Expand Down Expand Up @@ -346,6 +350,10 @@ class LibModel(dbcore.Model["Library"]):
# Config key that specifies how an instance should be formatted.
_format_config_key: str

@cached_classproperty
def writable_media_fields(cls) -> set[str]:
return set(MediaFile.fields()) & cls._fields.keys()

def _template_funcs(self):
funcs = DefaultTemplateFunctions(self, self._db).functions()
funcs.update(plugins.template_funcs())
Expand Down Expand Up @@ -375,6 +383,44 @@ def __str__(self):
def __bytes__(self):
return self.__str__().encode("utf-8")

# Convenient queries.

@classmethod
def field_query(
cls, field: str, pattern: str, query_cls: FieldQueryType
) -> FieldQuery:
"""Get a `FieldQuery` for the given field on this model."""
fast = field in cls.all_db_fields
if field in cls.shared_db_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError if we try to use it in a query.
# Using an explicit table name resolves this.
field = f"{cls._table}.{field}"

return query_cls(field, pattern, fast)

@classmethod
def any_field_query(cls, *args, **kwargs) -> dbcore.OrQuery:
return dbcore.OrQuery(
[cls.field_query(f, *args, **kwargs) for f in cls._search_fields]
)

@classmethod
def any_writable_media_field_query(cls, *args, **kwargs) -> dbcore.OrQuery:
fields = cls.writable_media_fields
return dbcore.OrQuery(
[cls.field_query(f, *args, **kwargs) for f in fields]
)

def duplicates_query(self, fields: list[str]) -> dbcore.AndQuery:
"""Return a query for entities with same values in the given fields."""
return dbcore.AndQuery(
[
self.field_query(f, self.get(f), dbcore.MatchQuery)
for f in fields
]
)


class FormattedItemMapping(dbcore.db.FormattedMapping):
"""Add lookup for album-level fields.
Expand Down Expand Up @@ -648,6 +694,12 @@ def _getters(cls):
getters["filesize"] = Item.try_filesize # In bytes.
return getters

def duplicates_query(self, fields: list[str]) -> dbcore.AndQuery:
"""Return a query for entities with same values in the given fields."""
return super().duplicates_query(fields) & dbcore.query.NoneQuery(
"album_id"
)

@classmethod
def from_path(cls, path):
"""Create a new item from the media file at the specified path."""
Expand Down Expand Up @@ -1866,7 +1918,6 @@ def tmpl_sunique(self, keys=None, disam=None, bracket=None):
Item.all_keys(),
# Do nothing for non singletons.
lambda i: i.album_id is not None,
initial_subqueries=[dbcore.query.NoneQuery("album_id", True)],
)

def _tmpl_unique_memokey(self, name, keys, disam, item_id):
Expand All @@ -1885,7 +1936,6 @@ def _tmpl_unique(
db_item,
item_keys,
skip_item,
initial_subqueries=None,
):
"""Generate a string that is guaranteed to be unique among all items of
the same type as "db_item" who share the same set of keys.
Expand Down Expand Up @@ -1932,15 +1982,7 @@ def _tmpl_unique(
bracket_r = ""

# Find matching items to disambiguate with.
subqueries = []
if initial_subqueries is not None:
subqueries.extend(initial_subqueries)
for key in keys:
value = db_item.get(key, "")
# Use slow queries for flexible attributes.
fast = key in item_keys
subqueries.append(dbcore.MatchQuery(key, value, fast))
query = dbcore.AndQuery(subqueries)
query = db_item.duplicates_query(keys)
ambigous_items = (
self.lib.items(query)
if isinstance(db_item, Item)
Expand Down
15 changes: 8 additions & 7 deletions beetsplug/aura.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def translate_filters(self):
value = converter(value)
# Add exact match query to list
# Use a slow query so it works with all fields
queries.append(MatchQuery(beets_attr, value, fast=False))
queries.append(
self.model_cls.field_query(beets_attr, value, MatchQuery)
)
# NOTE: AURA doesn't officially support multiple queries
return AndQuery(queries)

Expand Down Expand Up @@ -318,13 +320,12 @@ def all_resources(self):
sort = self.translate_sorts(sort_arg)
# For each sort field add a query which ensures all results
# have a non-empty, non-zero value for that field.
for s in sort.sorts:
query.subqueries.append(
NotQuery(
# Match empty fields (^$) or zero fields, (^0$)
RegexpQuery(s.field, "(^$|^0$)", fast=False)
)
query.subqueries.extend(
NotQuery(
self.model_cls.field_query(s.field, "(^$|^0$)", RegexpQuery)
)
for s in sort.sorts
)
else:
sort = None
# Get information from the library
Expand Down
Loading
Loading