Skip to content

Commit

Permalink
Refs django#373 -- Added support for using tuple lookups in filters.
Browse files Browse the repository at this point in the history
  • Loading branch information
csirmazbendeguz authored and sarahboyce committed Sep 26, 2024
1 parent f22ff45 commit 5ed7208
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 53 deletions.
79 changes: 45 additions & 34 deletions django/db/models/fields/tuple_lookups.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools

from django.core.exceptions import EmptyResultSet
from django.db.models import Field
from django.db.models.expressions import Func, Value
from django.db.models.lookups import (
Exact,
Expand All @@ -16,15 +17,19 @@

class Tuple(Func):
function = ""
output_field = Field()

def __len__(self):
return len(self.source_expressions)

def __iter__(self):
return iter(self.source_expressions)


class TupleLookupMixin:
def get_prep_lookup(self):
self.check_tuple_lookup()
return super().get_prep_lookup()

def check_tuple_lookup(self):
self.check_rhs_length_equals_lhs_length()
return self.rhs

def check_rhs_length_equals_lhs_length(self):
len_lhs = len(self.lhs)
Expand All @@ -34,24 +39,30 @@ def check_rhs_length_equals_lhs_length(self):
f"must have {len_lhs} elements"
)

def as_sql(self, compiler, connection):
# e.g.: (a, b, c) == (x, y, z) as SQL:
# WHERE (a, b, c) = (x, y, z)
vals = [
def get_prep_lhs(self):
if isinstance(self.lhs, (tuple, list)):
return Tuple(*self.lhs)
return super().get_prep_lhs()

def process_lhs(self, compiler, connection, lhs=None):
sql, params = super().process_lhs(compiler, connection, lhs)
if not isinstance(self.lhs, Tuple):
sql = f"({sql})"
return sql, params

def process_rhs(self, compiler, connection):
values = [
Value(val, output_field=col.output_field)
for col, val in zip(self.lhs, self.rhs)
]
lookup_class = self.__class__.__bases__[-1]
lookup = lookup_class(Tuple(self.lhs), Tuple(*vals))
return lookup.as_sql(compiler, connection)
return Tuple(*values).as_sql(compiler, connection)


class TupleExact(TupleLookupMixin, Exact):
def as_oracle(self, compiler, connection):
# e.g.: (a, b, c) == (x, y, z) as SQL:
# WHERE a = x AND b = y AND c = z
cols = self.lhs.get_cols()
lookups = [Exact(col, val) for col, val in zip(cols, self.rhs)]
lookups = [Exact(col, val) for col, val in zip(self.lhs, self.rhs)]
root = WhereNode(lookups, connector=AND)

return root.as_sql(compiler, connection)
Expand Down Expand Up @@ -83,10 +94,9 @@ class TupleGreaterThan(TupleLookupMixin, GreaterThan):
def as_oracle(self, compiler, connection):
# e.g.: (a, b, c) > (x, y, z) as SQL:
# WHERE a > x OR (a = x AND (b > y OR (b = y AND c > z)))
cols = self.lhs.get_cols()
lookups = itertools.cycle([GreaterThan, Exact])
connectors = itertools.cycle([OR, AND])
cols_list = [col for col in cols for _ in range(2)]
cols_list = [col for col in self.lhs for _ in range(2)]
vals_list = [val for val in self.rhs for _ in range(2)]
cols_iter = iter(cols_list[:-1])
vals_iter = iter(vals_list[:-1])
Expand All @@ -110,10 +120,9 @@ class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual):
def as_oracle(self, compiler, connection):
# e.g.: (a, b, c) >= (x, y, z) as SQL:
# WHERE a > x OR (a = x AND (b > y OR (b = y AND (c > z OR c = z))))
cols = self.lhs.get_cols()
lookups = itertools.cycle([GreaterThan, Exact])
connectors = itertools.cycle([OR, AND])
cols_list = [col for col in cols for _ in range(2)]
cols_list = [col for col in self.lhs for _ in range(2)]
vals_list = [val for val in self.rhs for _ in range(2)]
cols_iter = iter(cols_list)
vals_iter = iter(vals_list)
Expand All @@ -137,10 +146,9 @@ class TupleLessThan(TupleLookupMixin, LessThan):
def as_oracle(self, compiler, connection):
# e.g.: (a, b, c) < (x, y, z) as SQL:
# WHERE a < x OR (a = x AND (b < y OR (b = y AND c < z)))
cols = self.lhs.get_cols()
lookups = itertools.cycle([LessThan, Exact])
connectors = itertools.cycle([OR, AND])
cols_list = [col for col in cols for _ in range(2)]
cols_list = [col for col in self.lhs for _ in range(2)]
vals_list = [val for val in self.rhs for _ in range(2)]
cols_iter = iter(cols_list[:-1])
vals_iter = iter(vals_list[:-1])
Expand All @@ -164,10 +172,9 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
def as_oracle(self, compiler, connection):
# e.g.: (a, b, c) <= (x, y, z) as SQL:
# WHERE a < x OR (a = x AND (b < y OR (b = y AND (c < z OR c = z))))
cols = self.lhs.get_cols()
lookups = itertools.cycle([LessThan, Exact])
connectors = itertools.cycle([OR, AND])
cols_list = [col for col in cols for _ in range(2)]
cols_list = [col for col in self.lhs for _ in range(2)]
vals_list = [val for val in self.rhs for _ in range(2)]
cols_iter = iter(cols_list)
vals_iter = iter(vals_list)
Expand All @@ -188,8 +195,9 @@ def as_oracle(self, compiler, connection):


class TupleIn(TupleLookupMixin, In):
def check_tuple_lookup(self):
def get_prep_lookup(self):
self.check_rhs_elements_length_equals_lhs_length()
return super(TupleLookupMixin, self).get_prep_lookup()

def check_rhs_elements_length_equals_lhs_length(self):
len_lhs = len(self.lhs)
Expand All @@ -199,37 +207,40 @@ def check_rhs_elements_length_equals_lhs_length(self):
f"must have {len_lhs} elements each"
)

def as_sql(self, compiler, connection):
if not self.rhs:
def process_rhs(self, compiler, connection):
rhs = self.rhs
if not rhs:
raise EmptyResultSet

# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
# WHERE (a, b, c) IN ((x1, y1, z1), (x2, y2, z2))
rhs = []
for vals in self.rhs:
rhs.append(
result = []
lhs = self.lhs

for vals in rhs:
result.append(
Tuple(
*[
Value(val, output_field=col.output_field)
for col, val in zip(self.lhs, vals)
for col, val in zip(lhs, vals)
]
)
)

lookup = In(Tuple(self.lhs), Tuple(*rhs))
return lookup.as_sql(compiler, connection)
return Tuple(*result).as_sql(compiler, connection)

def as_sqlite(self, compiler, connection):
if not self.rhs:
rhs = self.rhs
if not rhs:
raise EmptyResultSet

# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
root = WhereNode([], connector=OR)
cols = self.lhs.get_cols()
lhs = self.lhs

for vals in self.rhs:
lookups = [Exact(col, val) for col, val in zip(cols, vals)]
for vals in rhs:
lookups = [Exact(col, val) for col, val in zip(lhs, vals)]
root.children.append(WhereNode(lookups, connector=AND))

return root.as_sql(compiler, connection)
Expand Down
Loading

0 comments on commit 5ed7208

Please sign in to comment.