diff --git a/server/src/weaverbird/backends/pypika_translator/translators/base.py b/server/src/weaverbird/backends/pypika_translator/translators/base.py index 846502d6a..cb42d269d 100644 --- a/server/src/weaverbird/backends/pypika_translator/translators/base.py +++ b/server/src/weaverbird/backends/pypika_translator/translators/base.py @@ -2,6 +2,7 @@ from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from datetime import UTC, date, datetime +from enum import Enum from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, cast, get_args from uuid import uuid4 @@ -11,7 +12,6 @@ Case, Criterion, Field, - Order, Query, Schema, Table, @@ -223,6 +223,12 @@ def __init__(self, date_part, interval, term, alias=None): super().__init__("DATEADD", LiteralValue(date_part), interval, term, alias=alias) +class Order(str, Enum): + ASC = "ASC" + DESC = "DESC" + ASC_NULLS_LAST = "ASC NULLS LAST" + + class SQLTranslator(ABC): DIALECT: SQLDialect QUERY_CLS: Query @@ -602,7 +608,7 @@ def _build_window_subquery() -> Any: window_table = Table("window_subquery") all_windows_subquery = _build_window_subquery() agg_query = ( - prev_step_table.select(*agg_selected, *step.on).groupby(*step.on).orderby(*step.on, order=Order.asc) + prev_step_table.select(*agg_selected, *step.on).groupby(*step.on).orderby(*step.on, order=Order.ASC) ).as_("agg_subquery") agg_table = Table("agg_subquery") merged_selected: list[str | Field] = [ @@ -652,7 +658,9 @@ def _build_window_subquery() -> Any: ) ) selected_col_names = [*columns, *agg_col_names] - return StepContext(query.orderby(*step.on) if step.on else query, selected_col_names) + return StepContext( + query.orderby(*step.on, order=Order.ASC_NULLS_LAST) if step.on else query, selected_col_names + ) else: selected_col_names = [ @@ -661,7 +669,7 @@ def _build_window_subquery() -> Any: *(f[1].alias for f in window_selected), ] return StepContext( - merged_query.orderby(*step.on) if step.on else merged_query, + merged_query.orderby(*step.on, order=Order.ASC_NULLS_LAST) if step.on else merged_query, selected_col_names, ) @@ -1629,7 +1637,7 @@ def rank( analytics_fn = analytics.Rank if step.method == "standard" else analytics.DenseRank rank_column = ( (analytics_fn().over(*(Field(group) for group in step.groupby)) if step.groupby else analytics_fn()) - .orderby(col_field, order=Order.desc if step.order == "desc" else Order.asc) + .orderby(col_field, order=Order.DESC if step.order == "desc" else Order.ASC) .as_(new_col_name) ) @@ -1715,7 +1723,7 @@ def sort( query = prev_step_table.select(*columns) for column_sort in step.columns: - query = query.orderby(column_sort.column, order=Order.desc if column_sort.order == "desc" else Order.asc) + query = query.orderby(column_sort.column, order=Order.DESC if column_sort.order == "desc" else Order.ASC) return StepContext(query, columns) @@ -1857,7 +1865,7 @@ def top( sub_query = sub_query.select( RowNumber() .over(*groups_fields) - .orderby(rank_on_field, order=Order.desc if step.sort == "desc" else Order.asc) + .orderby(rank_on_field, order=Order.DESC if step.sort == "desc" else Order.ASC) .as_("row_number") ) query: QueryBuilder = ( @@ -1866,7 +1874,7 @@ def top( .where(Field("row_number") <= step.limit) # The order of returned results is not necessarily consistent. This ensures we # always get the results in the same order - .orderby(*(Field(f) for f in step.groups + ["row_number"]), order=Order.asc) + .orderby(*(Field(f) for f in step.groups + ["row_number"]), order=Order.ASC) ) return StepContext(query, columns) @@ -1875,7 +1883,7 @@ def top( query = ( prev_step_table.select(*columns) - .orderby(step.rank_on, order=Order.desc if step.sort == "desc" else Order.asc) + .orderby(step.rank_on, order=Order.DESC if step.sort == "desc" else Order.ASC) .limit(step.limit) ) return StepContext(query, columns)