Skip to content

Commit

Permalink
fix: enums not quoted (#1776)
Browse files Browse the repository at this point in the history
* Fix enums not quoted

* Fix mypy complaints

* Make style and add test case for filter by enum

* Fix mysql unittest fails

* chore: upgrade pypika-tortoise and update changelog

* Fix codacy complaint for duplication

* chore: rollback version and add more type hints
  • Loading branch information
waketzheng authored Nov 19, 2024
1 parent 762fe37 commit 49b36ad
Show file tree
Hide file tree
Showing 32 changed files with 428 additions and 405 deletions.
13 changes: 6 additions & 7 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@ Changelog
------
Fixed
^^^^^
- Fix enums not quoted (#1776)
- Primary key field should not be nullable (#1778)

Added
^^^^^
- JSONField adds optional generic support, and supports OpenAPI document generation by specifying `field_type` as a pydantic BaseModel (#1763)

Changed
^^^^^^^
- Change old pydantic docs link to new one (#1775).


0.21.7
0.21.7 <../0.21.7>`_ - 2024-10-14
------
Fixed
^^^^^
Expand All @@ -36,11 +39,7 @@ Added
- Add POSIX Regex support for PostgreSQL and MySQL (#1714)
- support app=None for tortoise.contrib.fastapi.RegisterTortoise (#1733)

Changed
^^^^^^^
- Change old pydantic docs link to new one (#1775).

0.21.6
0.21.6 <../0.21.6>`_ - 2024-08-17
------
Fixed
^^^^^
Expand Down
284 changes: 142 additions & 142 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ classifiers = [

[tool.poetry.dependencies]
python = "^3.8"
pypika-tortoise = "^0.2.1"
pypika-tortoise = "^0.2.2"
iso8601 = "^2.1.0"
aiosqlite = ">=0.16.0, <0.21.0"
pytz = "*"
Expand Down
18 changes: 18 additions & 0 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from decimal import Decimal
from enum import Enum

from tests.testmodels import (
BooleanFields,
Expand All @@ -9,6 +10,15 @@
)
from tortoise.contrib import test
from tortoise.exceptions import FieldError
from tortoise.fields.base import StrEnum


class MyEnum(str, Enum):
moo = "moo"


class MyStrEnum(StrEnum):
moo = "moo"


class TestCharFieldFilters(test.TestCase):
Expand All @@ -29,6 +39,14 @@ async def test_equal(self):
set(await CharFields.filter(char="moo").values_list("char", flat=True)), {"moo"}
)

async def test_enum(self):
self.assertEqual(
set(await CharFields.filter(char=MyEnum.moo).values_list("char", flat=True)), {"moo"}
)
self.assertEqual(
set(await CharFields.filter(char=MyStrEnum.moo).values_list("char", flat=True)), {"moo"}
)

async def test_not(self):
self.assertEqual(
set(await CharFields.filter(char__not="moo").values_list("char", flat=True)),
Expand Down
10 changes: 8 additions & 2 deletions tests/test_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,16 @@ class TestQCall(TestCase):
def setUp(self) -> None:
super().setUp()
self.int_fields_context = ResolveContext(
model=IntFields, table=IntFields._meta.basequery, annotations={}, custom_filters={}
model=IntFields,
table=IntFields._meta.basequery, # type:ignore[arg-type]
annotations={},
custom_filters={},
)
self.char_fields_context = ResolveContext(
model=CharFields, table=CharFields._meta.basequery, annotations={}, custom_filters={}
model=CharFields,
table=CharFields._meta.basequery, # type:ignore[arg-type]
annotations={},
custom_filters={},
)

def test_q_basic(self):
Expand Down
9 changes: 5 additions & 4 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,11 @@ def _build_initial_querysets(cls) -> None:
for model in app.values():
model._meta.finalise_model()
model._meta.basetable = Table(name=model._meta.db_table, schema=model._meta.schema)
model._meta.basequery = model._meta.db.query_class.from_(model._meta.basetable)
model._meta.basequery_all_fields = model._meta.basequery.select(
basequery = model._meta.db.query_class.from_(model._meta.basetable)
model._meta.basequery = basequery # type:ignore[assignment]
model._meta.basequery_all_fields = basequery.select(
*model._meta.db_fields
)
) # type:ignore[assignment]

@classmethod
async def init(
Expand Down Expand Up @@ -517,7 +518,7 @@ async def init(
cls._inited = True

@classmethod
def _init_routers(cls, routers: Optional[List[Union[str, type]]] = None):
def _init_routers(cls, routers: Optional[List[Union[str, type]]] = None) -> None:
from tortoise.router import router

routers = routers or []
Expand Down
13 changes: 7 additions & 6 deletions tortoise/backends/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,9 @@ def __init__(
self.column_map[column] = field_object.to_db_value

table = self.model._meta.basetable
basequery = cast(QueryBuilder, self.model._meta.basequery)
self.delete_query = str(
self.model._meta.basequery.where(
table[self.model._meta.db_pk_column] == self.parameter(0)
).delete()
basequery.where(table[self.model._meta.db_pk_column] == self.parameter(0)).delete()
)
self.update_cache: Dict[str, str] = {}

Expand All @@ -121,13 +120,13 @@ def __init__(
) = EXECUTOR_CACHE[key]

async def execute_explain(self, query: Query) -> Any:
sql = " ".join((self.EXPLAIN_PREFIX, query.get_sql()))
sql = " ".join((self.EXPLAIN_PREFIX, query.get_sql())) # type:ignore[attr-defined]
return (await self.db.execute_query(sql))[1]

async def execute_select(
self, query: Union[Query, RawSQL], custom_fields: Optional[list] = None
) -> list:
_, raw_results = await self.db.execute_query(query.get_sql())
_, raw_results = await self.db.execute_query(query.get_sql()) # type:ignore[union-attr]
instance_list = []
for row in raw_results:
if self.select_related_idx:
Expand Down Expand Up @@ -543,7 +542,9 @@ def _make_prefetch_queries(self) -> None:
relation_field = self.model._meta.fields_map[field_name]
related_model: "Type[Model]" = relation_field.related_model # type: ignore
related_query = related_model.all().using_db(self.db)
related_query.query = copy(related_query.model._meta.basequery)
related_query.query = copy(
related_query.model._meta.basequery
) # type:ignore[assignment]
if forwarded_prefetches:
related_query = related_query.prefetch_related(*forwarded_prefetches)
self._prefetch_queries.setdefault(field_name, []).append((to_attr, related_query))
Expand Down
17 changes: 6 additions & 11 deletions tortoise/backends/base_postgres/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from typing import Optional, Sequence
from typing import Optional, Sequence, cast

from pypika import Parameter
from pypika.dialects import PostgreSQLQueryBuilder
Expand All @@ -23,7 +23,7 @@
)


def postgres_search(field: Term, value: Term):
def postgres_search(field: Term, value: Term) -> SearchCriterion:
return SearchCriterion(field, expr=value)


Expand All @@ -44,15 +44,10 @@ def parameter(self, pos: int) -> Parameter:
def _prepare_insert_statement(
self, columns: Sequence[str], has_generated: bool = True, ignore_conflicts: bool = False
) -> PostgreSQLQueryBuilder:
query = (
self.db.query_class.into(self.model._meta.basetable)
.columns(*columns)
.insert(*[self.parameter(i) for i in range(len(columns))])
)
if has_generated:
generated_fields = self.model._meta.generated_db_fields
if generated_fields:
query = query.returning(*generated_fields)
builder = cast(PostgreSQLQueryBuilder, self.db.query_class.into(self.model._meta.basetable))
query = builder.columns(*columns).insert(*[self.parameter(i) for i in range(len(columns))])
if has_generated and (generated_fields := self.model._meta.generated_db_fields):
query = query.returning(*generated_fields)
if ignore_conflicts:
query = query.on_conflict().do_nothing()
return query
Expand Down
10 changes: 5 additions & 5 deletions tortoise/backends/mysql/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
)


class StrWrapper(ValueWrapper): # type: ignore
class StrWrapper(ValueWrapper):
"""
Naive str wrapper that doesn't use the monkey-patched pypika ValueWrapper for MySQL
"""

def get_value_sql(self, **kwargs):
def get_value_sql(self, **kwargs) -> str:
quote_char = kwargs.get("secondary_quote_char") or ""
value = self.value.replace(quote_char, quote_char * 2)
return format_quotes(value, quote_char)
Expand Down Expand Up @@ -92,12 +92,12 @@ def mysql_insensitive_ends_with(field: Term, value: str) -> Criterion:
)


def mysql_search(field: Term, value: str):
def mysql_search(field: Term, value: str) -> SearchCriterion:
return SearchCriterion(field, expr=StrWrapper(value))


def mysql_posix_regex(field: Term, value: str):
return BasicCriterion(" REGEXP ", field, StrWrapper(value))
def mysql_posix_regex(field: Term, value: str) -> BasicCriterion:
return BasicCriterion(" REGEXP ", field, StrWrapper(value)) # type:ignore[arg-type]


class MySQLExecutor(BaseExecutor):
Expand Down
2 changes: 1 addition & 1 deletion tortoise/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self) -> None:
self._db_config: Optional["DBConfigType"] = None
self._create_db: bool = False

async def _init(self, db_config: "DBConfigType", create_db: bool):
async def _init(self, db_config: "DBConfigType", create_db: bool) -> None:
if self._db_config is None:
self._db_config = db_config
else:
Expand Down
6 changes: 3 additions & 3 deletions tortoise/contrib/mysql/functions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Union
from __future__ import annotations

from pypika.terms import Function, Parameter


class Rand(Function): # type: ignore
class Rand(Function):
"""
Generate random number, with optional seed.
:samp:`Rand()`
"""

def __init__(self, seed: Union[int, None] = None, alias=None) -> None:
def __init__(self, seed: int | None = None, alias=None) -> None:
super().__init__("RAND", seed, alias=alias)
self.args = [self.wrap_constant(seed) if seed is not None else Parameter("")]
43 changes: 17 additions & 26 deletions tortoise/contrib/mysql/json_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json
import operator
from typing import Any, Dict, List
Expand All @@ -7,27 +9,28 @@
from pypika.terms import Function as PypikaFunction
from pypika.terms import Term, ValueWrapper

from tortoise.filters import not_equal
from tortoise.filters import get_json_filter_operator, not_equal


class JSONContains(PypikaFunction): # type: ignore
def __init__(self, column_name: Term, target_list: Term):
super(JSONContains, self).__init__("JSON_CONTAINS", column_name, target_list)
class JSONContains(PypikaFunction):
def __init__(self, column_name: Term, target_list: Term) -> None:
super().__init__("JSON_CONTAINS", column_name, target_list)


class JSONExtract(PypikaFunction): # type: ignore
def __init__(self, column_name: Term, query_list: List[Term]):
class JSONExtract(PypikaFunction):
def __init__(self, column_name: Term, query_list: List[int | str | Term]) -> None:
query = self.make_query(query_list)
super(JSONExtract, self).__init__("JSON_EXTRACT", column_name, query)
super().__init__("JSON_EXTRACT", column_name, query)

@classmethod
def serialize_value(cls, value: Any):
def serialize_value(cls, value: Any) -> str:
if isinstance(value, int):
return f"[{value}]"
if isinstance(value, str):
return f".{value}"
return str(value)

def make_query(self, query_list: List[Term]):
def make_query(self, query_list: List[Term | int | str]) -> str:
query = ["$"]
for value in query_list:
query.append(self.serialize_value(value))
Expand All @@ -39,7 +42,7 @@ def mysql_json_contains(field: Term, value: str) -> Criterion:
return JSONContains(field, ValueWrapper(value))


def mysql_json_contained_by(field: Term, value_str: str) -> Criterion:
def mysql_json_contained_by(field: Term, value_str: str) -> JSONContains | None:
values = json.loads(value_str)
contained_by = None
for value in values:
Expand All @@ -50,14 +53,14 @@ def mysql_json_contained_by(field: Term, value_str: str) -> Criterion:
return contained_by


def _mysql_json_is_null(left: Term, is_null: bool):
def _mysql_json_is_null(left: Term, is_null: bool) -> Criterion:
if is_null:
return operator.eq(left, Cast("null", "JSON"))
else:
return not_equal(left, Cast("null", "JSON"))


def _mysql_json_not_is_null(left: Term, is_null: bool):
def _mysql_json_not_is_null(left: Term, is_null: bool) -> Criterion:
return _mysql_json_is_null(left, not is_null)


Expand All @@ -68,18 +71,6 @@ def _mysql_json_not_is_null(left: Term, is_null: bool):
}


def _serialize_value(value: Any):
if type(value) in [dict, list]:
return json.dumps(value)
return value


def mysql_json_filter(field: Term, value: Dict) -> Criterion:
((key, filter_value),) = value.items()
filter_value = _serialize_value(filter_value)
key_parts = [int(item) if item.isdigit() else str(item) for item in key.split("__")]
operator_ = operator.eq
if key_parts[-1] in operator_keywords:
operator_ = operator_keywords[str(key_parts.pop(-1))] # type: ignore

return operator_(JSONExtract(field, key_parts), filter_value)
key_parts, filter_value, operator_ = get_json_filter_operator(value, operator_keywords)
return operator_(JSONExtract(field, key_parts), filter_value) # type:ignore[arg-type]
14 changes: 7 additions & 7 deletions tortoise/contrib/mysql/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pypika.terms import Term


class Comp(Comparator): # type: ignore
class Comp(Comparator):
search = " "


Expand All @@ -18,13 +18,13 @@ class Mode(Enum):
WITH_QUERY_EXPRESSION = "WITH QUERY EXPANSION"


class Match(PypikaFunction): # type: ignore
def __init__(self, *columns: Term):
class Match(PypikaFunction):
def __init__(self, *columns: Term) -> None:
super(Match, self).__init__("MATCH", *columns)


class Against(PypikaFunction): # type: ignore
def __init__(self, expr: Term, mode: Optional[Mode] = None):
class Against(PypikaFunction):
def __init__(self, expr: Term, mode: Optional[Mode] = None) -> None:
super(Against, self).__init__("AGAINST", expr)
self.mode = mode

Expand All @@ -34,10 +34,10 @@ def get_special_params_sql(self, **kwargs: Any) -> Any:
return self.mode.value


class SearchCriterion(BasicCriterion): # type: ignore
class SearchCriterion(BasicCriterion):
"""
Only support for CharField, TextField with full search indexes.
"""

def __init__(self, *columns: Term, expr: Term, mode: Optional[Mode] = None):
def __init__(self, *columns: Term, expr: Term, mode: Optional[Mode] = None) -> None:
super().__init__(Comp.search, Match(*columns), Against(expr, mode))
Loading

0 comments on commit 49b36ad

Please sign in to comment.