From 2b9608955690cd4162a8afc132845859c5a0fe59 Mon Sep 17 00:00:00 2001 From: Luka Peschke Date: Fri, 15 Dec 2023 14:35:21 +0100 Subject: [PATCH] chore(deps): replace black with ruff Signed-off-by: Luka Peschke --- server/.pre-commit-config.yaml | 4 +- server/Makefile | 6 +- server/playground.py | 52 +---- server/poetry.lock | 55 +---- server/pyproject.toml | 5 - .../mongo_translator/steps/addmissingdates.py | 16 +- .../backends/mongo_translator/steps/cumsum.py | 4 +- .../mongo_translator/steps/date_extract.py | 4 +- .../mongo_translator/steps/evolution.py | 10 +- .../backends/mongo_translator/steps/join.py | 12 +- .../mongo_translator/steps/moving_average.py | 7 +- .../backends/mongo_translator/steps/rank.py | 4 +- .../backends/mongo_translator/steps/rollup.py | 4 +- .../backends/mongo_translator/steps/sort.py | 9 +- .../mongo_translator/steps/statistics.py | 22 +- .../mongo_translator/steps/substring.py | 18 +- .../backends/mongo_translator/steps/todate.py | 23 +- .../backends/mongo_translator/steps/totals.py | 10 +- .../mongo_translator/steps/waterfall.py | 59 ++--- .../backends/mongo_translator/utils.py | 14 +- .../backends/pandas_executor/geo_utils.py | 4 +- .../pandas_executor/pipeline_executor.py | 8 +- .../pandas_executor/steps/aggregate.py | 6 +- .../pandas_executor/steps/date_extract.py | 8 +- .../pandas_executor/steps/hierarchy.py | 4 +- .../pandas_executor/steps/ifthenelse.py | 8 +- .../backends/pandas_executor/steps/join.py | 4 +- .../pandas_executor/steps/replacetext.py | 4 +- .../backends/pandas_executor/steps/todate.py | 4 +- .../backends/pandas_executor/steps/totals.py | 8 +- .../pandas_executor/steps/utils/condition.py | 4 +- .../pandas_executor/steps/utils/formula.py | 6 +- .../pandas_executor/steps/waterfall.py | 26 +-- .../pypika_translator/translators/athena.py | 4 +- .../pypika_translator/translators/base.py | 208 +++++------------- .../translators/googlebigquery.py | 32 +-- .../pypika_translator/translators/mysql.py | 20 +- .../pypika_translator/translators/redshift.py | 8 +- .../translators/snowflake.py | 8 +- .../weaverbird/pipeline/formula_ast/eval.py | 20 +- server/src/weaverbird/pipeline/pipeline.py | 12 +- .../src/weaverbird/pipeline/steps/append.py | 8 +- .../weaverbird/pipeline/steps/customsql.py | 4 +- .../src/weaverbird/pipeline/steps/dissolve.py | 8 +- server/src/weaverbird/pipeline/steps/join.py | 4 +- .../pipeline/steps/utils/combination.py | 4 +- .../mongo_translator/steps/test_formula.py | 14 +- .../test_mongo_translator_steps.py | 12 +- .../test_pandas_executor_steps.py | 4 +- .../pandas_executor/utils/test_dates.py | 4 +- .../tests/backends/sql_translator/common.py | 5 +- .../test_sql_athena_translator_steps.py | 8 +- .../test_sql_bigquery_translator_steps.py | 8 +- .../test_sql_mysql_translator_steps.py | 16 +- .../test_sql_postgres_translator_steps.py | 16 +- .../test_sql_redshift_translator_steps.py | 12 +- .../test_sql_snowflake_translator_steps.py | 16 +- .../test_base_translator.py | 116 +++------- .../test_base_translator_strings.py | 8 +- .../test_date_format_translators.py | 8 +- .../test_filter_translators.py | 70 ++---- .../test_google_big_query.py | 4 +- .../test_row_number_translators.py | 33 +-- .../test_snowflake_translator.py | 34 +-- .../test_split_part_translators.py | 8 +- ...t_translation_with_mergeable_first_step.py | 12 +- server/tests/steps/test_addmissingdates.py | 50 ++--- server/tests/steps/test_convert.py | 4 +- server/tests/steps/test_cumsum.py | 4 +- server/tests/steps/test_date_extract.py | 4 +- server/tests/steps/test_duration.py | 4 +- server/tests/steps/test_evolution.py | 8 +- server/tests/steps/test_fillna.py | 4 +- server/tests/steps/test_filter.py | 28 +-- server/tests/steps/test_formula.py | 4 +- server/tests/steps/test_ifthenelse.py | 12 +- server/tests/steps/test_join.py | 8 +- server/tests/steps/test_moving_average.py | 13 +- server/tests/steps/test_percentage.py | 4 +- server/tests/steps/test_rank.py | 4 +- server/tests/steps/test_replacetext.py | 4 +- server/tests/steps/test_substring.py | 8 +- server/tests/steps/test_todate.py | 5 +- server/tests/steps/test_totals.py | 12 +- server/tests/steps/test_unpivot.py | 5 +- server/tests/test_pipeline.py | 26 +-- server/tests/test_pipeline_executor.py | 8 +- 87 files changed, 337 insertions(+), 1062 deletions(-) diff --git a/server/.pre-commit-config.yaml b/server/.pre-commit-config.yaml index 73df6eae78..145bcd6673 100644 --- a/server/.pre-commit-config.yaml +++ b/server/.pre-commit-config.yaml @@ -14,8 +14,8 @@ repos: language: system - id: system - name: Lint with Black - entry: black + name: Lint with Ruff format + entry: ruff format types: [python] language: system diff --git a/server/Makefile b/server/Makefile index a38f66aae3..dfab107b57 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,6 +1,6 @@ .DEFAULT_GOAL := all ruff = ruff src tests playground.py -black = black src tests playground.py +format = ruff format src tests playground.py .PHONY: clean clean: @@ -24,12 +24,12 @@ install-playground: .PHONY: format format: poetry run $(ruff) --fix - poetry run $(black) + poetry run $(format) .PHONY: lint lint: $(ruff) - $(black) --check + $(format) --check mypy .PHONY: test diff --git a/server/playground.py b/server/playground.py index 0bc7b68089..9ac0cbd370 100644 --- a/server/playground.py +++ b/server/playground.py @@ -122,14 +122,8 @@ class ColumnType(str, Enum): DOMAINS = { **{splitext(basename(csv_file))[0]: pd.read_csv(csv_file) for csv_file in csv_files}, - **{ - splitext(basename(json_file))[0]: pd.read_json(json_file, orient="table") - for json_file in json_files - }, - **{ - splitext(basename(geojson_file))[0]: gpd.read_file(geojson_file) - for geojson_file in geojson_files - }, + **{splitext(basename(json_file))[0]: pd.read_json(json_file, orient="table") for json_file in json_files}, + **{splitext(basename(geojson_file))[0]: gpd.read_file(geojson_file) for geojson_file in geojson_files}, } @@ -140,9 +134,7 @@ def get_available_domains(): def sanitize_table_schema(schema: dict) -> dict: return { "fields": [ - {"name": field["name"], "type": "geometry"} - if field.get("extDtype") == "geometry" - else field + {"name": field["name"], "type": "geometry"} if field.get("extDtype") == "geometry" else field for field in schema["fields"] ] } @@ -467,11 +459,7 @@ def get_table_columns(): for table in tables_info: with suppress(Exception): table_name = table[1] - infos = ( - _SNOWFLAKE_CONNECTION.cursor() - .execute(f'DESCRIBE TABLE "{table_name}";') - .fetchall() - ) + infos = _SNOWFLAKE_CONNECTION.cursor().execute(f'DESCRIBE TABLE "{table_name}";').fetchall() tables_columns[table_name] = [info[0] for info in infos if info[2] == "COLUMN"] return tables_columns @@ -497,11 +485,7 @@ async def handle_snowflake_backend_request(): tables_columns=tables_columns, ) - total_count = ( - _SNOWFLAKE_CONNECTION.cursor() - .execute(f"SELECT COUNT(*) FROM ({query})") - .fetchone()[0] - ) + total_count = _SNOWFLAKE_CONNECTION.cursor().execute(f"SELECT COUNT(*) FROM ({query})").fetchone()[0] # By using snowflake's connector ability to turn results into a DataFrame, # we can re-use all the methods to parse this data- interchange format in the front-end df_results = ( @@ -558,16 +542,12 @@ def postgresql_type_to_data_type(pg_type: str) -> ColumnType | None: @app.route("/postgresql", methods=["GET", "POST"]) async def handle_postgres_backend_request(): # improve by using a connexion pool - postgresql_connexion = await psycopg.AsyncConnection.connect( - os.getenv("POSTGRESQL_CONNECTION_STRING") - ) + postgresql_connexion = await psycopg.AsyncConnection.connect(os.getenv("POSTGRESQL_CONNECTION_STRING")) db_schema = "public" if request.method == "GET": async with postgresql_connexion.cursor() as cur: - tables_info_exec = await cur.execute( - f"SELECT * FROM pg_catalog.pg_tables WHERE schemaname='{db_schema}';" - ) + tables_info_exec = await cur.execute(f"SELECT * FROM pg_catalog.pg_tables WHERE schemaname='{db_schema}';") tables_info = await tables_info_exec.fetchall() return jsonify([table_infos[1] for table_infos in tables_info]) @@ -598,9 +578,7 @@ async def handle_postgres_backend_request(): ) async with postgresql_connexion.cursor() as cur: - query_total_count_exec = await cur.execute( - f"WITH Q AS ({sql_query}) SELECT COUNT(*) FROM Q" - ) + query_total_count_exec = await cur.execute(f"WITH Q AS ({sql_query}) SELECT COUNT(*) FROM Q") # fetchone() returns a tuple query_total_count = (await query_total_count_exec.fetchone())[0] @@ -623,9 +601,7 @@ async def handle_postgres_backend_request(): query_results_columns = [ { "name": c.name, - "type": [ - postgresql_type_to_data_type(t[1]) for t in types if t[0] == c.type_code - ][0], + "type": [postgresql_type_to_data_type(t[1]) for t in types if t[0] == c.type_code][0], } for c in query_results_desc ] @@ -681,10 +657,7 @@ async def handle_athena_post_request(): # Find all columns for all available tables table_info = _athena_table_info() - tables_columns = { - row["Table"]: [c.strip() for c in row["Columns"].split(",")] - for _, row in table_info.iterrows() - } + tables_columns = {row["Table"]: [c.strip() for c in row["Columns"].split(",")] for _, row in table_info.iterrows()} sql_query = ( pypika_translate_pipeline( @@ -731,10 +704,7 @@ def _bigquery_tables_list(client: bigquery.Client) -> list[str]: def _bigquery_tables_info(client: bigquery.Client) -> dict[str, list[str]]: - return { - table: [field.name for field in client.get_table(table).schema] - for table in _bigquery_tables_list(client) - } + return {table: [field.name for field in client.get_table(table).schema] for table in _bigquery_tables_list(client)} @app.get("/google-big-query") diff --git a/server/poetry.lock b/server/poetry.lock index 3e49c4d32f..702089b4f7 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -228,46 +228,6 @@ soupsieve = ">1.2" html5lib = ["html5lib"] lxml = ["lxml"] -[[package]] -name = "black" -version = "23.11.0" -description = "The uncompromising code formatter." -optional = false -python-versions = ">=3.8" -files = [ - {file = "black-23.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dbea0bb8575c6b6303cc65017b46351dc5953eea5c0a59d7b7e3a2d2f433a911"}, - {file = "black-23.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:412f56bab20ac85927f3a959230331de5614aecda1ede14b373083f62ec24e6f"}, - {file = "black-23.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d136ef5b418c81660ad847efe0e55c58c8208b77a57a28a503a5f345ccf01394"}, - {file = "black-23.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:6c1cac07e64433f646a9a838cdc00c9768b3c362805afc3fce341af0e6a9ae9f"}, - {file = "black-23.11.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cf57719e581cfd48c4efe28543fea3d139c6b6f1238b3f0102a9c73992cbb479"}, - {file = "black-23.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:698c1e0d5c43354ec5d6f4d914d0d553a9ada56c85415700b81dc90125aac244"}, - {file = "black-23.11.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:760415ccc20f9e8747084169110ef75d545f3b0932ee21368f63ac0fee86b221"}, - {file = "black-23.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:58e5f4d08a205b11800332920e285bd25e1a75c54953e05502052738fe16b3b5"}, - {file = "black-23.11.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:45aa1d4675964946e53ab81aeec7a37613c1cb71647b5394779e6efb79d6d187"}, - {file = "black-23.11.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c44b7211a3a0570cc097e81135faa5f261264f4dfaa22bd5ee2875a4e773bd6"}, - {file = "black-23.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a9acad1451632021ee0d146c8765782a0c3846e0e0ea46659d7c4f89d9b212b"}, - {file = "black-23.11.0-cp38-cp38-win_amd64.whl", hash = "sha256:fc7f6a44d52747e65a02558e1d807c82df1d66ffa80a601862040a43ec2e3142"}, - {file = "black-23.11.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7f622b6822f02bfaf2a5cd31fdb7cd86fcf33dab6ced5185c35f5db98260b055"}, - {file = "black-23.11.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:250d7e60f323fcfc8ea6c800d5eba12f7967400eb6c2d21ae85ad31c204fb1f4"}, - {file = "black-23.11.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5133f5507007ba08d8b7b263c7aa0f931af5ba88a29beacc4b2dc23fcefe9c06"}, - {file = "black-23.11.0-cp39-cp39-win_amd64.whl", hash = "sha256:421f3e44aa67138ab1b9bfbc22ee3780b22fa5b291e4db8ab7eee95200726b07"}, - {file = "black-23.11.0-py3-none-any.whl", hash = "sha256:54caaa703227c6e0c87b76326d0862184729a69b73d3b7305b6288e1d830067e"}, - {file = "black-23.11.0.tar.gz", hash = "sha256:4c68855825ff432d197229846f971bc4d6666ce90492e5b02013bcaca4d9ab05"}, -] - -[package.dependencies] -click = ">=8.0.0" -mypy-extensions = ">=0.4.3" -packaging = ">=22.0" -pathspec = ">=0.9.0" -platformdirs = ">=2" - -[package.extras] -colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)"] -jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] -uvloop = ["uvloop (>=0.15.2)"] - [[package]] name = "blinker" version = "1.6.2" @@ -525,7 +485,7 @@ files = [ name = "click" version = "8.1.7" description = "Composable command line interface toolkit" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, @@ -2020,17 +1980,6 @@ pytz = ">=2020.1" [package.extras] test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] -[[package]] -name = "pathspec" -version = "0.11.2" -description = "Utility library for gitignore style pattern matching of file paths." -optional = false -python-versions = ">=3.7" -files = [ - {file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"}, - {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, -] - [[package]] name = "platformdirs" version = "3.8.1" @@ -3637,4 +3586,4 @@ pypika = ["PyPika"] [metadata] lock-version = "2.0" python-versions = ">=3.11, <3.12" -content-hash = "06e4f01ba9605761f4088ef165bf890385fd7d88fc0776a34ec5afff66bc53dc" +content-hash = "8e923c75b5bff3f198b5913a26d6fe558eb930f488d712c5346bb76c6f4bb655" diff --git a/server/pyproject.toml b/server/pyproject.toml index 510f69e6fd..12834be776 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -42,7 +42,6 @@ snowflake-sqlalchemy = "^1.5.0" types-python-dateutil = "^2.8.19" pytest = "^7.4.1" pytest-xdist = ">=2.5,<4.0" -black = "^23.7.0" mypy = ">=0.990,<2" docker = "^6.1.3" sqlalchemy = "^1.4.49" @@ -62,10 +61,6 @@ all = ["pandas", "geopandas", "pypika"] # playground playground = ["quart", "Quart-CORS", "hypercorn", "pymongo", "pandas", "psycopg", "toucan-connectors"] -[tool.black] -line-length = 100 -target-version = ["py310"] - [tool.mypy] files = "src/" exclude = "weaverbird/backends/sql_translator" diff --git a/server/src/weaverbird/backends/mongo_translator/steps/addmissingdates.py b/server/src/weaverbird/backends/mongo_translator/steps/addmissingdates.py index 0585129555..eebbab1584 100644 --- a/server/src/weaverbird/backends/mongo_translator/steps/addmissingdates.py +++ b/server/src/weaverbird/backends/mongo_translator/steps/addmissingdates.py @@ -99,9 +99,7 @@ def _add_missing_dates_day_or_month(step: AddMissingDatesStep) -> list[MongoStep # use the variable in the following expression, in which we recreate a date which granularity will # depend on the user-specified granularity "in": { - "$dateFromParts": _generate_date_from_parts( - "$$currentDay", step.dates_granularity - ), + "$dateFromParts": _generate_date_from_parts("$$currentDay", step.dates_granularity), }, }, }, @@ -132,9 +130,7 @@ def _add_missing_dates_day_or_month(step: AddMissingDatesStep) -> list[MongoStep add_missing_dates = { "$map": { # loop over unique dates array - "input": all_days_range - if step.dates_granularity == "day" - else unique_days_for_month_granularity, + "input": all_days_range if step.dates_granularity == "day" else unique_days_for_month_granularity, # use a variable "date" as cursor "as": "date", # and apply the following expression to every "date" @@ -167,9 +163,7 @@ def _add_missing_dates_day_or_month(step: AddMissingDatesStep) -> list[MongoStep { "$addFields": { "_vqbDay": { - "$dateFromParts": _generate_date_from_parts( - f"${step.dates_column}", step.dates_granularity - ), + "$dateFromParts": _generate_date_from_parts(f"${step.dates_column}", step.dates_granularity), }, }, }, @@ -200,9 +194,7 @@ def _add_missing_dates_day_or_month(step: AddMissingDatesStep) -> list[MongoStep def translate_addmissingdates(step: AddMissingDatesStep) -> list[MongoStep]: return ( - _add_missing_dates_year(step) - if step.dates_granularity == "year" - else _add_missing_dates_day_or_month(step) + _add_missing_dates_year(step) if step.dates_granularity == "year" else _add_missing_dates_day_or_month(step) ) + [ # Get back to 1 row per document {"$unwind": "$_vqbAllDates"}, diff --git a/server/src/weaverbird/backends/mongo_translator/steps/cumsum.py b/server/src/weaverbird/backends/mongo_translator/steps/cumsum.py index c20d8e21cb..08ef81f7fa 100644 --- a/server/src/weaverbird/backends/mongo_translator/steps/cumsum.py +++ b/server/src/weaverbird/backends/mongo_translator/steps/cumsum.py @@ -19,9 +19,7 @@ def translate_cumsum(step: CumSumStep) -> list[MongoStep]: "$project": { **{col: f"$_id.{col}" for col in groupby}, **{ - new_name - if new_name - else f"{name}_CUMSUM": { + new_name if new_name else f"{name}_CUMSUM": { "$sum": {"$slice": [f"${name}", {"$add": ["$_VQB_INDEX", 1]}]} } for name, new_name in step.to_cumsum diff --git a/server/src/weaverbird/backends/mongo_translator/steps/date_extract.py b/server/src/weaverbird/backends/mongo_translator/steps/date_extract.py index 30cb9e55ca..950672f78c 100644 --- a/server/src/weaverbird/backends/mongo_translator/steps/date_extract.py +++ b/server/src/weaverbird/backends/mongo_translator/steps/date_extract.py @@ -299,9 +299,7 @@ def translate_date_extract(step: DateExtractStep) -> list[MongoStep]: # For retrocompatibility if step.operation: date_info = [step.operation] if step.operation else step.date_info - new_columns = [ - step.new_column_name if step.new_column_name else f"{step.column}_{step.operation}" - ] + new_columns = [step.new_column_name if step.new_column_name else f"{step.column}_{step.operation}"] else: date_info = step.date_info.copy() new_columns = step.new_columns.copy() diff --git a/server/src/weaverbird/backends/mongo_translator/steps/evolution.py b/server/src/weaverbird/backends/mongo_translator/steps/evolution.py index ae721701c3..65645b9691 100644 --- a/server/src/weaverbird/backends/mongo_translator/steps/evolution.py +++ b/server/src/weaverbird/backends/mongo_translator/steps/evolution.py @@ -5,11 +5,7 @@ def translate_evolution(step: EvolutionStep) -> list[MongoStep]: - new_column = ( - step.new_column - if step.new_column - else f"{step.value_col}_EVOL_{step.evolution_format.upper()}" - ) + new_column = step.new_column if step.new_column else f"{step.value_col}_EVOL_{step.evolution_format.upper()}" error_msg = "Error: More than one previous date found for the specified index columns" add_field_result: dict[str, Any] = {} @@ -56,9 +52,7 @@ def translate_evolution(step: EvolutionStep) -> list[MongoStep]: { "$facet": { "_VQB_ORIGINALS": [{"$project": {"_id": 0}}], - "_VQB_COPIES_ARRAY": [ - {"$group": {"_id": None, "_VQB_ALL_DOCS": {"$push": "$$ROOT"}}} - ], + "_VQB_COPIES_ARRAY": [{"$group": {"_id": None, "_VQB_ALL_DOCS": {"$push": "$$ROOT"}}}], }, }, {"$unwind": "$_VQB_ORIGINALS"}, diff --git a/server/src/weaverbird/backends/mongo_translator/steps/join.py b/server/src/weaverbird/backends/mongo_translator/steps/join.py index a1ee2f991e..09feb47e79 100644 --- a/server/src/weaverbird/backends/mongo_translator/steps/join.py +++ b/server/src/weaverbird/backends/mongo_translator/steps/join.py @@ -29,9 +29,7 @@ def translate_join(step: JoinStep) -> list[MongoStep]: right_without_domain.steps = [s.copy(deep=True) for s in right[1:]] else: right_domain = DomainStep(**right[0]) - right_without_domain.steps = [ - getattr(steps, f"{s['name'].capitalize()}Step")(**s) for s in right[1:] - ] + right_without_domain.steps = [getattr(steps, f"{s['name'].capitalize()}Step")(**s) for s in right[1:]] mongo_let: dict[str, str] = {} mongo_expr_and: list[dict[str, list[str]]] = [] @@ -57,14 +55,10 @@ def translate_join(step: JoinStep) -> list[MongoStep]: if step.type == "inner": mongo_pipeline.append({"$unwind": "$_vqbJoinKey"}) elif step.type == "left": - mongo_pipeline.append( - {"$unwind": {"path": "$_vqbJoinKey", "preserveNullAndEmptyArrays": True}} - ) + mongo_pipeline.append({"$unwind": {"path": "$_vqbJoinKey", "preserveNullAndEmptyArrays": True}}) else: mongo_pipeline.append({"$match": {"_vqbJoinKey": {"$eq": []}}}) - mongo_pipeline.append( - {"$unwind": {"path": "$_vqbJoinKey", "preserveNullAndEmptyArrays": True}} - ) + mongo_pipeline.append({"$unwind": {"path": "$_vqbJoinKey", "preserveNullAndEmptyArrays": True}}) mongo_pipeline.append( { diff --git a/server/src/weaverbird/backends/mongo_translator/steps/moving_average.py b/server/src/weaverbird/backends/mongo_translator/steps/moving_average.py index 727f6551ba..267efd845a 100644 --- a/server/src/weaverbird/backends/mongo_translator/steps/moving_average.py +++ b/server/src/weaverbird/backends/mongo_translator/steps/moving_average.py @@ -26,9 +26,7 @@ def translate_moving_average(step: MovingAverageStep) -> list[MongoStep]: "in": { "$cond": [ # If the index is less than the moving window minus 1... - { - "$lt": ["$$idx", (step.moving_window) - 1] - }, # explicit type for typescript + {"$lt": ["$$idx", (step.moving_window) - 1]}, # explicit type for typescript # ... then we cannot apply the moving average computation, and # we just keep the original document without any new field... {"$arrayElemAt": ["$_vqbArray", "$$idx"]}, @@ -40,8 +38,7 @@ def translate_moving_average(step: MovingAverageStep) -> list[MongoStep]: {"$arrayElemAt": ["$_vqbArray", "$$idx"]}, # and add the new moving average column { - step.new_column_name - or f"{step.value_column}_MOVING_AVG": { + step.new_column_name or f"{step.value_column}_MOVING_AVG": { "$avg": { "$slice": [ f"$_vqbArray.{step.value_column}", diff --git a/server/src/weaverbird/backends/mongo_translator/steps/rank.py b/server/src/weaverbird/backends/mongo_translator/steps/rank.py index f35bbb9d42..5c2df849f9 100644 --- a/server/src/weaverbird/backends/mongo_translator/steps/rank.py +++ b/server/src/weaverbird/backends/mongo_translator/steps/rank.py @@ -63,9 +63,7 @@ def translate_rank(step: RankStep) -> list[MongoStep]: "$$this", { ( - step.new_column_name - if step.new_column_name - else f"{step.value_col}_RANK" + step.new_column_name if step.new_column_name else f"{step.value_col}_RANK" ): "$$rank" }, ], diff --git a/server/src/weaverbird/backends/mongo_translator/steps/rollup.py b/server/src/weaverbird/backends/mongo_translator/steps/rollup.py index 2aff6a084d..23e35a31cd 100644 --- a/server/src/weaverbird/backends/mongo_translator/steps/rollup.py +++ b/server/src/weaverbird/backends/mongo_translator/steps/rollup.py @@ -23,9 +23,7 @@ def translate_rollup(step: RollupStep) -> list[MongoStep]: cols = agg_step.columns # _id is a name reserved by mongo. If for some reason, a user wants to aggregate the _id # column, the aggregation result will be stored in a new __id column - new_cols = [ - col if col != _ID_COLUMN else _NEW_ID_COLUMN for col in agg_step.new_columns - ] + new_cols = [col if col != _ID_COLUMN else _NEW_ID_COLUMN for col in agg_step.new_columns] if agg_step.agg_function == "count": for i in range(len(cols)): diff --git a/server/src/weaverbird/backends/mongo_translator/steps/sort.py b/server/src/weaverbird/backends/mongo_translator/steps/sort.py index b238460c9f..fd9760bd0e 100644 --- a/server/src/weaverbird/backends/mongo_translator/steps/sort.py +++ b/server/src/weaverbird/backends/mongo_translator/steps/sort.py @@ -3,11 +3,4 @@ def translate_sort(step: SortStep) -> list[MongoStep]: - return [ - { - "$sort": { - sort_column.column: 1 if sort_column.order == "asc" else -1 - for sort_column in step.columns - } - } - ] + return [{"$sort": {sort_column.column: 1 if sort_column.order == "asc" else -1 for sort_column in step.columns}}] diff --git a/server/src/weaverbird/backends/mongo_translator/steps/statistics.py b/server/src/weaverbird/backends/mongo_translator/steps/statistics.py index 0ac3495fb3..b59aad77d3 100644 --- a/server/src/weaverbird/backends/mongo_translator/steps/statistics.py +++ b/server/src/weaverbird/backends/mongo_translator/steps/statistics.py @@ -38,11 +38,7 @@ def _need_to_compute_column_square(step: StatisticsStep) -> bool: def _need_to_compute_average(step: StatisticsStep) -> bool: - return ( - "average" in step.statistics - or "variance" in step.statistics - or "standard deviation" in step.statistics - ) + return "average" in step.statistics or "variance" in step.statistics or "standard deviation" in step.statistics def _need_to_sort(step: StatisticsStep) -> bool: @@ -72,11 +68,7 @@ def translate_statistics(step: StatisticsStep) -> list[MongoStep]: "$project": { **{col: 1 for col in step.groupby_columns}, "column": f"${step.column}", - **( - {"column_square": {"$pow": [f"${step.column}", 2]}} - if _need_to_compute_column_square(step) - else {} - ), + **({"column_square": {"$pow": [f"${step.column}", 2]}} if _need_to_compute_column_square(step) else {}), }, }, { @@ -92,11 +84,7 @@ def translate_statistics(step: StatisticsStep) -> list[MongoStep]: **({"count": {"$sum": 1}} if _need_to_count(step) else {}), **({"max": {"$max": "$column"}} if "max" in step.statistics else {}), **({"min": {"$min": "$column"}} if "min" in step.statistics else {}), - **( - {"average_sum_square": {"$avg": "$column_square"}} - if _need_to_compute_column_square(step) - else {} - ), + **({"average_sum_square": {"$avg": "$column_square"}} if _need_to_compute_column_square(step) else {}), **({"average": {"$avg": "$column"}} if _need_to_compute_average(step) else {}), }, }, @@ -108,9 +96,7 @@ def translate_statistics(step: StatisticsStep) -> list[MongoStep]: **{statistic: _STATISTICS_FORMULA[statistic] for statistic in step.statistics}, # quantiles **{ - quantile.label - if quantile.label - else f"{quantile.nth}-th {quantile.order}-quantile": _get_quantile( + quantile.label if quantile.label else f"{quantile.nth}-th {quantile.order}-quantile": _get_quantile( quantile.nth, quantile.order ) for quantile in step.quantiles diff --git a/server/src/weaverbird/backends/mongo_translator/steps/substring.py b/server/src/weaverbird/backends/mongo_translator/steps/substring.py index 9588691f0b..fe2e059ff0 100644 --- a/server/src/weaverbird/backends/mongo_translator/steps/substring.py +++ b/server/src/weaverbird/backends/mongo_translator/steps/substring.py @@ -4,15 +4,11 @@ def translate_substring(step: SubstringStep) -> list[MongoStep]: pos_start_index = ( - step.start_index - 1 - if step.start_index > 0 - else {"$add": [{"$strLenCP": f"${step.column}"}, step.start_index]} + step.start_index - 1 if step.start_index > 0 else {"$add": [{"$strLenCP": f"${step.column}"}, step.start_index]} ) pos_end_index = ( - step.end_index - 1 - if step.end_index > 0 - else {"$add": [{"$strLenCP": f"${step.column}"}, step.end_index]} + step.end_index - 1 if step.end_index > 0 else {"$add": [{"$strLenCP": f"${step.column}"}, step.end_index]} ) length_to_keep = { @@ -26,12 +22,4 @@ def translate_substring(step: SubstringStep) -> list[MongoStep]: substr_mongo = {"$substrCP": [f"${step.column}", pos_start_index, length_to_keep]} - return [ - { - "$addFields": { - ( - step.new_column_name if step.new_column_name else f"{step.column}_SUBSTR" - ): substr_mongo - } - } - ] + return [{"$addFields": {(step.new_column_name if step.new_column_name else f"{step.column}_SUBSTR"): substr_mongo}}] diff --git a/server/src/weaverbird/backends/mongo_translator/steps/todate.py b/server/src/weaverbird/backends/mongo_translator/steps/todate.py index 66f408f685..3bad4d82ae 100644 --- a/server/src/weaverbird/backends/mongo_translator/steps/todate.py +++ b/server/src/weaverbird/backends/mongo_translator/steps/todate.py @@ -30,9 +30,7 @@ def translate_todate(step: ToDateStep) -> list[MongoStep]: {"$lt": [f"${step.column}", 10_000]}, { "$dateFromString": { - "dateString": { - "$concat": [col_as_string, "-01-01"] - }, + "dateString": {"$concat": [col_as_string, "-01-01"]}, "format": "%Y-%m-%d", "onError": {"$literal": None}, } @@ -109,9 +107,7 @@ def translate_todate(step: ToDateStep) -> list[MongoStep]: {"$addFields": {"_vqbTempArray": {"$split": [col_as_string, " "]}}}, _extract_date_parts_to_temp_fields(1, 0), MONTH_REPLACEMENT_STEP, - _concat_fields_to_date( - step.column, ["01-", "$_vqbTempMonth", "-", "$_vqbTempYear"], "%d-%m-%Y" - ), + _concat_fields_to_date(step.column, ["01-", "$_vqbTempMonth", "-", "$_vqbTempYear"], "%d-%m-%Y"), _clean_temp_fields(), ] @@ -120,9 +116,7 @@ def translate_todate(step: ToDateStep) -> list[MongoStep]: {"$addFields": {"_vqbTempArray": {"$split": [col_as_string, "-"]}}}, _extract_date_parts_to_temp_fields(1, 0), MONTH_REPLACEMENT_STEP, - _concat_fields_to_date( - step.column, ["01-", "$_vqbTempMonth", "-", "$_vqbTempYear"], "%d-%m-%Y" - ), + _concat_fields_to_date(step.column, ["01-", "$_vqbTempMonth", "-", "$_vqbTempYear"], "%d-%m-%Y"), _clean_temp_fields(), ] @@ -185,8 +179,7 @@ def translate_todate(step: ToDateStep) -> list[MongoStep]: MONTH_REPLACEMENT_STEP: MongoStep = { "$addFields": { - "_vqbTempMonth" - "$switch": { + "_vqbTempMonth" "$switch": { "branches": [ { "case": {"$in": month_names}, @@ -208,9 +201,7 @@ def _extract_date_parts_to_temp_fields( } if month_position is not None: - date_parts_temp_fields["_vqbTempMonth"] = { - "$toLower": {"$arrayElemAt": ["$_vqbTempArray", month_position]} - } + date_parts_temp_fields["_vqbTempMonth"] = {"$toLower": {"$arrayElemAt": ["$_vqbTempArray", month_position]}} if day_position is not None: date_parts_temp_fields["_vqbTempDay"] = {"$arrayElemAt": ["$_vqbTempArray", day_position]} @@ -221,9 +212,7 @@ def _extract_date_parts_to_temp_fields( def _clean_temp_fields(): - return { - "$project": {"_vqbTempArray": 0, "_vqbTempMonth": 0, "_vqbTempYear": 0, "_vqbTempDate": 0} - } + return {"$project": {"_vqbTempArray": 0, "_vqbTempMonth": 0, "_vqbTempYear": 0, "_vqbTempDate": 0}} def _concat_fields_to_date(target_col: str, fields: list[str | dict], format: str): diff --git a/server/src/weaverbird/backends/mongo_translator/steps/totals.py b/server/src/weaverbird/backends/mongo_translator/steps/totals.py index 1505293311..ab8062f1d5 100644 --- a/server/src/weaverbird/backends/mongo_translator/steps/totals.py +++ b/server/src/weaverbird/backends/mongo_translator/steps/totals.py @@ -6,11 +6,7 @@ def combinations(iterable: list) -> list: """combinations([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)""" - return list( - itertools.chain.from_iterable( - itertools.combinations(iterable, r) for r in range(1, len(iterable) + 1) - ) - ) + return list(itertools.chain.from_iterable(itertools.combinations(iterable, r) for r in range(1, len(iterable) + 1))) def column_map(s: list[str]) -> dict[str, str]: @@ -62,9 +58,7 @@ def translate_totals(step: TotalsStep) -> list[MongoStep]: else: aggs[aggregated_col] = {f"${agg_func}": f"${value_col}"} - add_fields_to_add_to_pipeline = ( - [{"$addFields": count_distinct_add_fields}] if count_distinct_add_fields else [] - ) + add_fields_to_add_to_pipeline = [{"$addFields": count_distinct_add_fields}] if count_distinct_add_fields else [] facet[f"combo_{i}"] = [ { diff --git a/server/src/weaverbird/backends/mongo_translator/steps/waterfall.py b/server/src/weaverbird/backends/mongo_translator/steps/waterfall.py index 7ea2b4e31a..c1d86716ba 100644 --- a/server/src/weaverbird/backends/mongo_translator/steps/waterfall.py +++ b/server/src/weaverbird/backends/mongo_translator/steps/waterfall.py @@ -16,9 +16,7 @@ _VQB_ORDER = "_vqbOrder" -def _facet_keys_and_elements( - step: WaterfallStep, *, group_key: list[str], project_key: list[str] -) -> list[MongoStep]: +def _facet_keys_and_elements(step: WaterfallStep, *, group_key: list[str], project_key: list[str]) -> list[MongoStep]: """Facets the input document in order to get: * A list of all existing group keys @@ -72,19 +70,11 @@ def _facet_keys_and_elements( ] -def _filter_out_incomplete_elements( - *, group_key: list[str], project_key: list[str] -) -> list[MongoStep]: +def _filter_out_incomplete_elements(*, group_key: list[str], project_key: list[str]) -> list[MongoStep]: """Filters out elements which do not have a start and end value""" return [ # Determining which keys should be kept: We only what those with a start and end value - { - "$addFields": { - _VQB_KEYS_TO_KEEP: { - "$setIntersection": [f"${_VQB_START_KEYS}", f"${_VQB_END_KEYS}"] - } - } - }, + {"$addFields": {_VQB_KEYS_TO_KEEP: {"$setIntersection": [f"${_VQB_START_KEYS}", f"${_VQB_END_KEYS}"]}}}, # filtering start and elements on keys that should be kept { "$project": { @@ -105,9 +95,7 @@ def _filter_out_incomplete_elements( ] -def _backfill_missing_values( - step: WaterfallStep, *, group_key: list[str], project_key: list[str] -) -> list[MongoStep]: +def _backfill_missing_values(step: WaterfallStep, *, group_key: list[str], project_key: list[str]) -> list[MongoStep]: """Backfills the missing start and end values.""" mongo_step = { "$project": { @@ -121,9 +109,7 @@ def _backfill_missing_values( "$map": { # Determining the missing keys by doing a set difference between all # keys and the start keys - "input": { - "$setDifference": [f"${_VQB_ALL_KEYS}", f"${_VQB_START_KEYS}"] - }, + "input": {"$setDifference": [f"${_VQB_ALL_KEYS}", f"${_VQB_START_KEYS}"]}, "in": { "$mergeObjects": [ "$$this", @@ -165,9 +151,7 @@ def _calculate_children_deltas(step: WaterfallStep) -> list[MongoStep]: _VQB_CHILDREN: { # Iterating over a zip op (end_element, start_element) pairs "$map": { - "input": { - "$zip": {"inputs": [f"${_VQB_END_ELEMENTS}", f"${_VQB_START_ELEMENTS}"]} - }, + "input": {"$zip": {"inputs": [f"${_VQB_END_ELEMENTS}", f"${_VQB_START_ELEMENTS}"]}}, "in": { # Here, we are merging the start object with another object containg the # value column with the start element's value subtracted from the end @@ -222,9 +206,7 @@ def _sort_elements(*, group_key: list[str], project_key: list[str]) -> list[Mong def _facet_results(step: WaterfallStep) -> list[MongoStep]: - children_group = ( - step.groupby + [step.labelsColumn] + ([step.parentsColumn] if step.parentsColumn else []) - ) + children_group = step.groupby + [step.labelsColumn] + ([step.parentsColumn] if step.parentsColumn else []) facet: dict[str, list] = { _VQB_CHILDREN: [ {"$unwind": f"${_VQB_CHILDREN}"}, @@ -238,11 +220,7 @@ def _facet_results(step: WaterfallStep) -> list[MongoStep]: "$project": { **{col: f"$_id.{col}" for col in step.groupby}, "LABEL_waterfall": f"$_id.{step.labelsColumn}", - **( - {"GROUP_waterfall": f"$_id.{step.parentsColumn}"} - if step.parentsColumn - else {} - ), + **({"GROUP_waterfall": f"$_id.{step.parentsColumn}"} if step.parentsColumn else {}), "TYPE_waterfall": "child" if step.parentsColumn else "parent", step.valueColumn: f"${step.valueColumn}", _VQB_ORDER: {"$literal": 1}, @@ -253,9 +231,7 @@ def _facet_results(step: WaterfallStep) -> list[MongoStep]: {"$unwind": f"${_VQB_START_ELEMENTS}"}, { "$group": { - "_id": {col: f"${_VQB_START_ELEMENTS}.{col}" for col in step.groupby} - if step.groupby - else True, + "_id": {col: f"${_VQB_START_ELEMENTS}.{col}" for col in step.groupby} if step.groupby else True, step.valueColumn: {"$sum": f"${_VQB_START_ELEMENTS}.{step.valueColumn}"}, } }, @@ -274,9 +250,7 @@ def _facet_results(step: WaterfallStep) -> list[MongoStep]: {"$unwind": f"${_VQB_END_ELEMENTS}"}, { "$group": { - "_id": {col: f"${_VQB_END_ELEMENTS}.{col}" for col in step.groupby} - if step.groupby - else True, + "_id": {col: f"${_VQB_END_ELEMENTS}.{col}" for col in step.groupby} if step.groupby else True, step.valueColumn: {"$sum": f"${_VQB_END_ELEMENTS}.{step.valueColumn}"}, } }, @@ -298,10 +272,7 @@ def _facet_results(step: WaterfallStep) -> list[MongoStep]: {"$unwind": f"${_VQB_CHILDREN}"}, { "$group": { - "_id": { - col: f"${_VQB_CHILDREN}.{col}" - for col in (step.groupby + [step.parentsColumn]) - }, + "_id": {col: f"${_VQB_CHILDREN}.{col}" for col in (step.groupby + [step.parentsColumn])}, step.valueColumn: {"$sum": f"${_VQB_CHILDREN}.{step.valueColumn}"}, } }, @@ -337,9 +308,7 @@ def _column_map(colnames: list[str]) -> dict[str, str]: def translate_waterfall(step: WaterfallStep) -> list[MongoStep]: - group_key = ( - step.groupby + ([step.parentsColumn] if step.parentsColumn else []) + [step.labelsColumn] - ) + group_key = step.groupby + ([step.parentsColumn] if step.parentsColumn else []) + [step.labelsColumn] project_key = group_key + [step.milestonesColumn, step.valueColumn, "_id"] steps = _facet_keys_and_elements(step, group_key=group_key, project_key=project_key) @@ -367,9 +336,7 @@ def translate_waterfall(step: WaterfallStep) -> list[MongoStep]: { "$sort": { _VQB_ORDER: 1, - ("LABEL_waterfall" if step.sortBy == "label" else step.valueColumn): 1 - if step.order == "asc" - else -1, + ("LABEL_waterfall" if step.sortBy == "label" else step.valueColumn): 1 if step.order == "asc" else -1, }, }, {"$unset": unset}, diff --git a/server/src/weaverbird/backends/mongo_translator/utils.py b/server/src/weaverbird/backends/mongo_translator/utils.py index cf64587d23..29611ec13b 100644 --- a/server/src/weaverbird/backends/mongo_translator/utils.py +++ b/server/src/weaverbird/backends/mongo_translator/utils.py @@ -71,24 +71,20 @@ def build_cond_expression( return cond_expression -def build_dates_expressions( - cond: SimpleCondition, cond_expression: dict[str, Any], operator_mapping: dict[str, str] -): +def build_dates_expressions(cond: SimpleCondition, cond_expression: dict[str, Any], operator_mapping: dict[str, str]): if cond.operator == "until": if isinstance(cond.value, datetime.datetime): cond_expression[operator_mapping[cond.operator]][1] = [ - datetime.datetime( - day=cond.value.day, month=cond.value.month, year=cond.value.month - ).replace(hour=23, minute=59, second=59, microsecond=999999) + datetime.datetime(day=cond.value.day, month=cond.value.month, year=cond.value.month).replace( + hour=23, minute=59, second=59, microsecond=999999 + ) ] if cond.operator == "from" or cond.operator == "until": cond_expression = { operator_mapping[cond.operator]: [ truncate_to_day(f"${cond.column}"), truncate_to_day( - translate_relative_date(cond.value) - if isinstance(cond.value, RelativeDate) - else cond.value + translate_relative_date(cond.value) if isinstance(cond.value, RelativeDate) else cond.value ), ] } diff --git a/server/src/weaverbird/backends/pandas_executor/geo_utils.py b/server/src/weaverbird/backends/pandas_executor/geo_utils.py index e92385deb6..0ffabdd259 100644 --- a/server/src/weaverbird/backends/pandas_executor/geo_utils.py +++ b/server/src/weaverbird/backends/pandas_executor/geo_utils.py @@ -18,6 +18,4 @@ def df_to_geodf(df: pd.DataFrame) -> gpd.GeoDataFrame: try: return gpd.GeoDataFrame(df) except Exception as exc: - raise UnsupportedGeoOperation( - f"Could not convert DataFrame to GeoDataFrame: {exc}" - ) from exc + raise UnsupportedGeoOperation(f"Could not convert DataFrame to GeoDataFrame: {exc}") from exc diff --git a/server/src/weaverbird/backends/pandas_executor/pipeline_executor.py b/server/src/weaverbird/backends/pandas_executor/pipeline_executor.py index 9c628cb8ed..1bf2c789e2 100644 --- a/server/src/weaverbird/backends/pandas_executor/pipeline_executor.py +++ b/server/src/weaverbird/backends/pandas_executor/pipeline_executor.py @@ -73,9 +73,7 @@ def execute_pipeline( return df, PipelineExecutionReport(steps_reports=step_reports) -def preview_pipeline( - pipeline: Pipeline, domain_retriever: DomainRetriever, limit: int = 50, offset: int = 0 -) -> str: +def preview_pipeline(pipeline: Pipeline, domain_retriever: DomainRetriever, limit: int = 50, offset: int = 0) -> str: """ Execute a pipeline but returns only a slice of the results, determined by `limit` and `offset` parameters, as JSON. @@ -99,9 +97,7 @@ def _default_formatter(obj): "limit": limit, "total": df.shape[0], "data": json.loads( - df[offset : offset + limit].to_json( - orient="records", default_handler=_default_formatter - ) + df[offset : offset + limit].to_json(orient="records", default_handler=_default_formatter) ), } ) diff --git a/server/src/weaverbird/backends/pandas_executor/steps/aggregate.py b/server/src/weaverbird/backends/pandas_executor/steps/aggregate.py index 4b36cadfc9..16f8af10c4 100644 --- a/server/src/weaverbird/backends/pandas_executor/steps/aggregate.py +++ b/server/src/weaverbird/backends/pandas_executor/steps/aggregate.py @@ -52,11 +52,7 @@ def execute_aggregate( else: for aggregation in step.aggregations: for col, new_col in zip(aggregation.columns, aggregation.new_columns, strict=True): - agg_serie = ( - grouped_by_df[col] - .agg(get_aggregate_fn(aggregation.agg_function)) - .rename(new_col) - ) + agg_serie = grouped_by_df[col].agg(get_aggregate_fn(aggregation.agg_function)).rename(new_col) aggregated_cols.append(agg_serie) df_result = concat(aggregated_cols, axis=1).reset_index() diff --git a/server/src/weaverbird/backends/pandas_executor/steps/date_extract.py b/server/src/weaverbird/backends/pandas_executor/steps/date_extract.py index 9c197bea08..1335665aa0 100644 --- a/server/src/weaverbird/backends/pandas_executor/steps/date_extract.py +++ b/server/src/weaverbird/backends/pandas_executor/steps/date_extract.py @@ -50,9 +50,7 @@ def execute_date_extract( result = to_datetime(DataFrame({"year": series_dt_accessor.year, "month": 1, "day": 1})) elif dt_info == "firstDayOfMonth": result = to_datetime( - DataFrame( - {"year": series_dt_accessor.year, "month": series_dt_accessor.month, "day": 1} - ) + DataFrame({"year": series_dt_accessor.year, "month": series_dt_accessor.month, "day": 1}) ) elif dt_info == "firstDayOfWeek": # dayofweek should be between 1 (sunday) and 7 (saturday) @@ -83,9 +81,7 @@ def execute_date_extract( # the result should be returned with 0-ed time information result = to_datetime(result.dt.date) elif dt_info == "firstDayOfPreviousYear": - result = to_datetime( - DataFrame({"year": series_dt_accessor.year - 1, "month": 1, "day": 1}) - ) + result = to_datetime(DataFrame({"year": series_dt_accessor.year - 1, "month": 1, "day": 1})) elif dt_info == "firstDayOfPreviousMonth": prev_month = series_dt_accessor.month - 1 prev_month = prev_month.replace({0: 12}) diff --git a/server/src/weaverbird/backends/pandas_executor/steps/hierarchy.py b/server/src/weaverbird/backends/pandas_executor/steps/hierarchy.py index 048d4d45a9..4049c2db7a 100644 --- a/server/src/weaverbird/backends/pandas_executor/steps/hierarchy.py +++ b/server/src/weaverbird/backends/pandas_executor/steps/hierarchy.py @@ -7,9 +7,7 @@ from .dissolve import execute_dissolve -def _dissolve_one_level( - df: pd.DataFrame, groups: list[str], level_col: str, level: int -) -> pd.DataFrame: +def _dissolve_one_level(df: pd.DataFrame, groups: list[str], level_col: str, level: int) -> pd.DataFrame: dissolved = execute_dissolve(DissolveStep(groups=groups), df) dissolved[level_col] = level return dissolved diff --git a/server/src/weaverbird/backends/pandas_executor/steps/ifthenelse.py b/server/src/weaverbird/backends/pandas_executor/steps/ifthenelse.py index e166820ef0..b35e925182 100644 --- a/server/src/weaverbird/backends/pandas_executor/steps/ifthenelse.py +++ b/server/src/weaverbird/backends/pandas_executor/steps/ifthenelse.py @@ -15,13 +15,7 @@ def _execute_ifthenelse(ifthenelse: IfThenElse, df: DataFrame, new_column) -> Da else_branch = eval_formula(df, str(ifthenelse.else_value)) then_branch = eval_formula(df, str(ifthenelse.then)) - return df.assign( - **{ - new_column: np.where( - apply_condition(ifthenelse.condition, df), then_branch, else_branch - ) - } - ) + return df.assign(**{new_column: np.where(apply_condition(ifthenelse.condition, df), then_branch, else_branch)}) def execute_ifthenelse( diff --git a/server/src/weaverbird/backends/pandas_executor/steps/join.py b/server/src/weaverbird/backends/pandas_executor/steps/join.py index db4c922bf8..7bd593a36d 100644 --- a/server/src/weaverbird/backends/pandas_executor/steps/join.py +++ b/server/src/weaverbird/backends/pandas_executor/steps/join.py @@ -13,9 +13,7 @@ def execute_join( domain_retriever: DomainRetriever, execute_pipeline: PipelineExecutor, ) -> DataFrame: - right_df = resolve_pipeline_for_combination( - step.right_pipeline, domain_retriever, execute_pipeline - ) + right_df = resolve_pipeline_for_combination(step.right_pipeline, domain_retriever, execute_pipeline) if step.type == "left outer": how = "outer" diff --git a/server/src/weaverbird/backends/pandas_executor/steps/replacetext.py b/server/src/weaverbird/backends/pandas_executor/steps/replacetext.py index 16adcd864f..eb53d0b98e 100644 --- a/server/src/weaverbird/backends/pandas_executor/steps/replacetext.py +++ b/server/src/weaverbird/backends/pandas_executor/steps/replacetext.py @@ -10,6 +10,4 @@ def execute_replacetext( domain_retriever: DomainRetriever | None = None, execute_pipeline: PipelineExecutor | None = None, ) -> DataFrame: - return df.assign( - **{step.search_column: df[step.search_column].str.replace(step.old_str, step.new_str)} - ) + return df.assign(**{step.search_column: df[step.search_column].str.replace(step.old_str, step.new_str)}) diff --git a/server/src/weaverbird/backends/pandas_executor/steps/todate.py b/server/src/weaverbird/backends/pandas_executor/steps/todate.py index 7980932560..abea5f9ad2 100644 --- a/server/src/weaverbird/backends/pandas_executor/steps/todate.py +++ b/server/src/weaverbird/backends/pandas_executor/steps/todate.py @@ -21,7 +21,5 @@ def execute_todate( # Timestamps are expected in ms (not in ns, which is pandas' default) timestamp_unit = "ms" - datetime_serie = to_datetime( - df[step.column], format=format, errors="coerce", unit=timestamp_unit - ) + datetime_serie = to_datetime(df[step.column], format=format, errors="coerce", unit=timestamp_unit) return df.assign(**{step.column: datetime_serie}) diff --git a/server/src/weaverbird/backends/pandas_executor/steps/totals.py b/server/src/weaverbird/backends/pandas_executor/steps/totals.py index 8a3f251055..88004a6407 100644 --- a/server/src/weaverbird/backends/pandas_executor/steps/totals.py +++ b/server/src/weaverbird/backends/pandas_executor/steps/totals.py @@ -13,9 +13,7 @@ def get_total_for_dimension( # get all group_by columns: all total_dimensions, except the current one + groups # all columns that are either not aggregated, or groups, or total will be null group_by_columns = step.groups + [ - group_column.total_column - for group_column in step.total_dimensions - if group_column != total_dimension + group_column.total_column for group_column in step.total_dimensions if group_column != total_dimension ] aggregations = [] for aggregation in step.aggregations: @@ -40,9 +38,7 @@ def get_total_for_dimension( full_aggregation = concat( [ full_aggregation, - get_total_for_dimension( - step, aggregated_df, dimension, dimensions_to_skip + [dimension] - ), + get_total_for_dimension(step, aggregated_df, dimension, dimensions_to_skip + [dimension]), ] ) return full_aggregation diff --git a/server/src/weaverbird/backends/pandas_executor/steps/utils/condition.py b/server/src/weaverbird/backends/pandas_executor/steps/utils/condition.py index fb2df1e035..1706278081 100644 --- a/server/src/weaverbird/backends/pandas_executor/steps/utils/condition.py +++ b/server/src/weaverbird/backends/pandas_executor/steps/utils/condition.py @@ -78,9 +78,7 @@ def apply_condition(condition: Condition, df: DataFrame) -> Series: hour=0, minute=0, second=0, microsecond=0, nanosecond=0 ) # Do the same with the value to compare it to - value_without_time = value - DateOffset( - hour=0, minute=0, second=0, microsecond=0, nanosecond=0 - ) + value_without_time = value - DateOffset(hour=0, minute=0, second=0, microsecond=0, nanosecond=0) return getattr(column_without_time, comparison_method)(value_without_time) diff --git a/server/src/weaverbird/backends/pandas_executor/steps/utils/formula.py b/server/src/weaverbird/backends/pandas_executor/steps/utils/formula.py index fb5be90b0c..23d01361a5 100644 --- a/server/src/weaverbird/backends/pandas_executor/steps/utils/formula.py +++ b/server/src/weaverbird/backends/pandas_executor/steps/utils/formula.py @@ -17,9 +17,9 @@ def _eval_operation(df: DataFrame, op: Operation) -> Series: - return _OP_MAP[op.operator]( - _eval_expression(df, op.left), _eval_expression(df, op.right) - ).replace([np.inf, -np.inf], np.nan) + return _OP_MAP[op.operator](_eval_expression(df, op.left), _eval_expression(df, op.right)).replace( + [np.inf, -np.inf], np.nan + ) def _eval_expression(df: DataFrame, expr: Expression) -> Series: diff --git a/server/src/weaverbird/backends/pandas_executor/steps/waterfall.py b/server/src/weaverbird/backends/pandas_executor/steps/waterfall.py index 501e158c4e..4d7a07fd0c 100644 --- a/server/src/weaverbird/backends/pandas_executor/steps/waterfall.py +++ b/server/src/weaverbird/backends/pandas_executor/steps/waterfall.py @@ -49,18 +49,12 @@ def execute_waterfall( # Backfilling missing values if step.backfill: - start_df = _backfill_missing_values( - start_df, step, step.labelsColumn, group_columns, unique_values_groups - ) - end_df = _backfill_missing_values( - end_df, step, step.labelsColumn, group_columns, unique_values_groups - ) + start_df = _backfill_missing_values(start_df, step, step.labelsColumn, group_columns, unique_values_groups) + end_df = _backfill_missing_values(end_df, step, step.labelsColumn, group_columns, unique_values_groups) # Otherwise, filter out rows which do not have a start and end date else: # We want to remove all value groups which are not in both the start and end dataframes - value_groups_to_remove = unique_values_groups - start_value_groups.intersection( - end_value_groups - ) + value_groups_to_remove = unique_values_groups - start_value_groups.intersection(end_value_groups) start_df = _filter_out_rows(start_df, group_columns, value_groups_to_remove) end_df = _filter_out_rows(end_df, group_columns, value_groups_to_remove) @@ -119,9 +113,7 @@ def _merge(step: WaterfallStep, start_df: DataFrame, end_df: DataFrame) -> DataF # we join the result to compare them merged_df = start_df.merge(end_df, on=_get_join_key(step)) - merged_df[RESULT_COLUMN] = ( - merged_df[f"{step.valueColumn}_end"] - merged_df[f"{step.valueColumn}_start"] - ) + merged_df[RESULT_COLUMN] = merged_df[f"{step.valueColumn}_end"] - merged_df[f"{step.valueColumn}_start"] merged_df = merged_df.drop( columns=[ f"{step.valueColumn}_start", @@ -131,9 +123,9 @@ def _merge(step: WaterfallStep, start_df: DataFrame, end_df: DataFrame) -> DataF # if there is a parent column, we need to aggregate for them if step.parentsColumn is not None: - parents_results = merged_df.groupby( - step.groupby + [step.parentsColumn], as_index=False - ).agg({RESULT_COLUMN: "sum"}) + parents_results = merged_df.groupby(step.groupby + [step.parentsColumn], as_index=False).agg( + {RESULT_COLUMN: "sum"} + ) parents_results[step.labelsColumn] = parents_results[step.parentsColumn] return pd.concat([merged_df, parents_results]) return merged_df @@ -150,9 +142,7 @@ def _filter_out_rows( return df for value_group_to_remove in value_groups_to_remove: - conditions = ( - df[col] == val for col, val in zip(group_columns, value_group_to_remove, strict=True) - ) + conditions = (df[col] == val for col, val in zip(group_columns, value_group_to_remove, strict=True)) condition = reduce(lambda s1, s2: s1 & s2, conditions) df = df[~condition] diff --git a/server/src/weaverbird/backends/pypika_translator/translators/athena.py b/server/src/weaverbird/backends/pypika_translator/translators/athena.py index 996fe5ee65..d3e32a3753 100644 --- a/server/src/weaverbird/backends/pypika_translator/translators/athena.py +++ b/server/src/weaverbird/backends/pypika_translator/translators/athena.py @@ -46,9 +46,7 @@ class AthenaTranslator(SQLTranslator): FROM_DATE_OP = FromDateOp.DATE_FORMAT @classmethod - def _add_date( - cls, *, target_column: Field, duration: int, unit: str, dialect: Dialects | None = None - ) -> Term: + def _add_date(cls, *, target_column: Field, duration: int, unit: str, dialect: Dialects | None = None) -> Term: # We need implement our own function for athena because Presto requires the units to be # quoted. PyPika's DateAdd function removes them by applying LiteralValue to the unit custom = CustomFunction("DATE_ADD", ["unit", "duration", "target"]) diff --git a/server/src/weaverbird/backends/pypika_translator/translators/base.py b/server/src/weaverbird/backends/pypika_translator/translators/base.py index a294c56dff..efd334c01a 100644 --- a/server/src/weaverbird/backends/pypika_translator/translators/base.py +++ b/server/src/weaverbird/backends/pypika_translator/translators/base.py @@ -241,9 +241,7 @@ def _step_context_from_first_step(self, step: "DomainStep | CustomSqlStep") -> S columns = self._extract_columns_from_customsql_step(step=step) return StepContext(self._custom_query(step=step), columns) - def _merge_first_steps( - self: Self, *, domain_step: "DomainStep", second_step: TopStep | FilterStep - ) -> StepContext: + def _merge_first_steps(self: Self, *, domain_step: "DomainStep", second_step: TopStep | FilterStep) -> StepContext: columns = self._extract_columns_from_domain_step(step=domain_step) # If we have a reference, self._extract_columns_from_domain_step raises assert isinstance(domain_step.domain, str) @@ -283,11 +281,7 @@ def get_query_builder( assert steps[0].name == "domain" or steps[0].name == "customsql" self._step_count = 0 - if ( - len(steps) > 1 - and isinstance(steps[0], DomainStep) - and isinstance(steps[1], FilterStep | TopStep) - ): + if len(steps) > 1 and isinstance(steps[0], DomainStep) and isinstance(steps[1], FilterStep | TopStep): ctx = self._merge_first_steps(domain_step=steps[0], second_step=steps[1]) remaining_steps = steps[2:] else: @@ -295,17 +289,13 @@ def get_query_builder( remaining_steps = steps[1:] table_name = self._next_step_name() - builder = (query_builder if query_builder is not None else self.QUERY_CLS).with_( - ctx.selectable, table_name - ) + builder = (query_builder if query_builder is not None else self.QUERY_CLS).with_(ctx.selectable, table_name) for step in remaining_steps: step_method: Callable[..., StepContext] | None = getattr(self, step.name, None) if step_method is None: raise NotImplementedError(f"[{self.DIALECT}] step {step.name} is not implemented") - ctx = step_method( - step=step, prev_step_table=table_name, builder=builder, columns=ctx.columns - ) + ctx = step_method(step=step, prev_step_table=table_name, builder=builder, columns=ctx.columns) table_name = self._next_step_name() builder = ctx.update_builder(builder=builder, step_name=table_name) return QueryBuilderContext(builder=builder, columns=ctx.columns, table_name=table_name) @@ -315,9 +305,7 @@ def get_query_str(self: Self, *, steps: Sequence["PipelineStep"]) -> str: # All other methods implement step from https://weaverbird.toucantoco.com/docs/steps/, # the name of the method being the name of the step and the kwargs the rest of the params - def _get_aggregate_function( - self: Self, agg_function: "AggregateFn" - ) -> type[functions.AggregateFunction] | None: + def _get_aggregate_function(self: Self, agg_function: "AggregateFn") -> type[functions.AggregateFunction] | None: match agg_function: case "avg": return functions.Avg @@ -334,9 +322,7 @@ def _get_aggregate_function( case _: return None - def _get_window_function( - self: Self, window_function: "AggregateFn" - ) -> analytics.AnalyticFunction | None: + def _get_window_function(self: Self, window_function: "AggregateFn") -> analytics.AnalyticFunction | None: match window_function: case "first": return analytics.FirstValue @@ -393,11 +379,7 @@ def _build_window_subquery() -> Any: self.QUERY_CLS.from_(window_subquery_list[0]) .select( *step.on, - *[ - getattr(first_wq, col[1].alias) - for col in window_selected - if col[0] == min_window_index - ], + *[getattr(first_wq, col[1].alias) for col in window_selected if col[0] == min_window_index], ) .as_("window_subquery") ) @@ -421,9 +403,7 @@ def _build_window_subquery() -> Any: for step_index, aggregation in enumerate(step.aggregations): if agg_fn := self._get_aggregate_function(aggregation.agg_function): - for agg_column_name, new_column_name in zip( - aggregation.columns, aggregation.new_columns, strict=True - ): + for agg_column_name, new_column_name in zip(aggregation.columns, aggregation.new_columns, strict=True): column_field: Field = Table(prev_step_table)[agg_column_name] new_agg_col = agg_fn(column_field).as_(new_column_name) agg_selected.append(new_agg_col) @@ -444,10 +424,7 @@ def _build_window_subquery() -> Any: window_selected.append((step_index, new_window_col)) agg_cols.append(new_window_col) window_subquery_list.append( - self.QUERY_CLS.from_(prev_step_table) - .select(*step.on, *agg_cols) - .distinct() - .as_(f"wq{step_index}") + self.QUERY_CLS.from_(prev_step_table).select(*step.on, *agg_cols).distinct().as_(f"wq{step_index}") ) else: # pragma: no cover @@ -478,16 +455,10 @@ def _build_window_subquery() -> Any: ) else: # If there is no `step.on` columns to join, just put the 2 subqueries side by side: - merged_query = ( - self.QUERY_CLS.from_(agg_query) - .from_(all_windows_subquery) - .select(*merged_selected) - ) + merged_query = self.QUERY_CLS.from_(agg_query).from_(all_windows_subquery).select(*merged_selected) elif agg_selected: selected_cols = [*step.on, *agg_selected] - merged_query = ( - self.QUERY_CLS.from_(prev_step_table).select(*selected_cols).groupby(*step.on) - ) + merged_query = self.QUERY_CLS.from_(prev_step_table).select(*selected_cols).groupby(*step.on) elif window_subquery_list: merged_query = _build_window_subquery() else: @@ -550,10 +521,7 @@ def append( columns: list[str], step: "AppendStep", ) -> StepContext: - pipelines = [ - self._pipeline_or_domain_name_or_reference_to_pipeline(pipeline) - for pipeline in step.pipelines - ] + pipelines = [self._pipeline_or_domain_name_or_reference_to_pipeline(pipeline) for pipeline in step.pipelines] tables: list[str] = [] column_lists: list[list[str]] = [] for pipeline in pipelines: @@ -641,10 +609,7 @@ def comparetext( table = Table(prev_step_table) query: "QueryBuilder" = self.QUERY_CLS.from_(prev_step_table).select( *columns, - Case() - .when(table[step.str_col_1] == table[step.str_col_2], True) - .else_(False) - .as_(step.new_column_name), + Case().when(table[step.str_col_1] == table[step.str_col_2], True).else_(False).as_(step.new_column_name), ) return StepContext(query, columns + [step.new_column_name]) @@ -682,9 +647,7 @@ def convert( query: "QueryBuilder" = self.QUERY_CLS.from_(prev_step_table).select( *(c for c in columns if c not in step.columns), *( - functions.Cast(col_field, getattr(self.DATA_TYPE_MAPPING, step.data_type)).as_( - col_field.name - ) + functions.Cast(col_field, getattr(self.DATA_TYPE_MAPPING, step.data_type)).as_(col_field.name) for col_field in col_fields ), ) @@ -712,16 +675,15 @@ def cumsum( # In case some cumsum columns have overwritten previously exising columns, don't select twice original_column_names = [col for col in columns if col not in cumsum_colnames] query: "QueryBuilder" = ( - self.QUERY_CLS.from_(prev_step_table).select(*original_column_names, *cumsum_cols) + self.QUERY_CLS.from_(prev_step_table) + .select(*original_column_names, *cumsum_cols) # Depending on the backend, results are ordered by partition or by reference colum, so # we choose an arbitrary ordering here .orderby(order_by) ) return StepContext(query, original_column_names + cumsum_colnames) - def _custom_query( - self: Self, *, step: "CustomSqlStep", prev_step_table: str | None = None - ) -> CustomQuery: + def _custom_query(self: Self, *, step: "CustomSqlStep", prev_step_table: str | None = None) -> CustomQuery: table_name = prev_step_table or "_" return CustomQuery( name=f"custom_from_{table_name}", @@ -759,9 +721,7 @@ def _get_date_extract_func(cls, *, date_unit: DATE_INFO, target_column: Field) - "year", "yearofweek", ): - return Extract( - lowered_date_unit.removesuffix("s"), cls._cast_to_timestamp(target_column) - ) + return Extract(lowered_date_unit.removesuffix("s"), cls._cast_to_timestamp(target_column)) # ms aren't supported by snowflake's EXTRACT, even if the docs state otherwise: # https://community.snowflake.com/s/question/0D50Z00008dWkrpSAC/supported-time-parts-in-datepart elif lowered_date_unit == "milliseconds": @@ -774,11 +734,7 @@ def _get_date_extract_func(cls, *, date_unit: DATE_INFO, target_column: Field) - return Extract("week", target_column) elif lowered_date_unit == "isodayofweek": # We want monday as 1, sunday as 7. Redshift goes from sunday as 0 to saturday as 6 - return ( - Case() - .when(cls._day_of_week(target_column) == 0, 7) - .else_(cls._day_of_week(target_column)) - ) + return Case().when(cls._day_of_week(target_column) == 0, 7).else_(cls._day_of_week(target_column)) elif lowered_date_unit == "firstdayofyear": return cls._date_trunc("year", target_column) elif lowered_date_unit == "firstdayofmonth": @@ -810,13 +766,9 @@ def _get_date_extract_func(cls, *, date_unit: DATE_INFO, target_column: Field) - cls._add_date(target_column=target_column, unit="months", duration=-1), ) elif lowered_date_unit == "previousweek": - return Extract( - "week", cls._add_date(target_column=target_column, unit="weeks", duration=-1) - ) + return Extract("week", cls._add_date(target_column=target_column, unit="weeks", duration=-1)) elif lowered_date_unit == "previousisoweek": - return Extract( - "week", cls._add_date(target_column=target_column, unit="weeks", duration=-1) - ) + return Extract("week", cls._add_date(target_column=target_column, unit="weeks", duration=-1)) elif lowered_date_unit == "previousquarter": return Extract( "quarter", @@ -827,18 +779,12 @@ def _get_date_extract_func(cls, *, date_unit: DATE_INFO, target_column: Field) - ), ) elif lowered_date_unit == "firstdayofpreviousyear": - return cls._add_date( - target_column=cls._date_trunc("year", target_column), unit="years", duration=-1 - ) + return cls._add_date(target_column=cls._date_trunc("year", target_column), unit="years", duration=-1) elif lowered_date_unit == "firstdayofpreviousmonth": - return cls._add_date( - target_column=cls._date_trunc("month", target_column), unit="months", duration=-1 - ) + return cls._add_date(target_column=cls._date_trunc("month", target_column), unit="months", duration=-1) elif lowered_date_unit == "firstdayofpreviousquarter": # Postgres does not support quarters in intervals - return cls._add_date( - target_column=cls._date_trunc("year", target_column), unit="months", duration=-3 - ) + return cls._add_date(target_column=cls._date_trunc("year", target_column), unit="months", duration=-3) elif lowered_date_unit == "firstdayofpreviousweek": return cls._add_date( target_column=cls._add_date( @@ -852,9 +798,7 @@ def _get_date_extract_func(cls, *, date_unit: DATE_INFO, target_column: Field) - duration=-1, ) elif lowered_date_unit == "firstdayofpreviousisoweek": - return cls._add_date( - target_column=cls._date_trunc("week", target_column), unit="weeks", duration=-1 - ) + return cls._add_date(target_column=cls._date_trunc("week", target_column), unit="weeks", duration=-1) # Postgres supports EXTRACT(isoyear) but redshift doesn't so... elif lowered_date_unit == "isoyear": return Extract("year", cls._date_trunc("week", target_column)) @@ -879,9 +823,7 @@ def dateextract( col_field = functions.Cast(col_field, self.DATA_TYPE_MAPPING.integer) extracted_dates.append(col_field.as_(new_column_name)) - query: "Selectable" = self.QUERY_CLS.from_(prev_step_table).select( - *columns, *extracted_dates - ) + query: "Selectable" = self.QUERY_CLS.from_(prev_step_table).select(*columns, *extracted_dates) return StepContext(query, columns + [col.alias for col in extracted_dates]) @@ -909,9 +851,7 @@ def _extract_columns_from_domain_step(self: Self, *, step: "DomainStep") -> list # getattr(self, step_name) def _domain(self: Self, *, step: "DomainStep") -> StepContext: selected_cols = self._extract_columns_from_domain_step(step=step) - query: "QueryBuilder" = self.QUERY_CLS.from_( - Table(step.domain, schema=self._db_schema) - ).select(*selected_cols) + query: "QueryBuilder" = self.QUERY_CLS.from_(Table(step.domain, schema=self._db_schema)).select(*selected_cols) if self._source_rows_subset: query = query.limit(self._source_rows_subset) return StepContext(query, selected_cols) @@ -956,9 +896,7 @@ def duration( return StepContext(query, columns + [step.new_column_name]) @classmethod - def _add_date( - cls, *, target_column: Field, duration: int, unit: str, dialect: Dialects | None = None - ) -> Term: + def _add_date(cls, *, target_column: Field, duration: int, unit: str, dialect: Dialects | None = None) -> Term: return target_column + Interval(**{unit: duration, "dialect": dialect}) def evolution( @@ -985,12 +923,8 @@ def evolution( prev_table.field(step.value_col) - right_table.field(step.value_col) if step.evolution_format == "abs" else ( - functions.Cast( - prev_table.field(step.value_col), self.DATA_TYPE_MAPPING.float - ) - / functions.Cast( - right_table.field(step.value_col), self.DATA_TYPE_MAPPING.float - ) + functions.Cast(prev_table.field(step.value_col), self.DATA_TYPE_MAPPING.float) + / functions.Cast(right_table.field(step.value_col), self.DATA_TYPE_MAPPING.float) ) - 1.0 ).as_(new_col), @@ -1021,10 +955,7 @@ def fillna( the_table = Table(prev_step_table) query: "QueryBuilder" = self.QUERY_CLS.from_(prev_step_table).select( *(c for c in columns if c not in step.columns), - *( - functions.Coalesce(the_table[col_name], step.value).as_(col_name) - for col_name in step.columns - ), + *(functions.Coalesce(the_table[col_name], step.value).as_(col_name) for col_name in step.columns), ) return StepContext(query, columns) @@ -1032,9 +963,7 @@ def fillna( def _cast_to_timestamp(value: str | datetime | Field | Term) -> functions.Function: return functions.Cast(value, "TIMESTAMP") - def _get_single_condition_criterion( - self: Self, condition: "SimpleCondition", prev_step_table: Table - ) -> Criterion: + def _get_single_condition_criterion(self: Self, condition: "SimpleCondition", prev_step_table: Table) -> Criterion: column_field = prev_step_table[condition.column] self._ensure_term_uses_wrapper(column_field) @@ -1103,9 +1032,7 @@ def _get_single_condition_criterion( column_field.wrap_constant(compliant_regex), ) case _: - raise NotImplementedError( - f"[{self.DIALECT}] doesn't have regexp operator" - ) + raise NotImplementedError(f"[{self.DIALECT}] doesn't have regexp operator") elif condition.operator == "notmatches": # Casting the field to str first as it is the only compatible type for regex @@ -1138,9 +1065,7 @@ def _get_single_condition_criterion( column_field.wrap_constant(compliant_regex), ) case _: - raise NotImplementedError( - f"[{self.DIALECT}] doesn't have regexp operator" - ) + raise NotImplementedError(f"[{self.DIALECT}] doesn't have regexp operator") case NullCondition(): # type:ignore[misc] if condition.operator == "isnull": @@ -1170,9 +1095,7 @@ def _get_single_condition_criterion( case _: # pragma: no cover raise KeyError(f"Operator {condition.operator!r} does not exist") - def _get_filter_criterion( - self: Self, condition: "Condition", prev_step_table: Table - ) -> Criterion: + def _get_filter_criterion(self: Self, condition: "Condition", prev_step_table: Table) -> Criterion: from weaverbird.pipeline.conditions import ConditionComboAnd, ConditionComboOr # NOTE: type ignore comments below are because of 'Expected type in class pattern; found @@ -1180,13 +1103,11 @@ def _get_filter_criterion( match condition: case ConditionComboOr(): # type:ignore[misc] return Criterion.any( - self._get_filter_criterion(condition, prev_step_table) - for condition in condition.or_ + self._get_filter_criterion(condition, prev_step_table) for condition in condition.or_ ) case ConditionComboAnd(): # type:ignore[misc] return Criterion.all( - self._get_filter_criterion(condition, prev_step_table) - for condition in condition.and_ + self._get_filter_criterion(condition, prev_step_table) for condition in condition.and_ ) case _: return self._get_single_condition_criterion(condition, prev_step_table) @@ -1201,9 +1122,7 @@ def filter( ) -> StepContext: table = Table(prev_step_table) if isinstance(prev_step_table, str) else prev_step_table query: "QueryBuilder" = ( - self.QUERY_CLS.from_(table) - .select(*columns) - .where(self._get_filter_criterion(step.condition, table)) + self.QUERY_CLS.from_(table).select(*columns).where(self._get_filter_criterion(step.condition, table)) ) return StepContext(query, columns) @@ -1267,9 +1186,7 @@ def _build_ifthenelse_case( except (json.JSONDecodeError, TypeError): # the value is a formula or a string literal that can't be parsed then_value = formula_to_term(then_, table) - case_ = case_.when( - self._get_filter_criterion(if_, Table(prev_step_table)), LiteralValue(then_value) - ) + case_ = case_.when(self._get_filter_criterion(if_, Table(prev_step_table)), LiteralValue(then_value)) if isinstance(else_, IfThenElse): return self._build_ifthenelse_case( @@ -1363,11 +1280,7 @@ def join( self.QUERY_CLS.from_(left_table) .select(*left_cols, *right_cols) .join(right_table, self._get_join_type(step.type)) - .on( - Criterion.all( - Field(f[0], table=left_table) == Field(f[1], table=right_table) for f in step.on - ) - ) + .on(Criterion.all(Field(f[0], table=left_table) == Field(f[1], table=right_table) for f in step.on)) # Order of results is not consistent depending on the SQL Engine (inconsistencies # observed with Athena and BigQuery). .orderby(*(c[0] for c in step.on)) @@ -1409,19 +1322,14 @@ def percentage( # If we have groups, we need to select them as well, and group the sum by the groups if len(step.group) > 0: - agg_query = ( - self.QUERY_CLS.from_(prev_step_table) - .select(sum_col, *step.group) - .groupby(*step.group) - ) + agg_query = self.QUERY_CLS.from_(prev_step_table).select(sum_col, *step.group).groupby(*step.group) # Otherwise we just need a simple sum else: agg_query = self.QUERY_CLS.from_(prev_step_table).select(sum_col) - perc_column = ( - functions.Cast(table[step.column], self.DATA_TYPE_MAPPING.float) - / agg_query[sum_col_name] - ).as_(new_col_name) + perc_column = (functions.Cast(table[step.column], self.DATA_TYPE_MAPPING.float) / agg_query[sum_col_name]).as_( + new_col_name + ) query = self.QUERY_CLS.from_(prev_step_table).select(*columns, perc_column) @@ -1451,11 +1359,7 @@ def rank( analytics_fn = analytics.Rank if step.method == "standard" else analytics.DenseRank rank_column = ( - ( - analytics_fn().over(*(Field(group) for group in step.groupby)) - if step.groupby - else analytics_fn() - ) + (analytics_fn().over(*(Field(group) for group in step.groupby)) if step.groupby else analytics_fn()) .orderby(col_field, order=Order.desc if step.order == "desc" else Order.asc) .as_(new_col_name) ) @@ -1547,9 +1451,7 @@ def sort( query: "QueryBuilder" = self.QUERY_CLS.from_(prev_step_table).select(*columns) for column_sort in step.columns: - query = query.orderby( - column_sort.column, order=Order.desc if column_sort.order == "desc" else Order.asc - ) + query = query.orderby(column_sort.column, order=Order.desc if column_sort.order == "desc" else Order.asc) return StepContext(query, columns) @@ -1572,9 +1474,7 @@ def split( query: "QueryBuilder" = self.QUERY_CLS.from_(prev_step_table).select( *columns, *( - self._wrap_split_part( - functions.SplitPart(col_field, step.delimiter, i + 1) - ).as_(new_cols[i]) + self._wrap_split_part(functions.SplitPart(col_field, step.delimiter, i + 1)).as_(new_cols[i]) for i in range(step.number_cols_to_keep) ), ) @@ -1590,15 +1490,13 @@ def substring( columns: list[str], step: "SubstringStep", ) -> StepContext: - step.new_column_name = ( - f"{step.column}_substr" if step.new_column_name is None else step.new_column_name - ) + step.new_column_name = f"{step.column}_substr" if step.new_column_name is None else step.new_column_name col_field: Field = Table(prev_step_table)[step.column] query: "QueryBuilder" = self.QUERY_CLS.from_(prev_step_table).select( *columns, - functions.Substring( - col_field, step.start_index, (step.end_index - step.start_index) + 1 - ).as_(step.new_column_name), + functions.Substring(col_field, step.start_index, (step.end_index - step.start_index) + 1).as_( + step.new_column_name + ), ) return StepContext(query, columns + [step.new_column_name]) @@ -1769,9 +1667,7 @@ def _build_unpivot_col( if cls.SUPPORT_UNPIVOT: return f"UNPIVOT({value_col} FOR {unpivot_col} IN ({in_cols}))" - in_single_quote_cols = ", ".join( - format_quotes(col, secondary_quote_char) for col in step.unpivot - ) + in_single_quote_cols = ", ".join(format_quotes(col, secondary_quote_char) for col in step.unpivot) return f" t1 CROSS JOIN UNNEST(ARRAY[{in_single_quote_cols}], ARRAY[{in_cols}]) t2 ({unpivot_col}, {value_col})" def unpivot( diff --git a/server/src/weaverbird/backends/pypika_translator/translators/googlebigquery.py b/server/src/weaverbird/backends/pypika_translator/translators/googlebigquery.py index 7f349a2180..ca34b7ccc7 100644 --- a/server/src/weaverbird/backends/pypika_translator/translators/googlebigquery.py +++ b/server/src/weaverbird/backends/pypika_translator/translators/googlebigquery.py @@ -50,9 +50,7 @@ def __init__(self, field: Field, delimiter: str | None = None) -> None: super().__init__(*args) -GQBTimestampDiffUnit: TypeAlias = Literal[ - "MICROSECOND", "MILLISECOND", "SECOND", "MINUTE", "HOUR", "DAY" -] +GQBTimestampDiffUnit: TypeAlias = Literal["MICROSECOND", "MILLISECOND", "SECOND", "MINUTE", "HOUR", "DAY"] class GBQTimestampDiff(Function): @@ -83,9 +81,7 @@ class GoogleBigQueryQueryBuilder(QueryBuilder): QUERY_CLS = GoogleBigQueryQuery def __init__(self, **kwargs: Any) -> None: - super().__init__( - dialect=SQLDialect.GOOGLEBIGQUERY, wrapper_cls=GoogleBigQueryValueWrapper, **kwargs - ) + super().__init__(dialect=SQLDialect.GOOGLEBIGQUERY, wrapper_cls=GoogleBigQueryValueWrapper, **kwargs) class GoogleBigQueryDateAdd(Function): @@ -121,9 +117,7 @@ class GoogleBigQueryTranslator(SQLTranslator): REGEXP_OP = RegexOp.REGEXP_CONTAINS @classmethod - def _add_date( - cls, *, target_column: Field, duration: int, unit: str, dialect: Dialects | None = None - ) -> Term: + def _add_date(cls, *, target_column: Field, duration: int, unit: str, dialect: Dialects | None = None) -> Term: return GoogleBigQueryDateAdd( target_column=target_column, # Cheating a bit here: MySQL's syntax is compatible with GBQ for intervals @@ -154,9 +148,7 @@ def gen_splitted_cols(): split_str = GBQSplit(col_field, step.delimiter).get_sql( quote_char=GoogleBigQueryQueryBuilder.QUOTE_CHAR ) - safe_offset_str = safe_offset(i).get_sql( - quote_char=GoogleBigQueryQueryBuilder.QUOTE_CHAR - ) + safe_offset_str = safe_offset(i).get_sql(quote_char=GoogleBigQueryQueryBuilder.QUOTE_CHAR) # LiteralValue is ugly, but it does not seem like pypika supports "[]" array # accessing, and GBQ does not seem to provide functions to access array value. # @@ -167,9 +159,7 @@ def gen_splitted_cols(): ) splitted_cols = list(gen_splitted_cols()) - query: "QueryBuilder" = self.QUERY_CLS.from_(prev_step_table).select( - *columns, *splitted_cols - ) + query: "QueryBuilder" = self.QUERY_CLS.from_(prev_step_table).select(*columns, *splitted_cols) return StepContext(query, columns + splitted_cols) @classmethod @@ -184,9 +174,7 @@ def _get_date_extract_func(cls, *, date_unit: DATE_INFO, target_column: Field) - if date_unit == "week": return functions.Extract("isoweek", target_column) if date_unit == "previousWeek": - return functions.Extract( - "isoweek", cls._add_date(target_column=target_column, unit="weeks", duration=-1) - ) + return functions.Extract("isoweek", cls._add_date(target_column=target_column, unit="weeks", duration=-1)) if date_unit == "isoWeek": return ( @@ -225,15 +213,11 @@ def _get_date_extract_func(cls, *, date_unit: DATE_INFO, target_column: Field) - if date_unit == "firstDayOfWeek": return cls._date_trunc("WEEK", target_column) if date_unit == "firstDayOfPreviousWeek": - return cls._add_date( - target_column=cls._date_trunc("WEEK", target_column), duration=-1, unit="weeks" - ) + return cls._add_date(target_column=cls._date_trunc("WEEK", target_column), duration=-1, unit="weeks") if date_unit == "firstDayOfIsoWeek": return cls._date_trunc("ISOWEEK", target_column) if date_unit == "firstDayOfPreviousIsoWeek": - return cls._add_date( - target_column=cls._date_trunc("ISOWEEK", target_column), duration=-1, unit="weeks" - ) + return cls._add_date(target_column=cls._date_trunc("ISOWEEK", target_column), duration=-1, unit="weeks") if date_unit == "firstDayOfPreviousMonth": # We need to cast the truncated timestamp to a date to prevent the following error: # "DATE_ADD does not support the MONTH date part when the argument is TIMESTAMP type at [1:8]" diff --git a/server/src/weaverbird/backends/pypika_translator/translators/mysql.py b/server/src/weaverbird/backends/pypika_translator/translators/mysql.py index 11609a2420..ea4390b472 100644 --- a/server/src/weaverbird/backends/pypika_translator/translators/mysql.py +++ b/server/src/weaverbird/backends/pypika_translator/translators/mysql.py @@ -53,9 +53,7 @@ class MySQLTranslator(SQLTranslator): TO_DATE_OP = ToDateOp.STR_TO_DATE @classmethod - def _build_unpivot_col( - cls, *, step: "UnpivotStep", quote_char: str | None, secondary_quote_char: str - ) -> str: + def _build_unpivot_col(cls, *, step: "UnpivotStep", quote_char: str | None, secondary_quote_char: str) -> str: value_col = format_quotes(step.value_column_name, quote_char) unpivot_col = format_quotes(step.unpivot_column_name, quote_char) in_cols = [format_quotes(col, quote_char) for col in step.unpivot] @@ -90,17 +88,13 @@ def build_columns(): # but having another Case statement for i=0 would be hell col_field.regexp(f"(({step.delimiter}).*){{{i}}}"), # https://stackoverflow.com/a/32500349 - SubstringIndex( - SubstringIndex(col_field, step.delimiter, i + 1), step.delimiter, -1 - ), + SubstringIndex(SubstringIndex(col_field, step.delimiter, i + 1), step.delimiter, -1), ) .else_("") .as_(new_cols[i]) ) - query: "QueryBuilder" = self.QUERY_CLS.from_(prev_step_table).select( - *columns, *build_columns() - ) + query: "QueryBuilder" = self.QUERY_CLS.from_(prev_step_table).select(*columns, *build_columns()) return StepContext(query, columns + new_cols) @staticmethod @@ -108,9 +102,7 @@ def _cast_to_timestamp(value: str | datetime | Field | Term) -> functions.Functi return functions.Timestamp(value) @classmethod - def _add_date( - cls, *, target_column: Field, duration: int, unit: str, dialect: Dialects | None = None - ) -> Term: + def _add_date(cls, *, target_column: Field, duration: int, unit: str, dialect: Dialects | None = None) -> Term: return super()._add_date( target_column=target_column, duration=duration, @@ -133,7 +125,5 @@ def dateextract( class SubstringIndex(functions.Function): - def __init__( - self, term: str | Field, delimiter: str, count: int, alias: str | None = None - ) -> None: + def __init__(self, term: str | Field, delimiter: str, count: int, alias: str | None = None) -> None: super().__init__("SUBSTRING_INDEX", term, delimiter, count, alias=alias) diff --git a/server/src/weaverbird/backends/pypika_translator/translators/redshift.py b/server/src/weaverbird/backends/pypika_translator/translators/redshift.py index 6c2e724347..272b5c33ce 100644 --- a/server/src/weaverbird/backends/pypika_translator/translators/redshift.py +++ b/server/src/weaverbird/backends/pypika_translator/translators/redshift.py @@ -45,9 +45,7 @@ class RedshiftTranslator(PostgreSQLTranslator): # helpers allow to nest concatenations: # https://docs.aws.amazon.com/redshift/latest/dg/r_CONCAT.html - def _recursive_concat( - self, concat: functions.Concat | None, tokens: list[str] - ) -> functions.Concat: + def _recursive_concat(self, concat: functions.Concat | None, tokens: list[str]) -> functions.Concat: if len(tokens) == 0: assert concat is not None return concat @@ -79,9 +77,7 @@ def concatenate( return StepContext(query, columns + [step.new_column_name]) @classmethod - def _add_date( - cls, *, target_column: Field, duration: int, unit: str, dialect: Dialects | None = None - ) -> Term: + def _add_date(cls, *, target_column: Field, duration: int, unit: str, dialect: Dialects | None = None) -> Term: return DateAddWithoutUnderscore(date_part=unit, interval=duration, term=target_column) diff --git a/server/src/weaverbird/backends/pypika_translator/translators/snowflake.py b/server/src/weaverbird/backends/pypika_translator/translators/snowflake.py index 0bda9aa40e..e44eab7a1f 100644 --- a/server/src/weaverbird/backends/pypika_translator/translators/snowflake.py +++ b/server/src/weaverbird/backends/pypika_translator/translators/snowflake.py @@ -61,12 +61,8 @@ class SnowflakeTranslator(SQLTranslator): TO_DATE_OP = ToDateOp.TO_TIMESTAMP_NTZ @classmethod - def _add_date( - cls, *, target_column: Field, duration: int, unit: str, dialect: Dialects | None = None - ) -> Term: - return DateAddWithoutUnderscore( - date_part=unit.removesuffix("s"), interval=duration, term=target_column - ) + def _add_date(cls, *, target_column: Field, duration: int, unit: str, dialect: Dialects | None = None) -> Term: + return DateAddWithoutUnderscore(date_part=unit.removesuffix("s"), interval=duration, term=target_column) @classmethod def _interval_to_seconds(cls, value: Selectable) -> functions.Function: diff --git a/server/src/weaverbird/pipeline/formula_ast/eval.py b/server/src/weaverbird/pipeline/formula_ast/eval.py index a1dd758419..5f883e354d 100644 --- a/server/src/weaverbird/pipeline/formula_ast/eval.py +++ b/server/src/weaverbird/pipeline/formula_ast/eval.py @@ -79,9 +79,7 @@ def parse_col_name(prev_token: tokenize.TokenInfo) -> str: self._columns[col_name] = self._formula[start:end] return col_name prev_token = tok - raise UnclosedColumnName( - f"Expected column to be closed near {prev_token.string} at {prev_token.start[1]}" - ) + raise UnclosedColumnName(f"Expected column to be closed near {prev_token.string} at {prev_token.start[1]}") while token := next_token(): if token.type in ( @@ -110,9 +108,7 @@ def sanitize_formula(self) -> str: self._columns = {} # Stripping because strings starting with whitespace raise UnexpectedIndent when parsed by # the ast module - return " ".join( - self._iterate_tokens(tokenize.tokenize(BytesIO(self._formula.encode()).readline)) - ).strip() + return " ".join(self._iterate_tokens(tokenize.tokenize(BytesIO(self._formula.encode()).readline))).strip() @staticmethod def _operator_from_ast_op(op: ast.operator) -> types.Operator: @@ -165,15 +161,11 @@ def _parse_expr(self, expr: ast.expr) -> types.Expression: # -colname: -mycol, -[my col] case ast.UnaryOp(op=ast.USub(), operand=ast.Name(id=name)): # Cheating a bit here, assuming the column is numeric - return types.Operation( - left=-1, operator=types.Operator.MUL, right=self._build_name(name) - ) + return types.Operation(left=-1, operator=types.Operator.MUL, right=self._build_name(name)) # Recursing down into both branches of the operation case ast.BinOp(left=left, right=right, op=op): operator = self._operator_from_ast_op(op) - return types.Operation( - left=self._parse_expr(left), right=self._parse_expr(right), operator=operator - ) + return types.Operation(left=self._parse_expr(left), right=self._parse_expr(right), operator=operator) # Constant: number, string literal or boolean case ast.Constant(value=value): # bool is a subtype of int @@ -182,9 +174,7 @@ def _parse_expr(self, expr: ast.expr) -> types.Expression: elif isinstance(value, str): return f"'{value}'" else: - raise UnsupportedConstant( - f"Unsupported constant '{expr}' of type {type(value)}" - ) + raise UnsupportedConstant(f"Unsupported constant '{expr}' of type {type(value)}") # Column name case ast.Name(id=name): return self._build_name(name) diff --git a/server/src/weaverbird/pipeline/pipeline.py b/server/src/weaverbird/pipeline/pipeline.py index 2f98c79101..8fa8de598e 100644 --- a/server/src/weaverbird/pipeline/pipeline.py +++ b/server/src/weaverbird/pipeline/pipeline.py @@ -375,9 +375,7 @@ def _sanitize_query_matches(query: dict | list[dict]) -> Any: if bool(query) and "$match" not in query[0]: query = [{"$match": {}}] + query - return [ - {"$match": _sanitize_match(q["$match"])} if _is_match_statement(q) else q for q in query - ] + return [{"$match": _sanitize_match(q["$match"])} if _is_match_statement(q) else q for q in query] return query @@ -416,9 +414,7 @@ class PipelineWithRefs(BaseModel): steps: list[PipelineStepWithRefs | PipelineStep | PipelineStepWithVariables] - async def resolve_references( - self, reference_resolver: ReferenceResolver - ) -> PipelineWithVariables | None: + async def resolve_references(self, reference_resolver: ReferenceResolver) -> PipelineWithVariables | None: """ Walk the pipeline steps and replace any reference by its corresponding pipeline. The sub-pipelines added should also be handled, so that they will be no references anymore in the result. @@ -426,9 +422,7 @@ async def resolve_references( resolved_steps: list[PipelineStepWithRefs | PipelineStepWithVariables | PipelineStep] = [] for step in self.steps: resolved_step = ( - await step.resolve_references(reference_resolver) - if hasattr(step, "resolve_references") - else step + await step.resolve_references(reference_resolver) if hasattr(step, "resolve_references") else step ) if isinstance(resolved_step, PipelineWithVariables): resolved_steps.extend(resolved_step.steps) diff --git a/server/src/weaverbird/pipeline/steps/append.py b/server/src/weaverbird/pipeline/steps/append.py index 1e9f04fb40..85569075c4 100644 --- a/server/src/weaverbird/pipeline/steps/append.py +++ b/server/src/weaverbird/pipeline/steps/append.py @@ -26,12 +26,8 @@ class AppendStepWithVariable(AppendStep, StepWithVariablesMixin): class AppendStepWithRefs(BaseAppendStep): pipelines: list[PipelineWithRefsOrDomainNameOrReference] - async def resolve_references( - self, reference_resolver: ReferenceResolver - ) -> AppendStepWithVariable | None: - resolved_pipelines = [ - await resolve_if_reference(reference_resolver, p) for p in self.pipelines - ] + async def resolve_references(self, reference_resolver: ReferenceResolver) -> AppendStepWithVariable | None: + resolved_pipelines = [await resolve_if_reference(reference_resolver, p) for p in self.pipelines] resolved_pipelines_without_nones = [p for p in resolved_pipelines if p is not None] if len(resolved_pipelines_without_nones) == 0: return None # skip the step diff --git a/server/src/weaverbird/pipeline/steps/customsql.py b/server/src/weaverbird/pipeline/steps/customsql.py index a98949cf34..48f3297550 100644 --- a/server/src/weaverbird/pipeline/steps/customsql.py +++ b/server/src/weaverbird/pipeline/steps/customsql.py @@ -17,9 +17,7 @@ def _strip_query(query: str) -> str: @field_validator("query") @classmethod def _validate_query(cls, query: str) -> str: - assert ";" not in ( - stripped := cls._strip_query(query) - ), "Custom SQL queries must not contain semicolumns" + assert ";" not in (stripped := cls._strip_query(query)), "Custom SQL queries must not contain semicolumns" return stripped diff --git a/server/src/weaverbird/pipeline/steps/dissolve.py b/server/src/weaverbird/pipeline/steps/dissolve.py index d14cf42d51..2867a84e50 100644 --- a/server/src/weaverbird/pipeline/steps/dissolve.py +++ b/server/src/weaverbird/pipeline/steps/dissolve.py @@ -36,10 +36,6 @@ def _validate_aggregations(cls, values: list[Aggregation]) -> list[Aggregation]: assert len(agg.columns) == 1, "aggregations can only contain a single column" assert len(agg.new_columns) == 1, "aggregations can only contain a single new column" - cls._ensure_unique_and_non_empty( - list(ichain.from_iterable(agg.columns for agg in values)), "columns" - ) - cls._ensure_unique_and_non_empty( - list(ichain.from_iterable(agg.new_columns for agg in values)), "new_columns" - ) + cls._ensure_unique_and_non_empty(list(ichain.from_iterable(agg.columns for agg in values)), "columns") + cls._ensure_unique_and_non_empty(list(ichain.from_iterable(agg.new_columns for agg in values)), "new_columns") return values diff --git a/server/src/weaverbird/pipeline/steps/join.py b/server/src/weaverbird/pipeline/steps/join.py index 676cb5b63f..4f2ca67b63 100644 --- a/server/src/weaverbird/pipeline/steps/join.py +++ b/server/src/weaverbird/pipeline/steps/join.py @@ -33,9 +33,7 @@ class JoinStepWithVariable(JoinStep, StepWithVariablesMixin): class JoinStepWithRef(BaseJoinStep): right_pipeline: PipelineWithRefsOrDomainNameOrReference - async def resolve_references( - self, reference_resolver: ReferenceResolver - ) -> JoinStepWithVariable | None: + async def resolve_references(self, reference_resolver: ReferenceResolver) -> JoinStepWithVariable | None: right_pipeline = await resolve_if_reference(reference_resolver, self.right_pipeline) if right_pipeline is None: from weaverbird.pipeline.pipeline import ReferenceUnresolved diff --git a/server/src/weaverbird/pipeline/steps/utils/combination.py b/server/src/weaverbird/pipeline/steps/utils/combination.py index 06b55fb96d..97f2e27436 100644 --- a/server/src/weaverbird/pipeline/steps/utils/combination.py +++ b/server/src/weaverbird/pipeline/steps/utils/combination.py @@ -64,9 +64,7 @@ def iter_() -> Iterable["PipelineStep"]: # can be either a domain name or a complete pipeline -PipelineOrDomainName = Annotated[ - str | list["PipelineStep"], BeforeValidator(_ensure_is_pipeline_step) -] +PipelineOrDomainName = Annotated[str | list["PipelineStep"], BeforeValidator(_ensure_is_pipeline_step)] def _ensure_is_pipeline_step_with_ref( diff --git a/server/tests/backends/mongo_translator/steps/test_formula.py b/server/tests/backends/mongo_translator/steps/test_formula.py index 8aed15ccf4..eeb62a63f3 100644 --- a/server/tests/backends/mongo_translator/steps/test_formula.py +++ b/server/tests/backends/mongo_translator/steps/test_formula.py @@ -13,11 +13,7 @@ def test_formula_basic_operators(): {"$addFields": {"diff": {"$subtract": ["$you", "$two"]}}} ] assert translate_formula(FormulaStep(new_column="conquer", formula="1 / pi")) == [ - { - "$addFields": { - "conquer": {"$cond": [{"$in": ["$pi", [0, None]]}, None, {"$divide": [1, "$pi"]}]} - } - } + {"$addFields": {"conquer": {"$cond": [{"$in": ["$pi", [0, None]]}, None, {"$divide": [1, "$pi"]}]}}} ] @@ -52,9 +48,7 @@ def test_formula_nested(): } ] - assert translate_formula( - FormulaStep(new_column="bar", formula="1 / ((column_1 + column_2 + column_3)) * 10") - ) == [ + assert translate_formula(FormulaStep(new_column="bar", formula="1 / ((column_1 + column_2 + column_3)) * 10")) == [ { "$addFields": { "bar": { @@ -197,9 +191,7 @@ def test_special_column_name(): def test_special_column_name_and_normal_column_name(): - assert translate_formula( - FormulaStep(new_column="test", formula="[column with spaces] + A") - ) == [ + assert translate_formula(FormulaStep(new_column="test", formula="[column with spaces] + A")) == [ { "$addFields": { "test": { diff --git a/server/tests/backends/mongo_translator/test_mongo_translator_steps.py b/server/tests/backends/mongo_translator/test_mongo_translator_steps.py index 8efffc9151..a6d1e49a14 100644 --- a/server/tests/backends/mongo_translator/test_mongo_translator_steps.py +++ b/server/tests/backends/mongo_translator/test_mongo_translator_steps.py @@ -72,9 +72,7 @@ def _sanitized_df_from_pandas_table(df_spec: dict) -> pd.DataFrame: @pytest.mark.parametrize("case_id,case_spec_file_path", test_cases) -def test_mongo_translator_pipeline( - mongo_database, case_id, case_spec_file_path, available_variables -): +def test_mongo_translator_pipeline(mongo_database, case_id, case_spec_file_path, available_variables): # insert in mongoDB collection_uid = uuid.uuid4().hex spec = get_spec_from_json_fixture(case_id, case_spec_file_path) @@ -84,17 +82,13 @@ def test_mongo_translator_pipeline( "join" in case_id or "append" in case_id ): # needed for join & append steps tests as we need a != collection [ - mongo_database[k].insert_many( - pd.read_json(json.dumps(v), orient="table").to_dict(orient="records") - ) + mongo_database[k].insert_many(pd.read_json(json.dumps(v), orient="table").to_dict(orient="records")) for k, v in spec.get("other_inputs", {}).items() ] # create query steps = spec["step"]["pipeline"] - pipeline = PipelineWithVariables(steps=steps).render( - available_variables, nosql_apply_parameters_to_query - ) + pipeline = PipelineWithVariables(steps=steps).render(available_variables, nosql_apply_parameters_to_query) query = translate_pipeline(pipeline) # execute query result = list(mongo_database[collection_uid].aggregate(query)) diff --git a/server/tests/backends/pandas_executor/test_pandas_executor_steps.py b/server/tests/backends/pandas_executor/test_pandas_executor_steps.py index 0c9987ca15..a1012f0431 100644 --- a/server/tests/backends/pandas_executor/test_pandas_executor_steps.py +++ b/server/tests/backends/pandas_executor/test_pandas_executor_steps.py @@ -29,9 +29,7 @@ def test_pandas_execute_pipeline(case_id, case_spec_file_path, available_variabl steps = spec["step"]["pipeline"] steps.insert(0, {"name": "domain", "domain": "in"}) - pipeline = PipelineWithVariables(steps=steps).render( - available_variables, nosql_apply_parameters_to_query - ) + pipeline = PipelineWithVariables(steps=steps).render(available_variables, nosql_apply_parameters_to_query) domains = {"in": df_in, **dfs_in_others} result = execute_pipeline(pipeline, domain_retriever=lambda x: domains[x])[0] diff --git a/server/tests/backends/pandas_executor/utils/test_dates.py b/server/tests/backends/pandas_executor/utils/test_dates.py index b81dc59ac2..f765af8cab 100644 --- a/server/tests/backends/pandas_executor/utils/test_dates.py +++ b/server/tests/backends/pandas_executor/utils/test_dates.py @@ -15,7 +15,5 @@ def test_evaluate_relative_date(): ) == datetime(year=2020, month=5, day=1) assert evaluate_relative_date( - RelativeDate( - date=datetime(year=2020, month=8, day=1), operator="from", quantity=3, duration="day" - ) + RelativeDate(date=datetime(year=2020, month=8, day=1), operator="from", quantity=3, duration="day") ) == datetime(year=2020, month=8, day=4) diff --git a/server/tests/backends/sql_translator/common.py b/server/tests/backends/sql_translator/common.py index 1c722f1b09..f0f87a20b2 100644 --- a/server/tests/backends/sql_translator/common.py +++ b/server/tests/backends/sql_translator/common.py @@ -7,10 +7,7 @@ def standardized_columns(df: pd.DataFrame, colname_lowercase: bool = False): - df.columns = [ - (c.replace("-", "_").lower() if colname_lowercase else c.replace("-", "_")) - for c in df.columns - ] + df.columns = [(c.replace("-", "_").lower() if colname_lowercase else c.replace("-", "_")) for c in df.columns] def standardized_values(df: pd.DataFrame, convert_nan_to_none: bool = False) -> None: diff --git a/server/tests/backends/sql_translator_integration_tests/test_sql_athena_translator_steps.py b/server/tests/backends/sql_translator_integration_tests/test_sql_athena_translator_steps.py index 0a6db8e82d..e231f9b2f8 100644 --- a/server/tests/backends/sql_translator_integration_tests/test_sql_athena_translator_steps.py +++ b/server/tests/backends/sql_translator_integration_tests/test_sql_athena_translator_steps.py @@ -40,18 +40,14 @@ def boto_session() -> Session: ] -@pytest.mark.parametrize( - "case_id, case_spec_file", retrieve_case("sql_translator", "athena_pypika") -) +@pytest.mark.parametrize("case_id, case_spec_file", retrieve_case("sql_translator", "athena_pypika")) def test_athena_translator_pipeline( boto_session: Session, case_id: str, case_spec_file: str, available_variables: dict ): pipeline_spec = get_spec_from_json_fixture(case_id, case_spec_file) steps = [{"name": "domain", "domain": "beers_tiny"}] + pipeline_spec["step"]["pipeline"] - pipeline = PipelineWithVariables(steps=steps).render( - available_variables, nosql_apply_parameters_to_query - ) + pipeline = PipelineWithVariables(steps=steps).render(available_variables, nosql_apply_parameters_to_query) query = translate_pipeline( sql_dialect=SQLDialect.ATHENA, diff --git a/server/tests/backends/sql_translator_integration_tests/test_sql_bigquery_translator_steps.py b/server/tests/backends/sql_translator_integration_tests/test_sql_bigquery_translator_steps.py index 74ee8847dd..552d37e7dc 100644 --- a/server/tests/backends/sql_translator_integration_tests/test_sql_bigquery_translator_steps.py +++ b/server/tests/backends/sql_translator_integration_tests/test_sql_bigquery_translator_steps.py @@ -49,18 +49,14 @@ def bigquery_client() -> Client: return Client(credentials=credentials) -@pytest.mark.parametrize( - "case_id, case_spec_file", retrieve_case("sql_translator", "bigquery_pypika") -) +@pytest.mark.parametrize("case_id, case_spec_file", retrieve_case("sql_translator", "bigquery_pypika")) def test_bigquery_translator_pipeline( bigquery_client: Client, case_id: str, case_spec_file: str, available_variables: dict ): pipeline_spec = get_spec_from_json_fixture(case_id, case_spec_file) steps = [{"name": "domain", "domain": "beers_tiny"}] + pipeline_spec["step"]["pipeline"] - pipeline = PipelineWithVariables(steps=steps).render( - available_variables, nosql_apply_parameters_to_query - ) + pipeline = PipelineWithVariables(steps=steps).render(available_variables, nosql_apply_parameters_to_query) query = translate_pipeline( sql_dialect=SQLDialect.GOOGLEBIGQUERY, diff --git a/server/tests/backends/sql_translator_integration_tests/test_sql_mysql_translator_steps.py b/server/tests/backends/sql_translator_integration_tests/test_sql_mysql_translator_steps.py index 3678810e8b..40ec719755 100644 --- a/server/tests/backends/sql_translator_integration_tests/test_sql_mysql_translator_steps.py +++ b/server/tests/backends/sql_translator_integration_tests/test_sql_mysql_translator_steps.py @@ -59,9 +59,7 @@ def mysql_container(): time.sleep(1) try: if container.status == "created" and pymysql.connect(**_CON_PARAMS): - dataset = pd.read_csv( - f"{path.join(path.dirname(path.realpath(__file__)))}/beers.csv" - ) + dataset = pd.read_csv(f"{path.join(path.dirname(path.realpath(__file__)))}/beers.csv") dataset["brewing_date"] = dataset["brewing_date"].apply(pd.to_datetime) engine = create_engine(_CONNECTION_STRING) dataset.to_sql("beers_tiny", engine) @@ -73,20 +71,14 @@ def mysql_container(): # Translation from Pipeline json to SQL query @pytest.mark.serial -@pytest.mark.parametrize( - "case_id, case_spec_file_path", retrieve_case("sql_translator", "mysql_pypika") -) +@pytest.mark.parametrize("case_id, case_spec_file_path", retrieve_case("sql_translator", "mysql_pypika")) @pytest.mark.skip("MySQL result order is not consistent with CTEs") -def test_sql_translator_pipeline( - case_id: str, case_spec_file_path: str, engine: Any, available_variables: dict -): +def test_sql_translator_pipeline(case_id: str, case_spec_file_path: str, engine: Any, available_variables: dict): spec = get_spec_from_json_fixture(case_id, case_spec_file_path) steps = spec["step"]["pipeline"] steps.insert(0, {"name": "domain", "domain": "beers_tiny"}) - pipeline = PipelineWithVariables(steps=steps).render( - available_variables, nosql_apply_parameters_to_query - ) + pipeline = PipelineWithVariables(steps=steps).render(available_variables, nosql_apply_parameters_to_query) # Convert Pipeline object to Postgres Query query = translate_pipeline( diff --git a/server/tests/backends/sql_translator_integration_tests/test_sql_postgres_translator_steps.py b/server/tests/backends/sql_translator_integration_tests/test_sql_postgres_translator_steps.py index 6fb3b27d4e..1e6be1a8b5 100644 --- a/server/tests/backends/sql_translator_integration_tests/test_sql_postgres_translator_steps.py +++ b/server/tests/backends/sql_translator_integration_tests/test_sql_postgres_translator_steps.py @@ -53,9 +53,7 @@ def postgres_container(): time.sleep(1) try: if container.status == "created" and psycopg2.connect(**con_params): - dataset = pd.read_csv( - f"{path.join(path.dirname(path.realpath(__file__)))}/beers.csv" - ) + dataset = pd.read_csv(f"{path.join(path.dirname(path.realpath(__file__)))}/beers.csv") dataset["brewing_date"] = dataset["brewing_date"].apply(pd.to_datetime) engine = create_engine(connection_string) dataset.to_sql("beers_tiny", engine) @@ -67,19 +65,13 @@ def postgres_container(): # Translation from Pipeline json to SQL query @pytest.mark.serial -@pytest.mark.parametrize( - "case_id, case_spec_file_path", retrieve_case("sql_translator", "postgres_pypika") -) -def test_sql_translator_pipeline( - case_id: str, case_spec_file_path: str, engine: Any, available_variables: dict -): +@pytest.mark.parametrize("case_id, case_spec_file_path", retrieve_case("sql_translator", "postgres_pypika")) +def test_sql_translator_pipeline(case_id: str, case_spec_file_path: str, engine: Any, available_variables: dict): spec = get_spec_from_json_fixture(case_id, case_spec_file_path) steps = spec["step"]["pipeline"] steps.insert(0, {"name": "domain", "domain": "beers_tiny"}) - pipeline = PipelineWithVariables(steps=steps).render( - available_variables, nosql_apply_parameters_to_query - ) + pipeline = PipelineWithVariables(steps=steps).render(available_variables, nosql_apply_parameters_to_query) # Convert Pipeline object to Postgres Query query = translate_pipeline( diff --git a/server/tests/backends/sql_translator_integration_tests/test_sql_redshift_translator_steps.py b/server/tests/backends/sql_translator_integration_tests/test_sql_redshift_translator_steps.py index 28aa79316a..224fc32bb3 100644 --- a/server/tests/backends/sql_translator_integration_tests/test_sql_redshift_translator_steps.py +++ b/server/tests/backends/sql_translator_integration_tests/test_sql_redshift_translator_steps.py @@ -49,18 +49,12 @@ def engine(): ] -@pytest.mark.parametrize( - "case_id, case_spec_file", retrieve_case("sql_translator", "redshift_pypika") -) -def test_redshift_translator_pipeline( - engine: Any, case_id: str, case_spec_file: str, available_variables: dict -): +@pytest.mark.parametrize("case_id, case_spec_file", retrieve_case("sql_translator", "redshift_pypika")) +def test_redshift_translator_pipeline(engine: Any, case_id: str, case_spec_file: str, available_variables: dict): pipeline_spec = get_spec_from_json_fixture(case_id, case_spec_file) steps = [{"name": "domain", "domain": "beers_tiny"}] + pipeline_spec["step"]["pipeline"] - pipeline = PipelineWithVariables(steps=steps).render( - available_variables, nosql_apply_parameters_to_query - ) + pipeline = PipelineWithVariables(steps=steps).render(available_variables, nosql_apply_parameters_to_query) query = translate_pipeline( sql_dialect=SQLDialect.REDSHIFT, diff --git a/server/tests/backends/sql_translator_integration_tests/test_sql_snowflake_translator_steps.py b/server/tests/backends/sql_translator_integration_tests/test_sql_snowflake_translator_steps.py index c7d65af541..d563a91935 100644 --- a/server/tests/backends/sql_translator_integration_tests/test_sql_snowflake_translator_steps.py +++ b/server/tests/backends/sql_translator_integration_tests/test_sql_snowflake_translator_steps.py @@ -50,21 +50,13 @@ def engine(): connection.close() -@pytest.skip( - "Should be skipped, waiting the payment of creds on november...", allow_module_level=True -) -@pytest.mark.parametrize( - "case_id, case_spec_file", retrieve_case("sql_translator", "snowflake_pypika") -) -def test_snowflake_translator_pipeline( - engine: Any, case_id: str, case_spec_file: str, available_variables: dict -): +@pytest.skip("Should be skipped, waiting the payment of creds on november...", allow_module_level=True) +@pytest.mark.parametrize("case_id, case_spec_file", retrieve_case("sql_translator", "snowflake_pypika")) +def test_snowflake_translator_pipeline(engine: Any, case_id: str, case_spec_file: str, available_variables: dict): pipeline_spec = get_spec_from_json_fixture(case_id, case_spec_file) steps = [{"name": "domain", "domain": "beers_tiny"}] + pipeline_spec["step"]["pipeline"] - pipeline = PipelineWithVariables(steps=steps).render( - available_variables, nosql_apply_parameters_to_query - ) + pipeline = PipelineWithVariables(steps=steps).render(available_variables, nosql_apply_parameters_to_query) query = translate_pipeline( sql_dialect=SQLDialect.SNOWFLAKE, diff --git a/server/tests/backends/sql_translator_unit_tests/test_base_translator.py b/server/tests/backends/sql_translator_unit_tests/test_base_translator.py index 3ae479f821..0eff59e2e4 100644 --- a/server/tests/backends/sql_translator_unit_tests/test_base_translator.py +++ b/server/tests/backends/sql_translator_unit_tests/test_base_translator.py @@ -106,14 +106,9 @@ def test_get_query_builder_more_than_one_step(base_translator: BaseTranslator): schema = Schema(DB_SCHEMA) step_0_query = Query.from_(schema.users).select(*ALL_TABLES["users"]) - expected = ( - Query.with_(step_0_query, "__step_0_basetranslator__").from_(schema.users).select("*") - ) + expected = Query.with_(step_0_query, "__step_0_basetranslator__").from_(schema.users).select("*") - columns = ( - Field(col) if col is not to_rename else Field(col).as_(rename_as) - for col in ALL_TABLES["users"] - ) + columns = (Field(col) if col is not to_rename else Field(col).as_(rename_as) for col in ALL_TABLES["users"]) expected_cols = (col if col != to_rename else rename_as for col in ALL_TABLES["users"]) step_1_query = Query.from_(AliasedQuery('"__step_0_basetranslator__"')).select(*columns) @@ -158,45 +153,34 @@ def test__get_window_function(base_translator: BaseTranslator, agg_type): @pytest.mark.parametrize("agg_type", ["count distinct including empty"]) -def test_aggregate_raise_expection( - base_translator: BaseTranslator, agg_type: str, default_step_kwargs: dict[str, Any] -): +def test_aggregate_raise_expection(base_translator: BaseTranslator, agg_type: str, default_step_kwargs: dict[str, Any]): new_column = "countDistinctAge" agg_field = "age" step = steps.AggregateStep( on=[agg_field], - aggregations=[ - steps.Aggregation(new_columns=[new_column], agg_function=agg_type, columns=[agg_field]) - ], + aggregations=[steps.Aggregation(new_columns=[new_column], agg_function=agg_type, columns=[agg_field])], ) with pytest.raises(NotImplementedError): base_translator.aggregate(step=step, columns=["*"], **default_step_kwargs) @pytest.mark.parametrize("agg_type", ["avg", "count", "count distinct", "max", "min", "sum"]) -def test_aggregate( - base_translator: BaseTranslator, agg_type: str, default_step_kwargs: dict[str, Any] -): +def test_aggregate(base_translator: BaseTranslator, agg_type: str, default_step_kwargs: dict[str, Any]): new_column = "avgAge" previous_step = "previous_with" agg_field = "age" step = steps.AggregateStep( on=[agg_field], - aggregations=[ - steps.Aggregation(new_columns=[new_column], agg_function=agg_type, columns=[agg_field]) - ], + aggregations=[steps.Aggregation(new_columns=[new_column], agg_function=agg_type, columns=[agg_field])], ) ctx = base_translator.aggregate(step=step, columns=["*"], **default_step_kwargs) agg_func = base_translator._get_aggregate_function(agg_type) field = Field(agg_field) expected_query = ( - Query.from_(previous_step) - .groupby(field) - .orderby(agg_field) - .select(field, agg_func(field).as_(new_column)) + Query.from_(previous_step).groupby(field).orderby(agg_field).select(field, agg_func(field).as_(new_column)) ) assert ctx.selectable.get_sql() == expected_query.get_sql() @@ -213,18 +197,14 @@ def test_aggregate_with_original_granularity( step = steps.AggregateStep( on=[agg_field], - aggregations=[ - steps.Aggregation(new_columns=[new_column], agg_function=agg_type, columns=[agg_field]) - ], + aggregations=[steps.Aggregation(new_columns=[new_column], agg_function=agg_type, columns=[agg_field])], keepOriginalGranularity=True, ) ctx = base_translator.aggregate(step=step, columns=original_select, **default_step_kwargs) agg_func = base_translator._get_aggregate_function(agg_type) field = Field(agg_field) - agg_query = ( - Query.from_(previous_step).groupby(field).select(field, agg_func(field).as_(new_column)) - ) + agg_query = Query.from_(previous_step).groupby(field).select(field, agg_func(field).as_(new_column)) expected_query = ( Query.from_(previous_step) @@ -244,9 +224,7 @@ def test_comparetext(base_translator: BaseTranslator, default_step_kwargs: dict[ compare_b = "pseudonyme" previous_step = "previous_with" selected_columns = ["*"] - step = steps.CompareTextStep( - newColumnName=new_column_name, strCol1=compare_a, strCol2=compare_b - ) + step = steps.CompareTextStep(newColumnName=new_column_name, strCol1=compare_a, strCol2=compare_b) ctx = base_translator.comparetext(step=step, columns=selected_columns, **default_step_kwargs) expected_query = Query.from_(previous_step).select( @@ -264,9 +242,7 @@ def test_concatenate(base_translator: BaseTranslator, default_step_kwargs: dict[ concat_columns = ["name", "pseudonyme"] separator = "," - step = steps.ConcatenateStep( - columns=concat_columns, separator=separator, new_column_name=new_column_name - ) + step = steps.ConcatenateStep(columns=concat_columns, separator=separator, new_column_name=new_column_name) ctx = base_translator.concatenate(step=step, columns=selected_columns, **default_step_kwargs) expected_query = Query.from_(previous_step).select( @@ -435,9 +411,7 @@ def test_ifthenelse_columns(base_translator: BaseTranslator, default_step_kwargs then = "a" reject = "b" - step = steps.IfthenelseStep( - condition=statement, then=then, else_value=reject, newColumn=new_column_name - ) + step = steps.IfthenelseStep(condition=statement, then=then, else_value=reject, newColumn=new_column_name) ctx = base_translator.ifthenelse(step=step, columns=selected_columns, **default_step_kwargs) expected_query = Query.from_(previous_step).select( @@ -458,9 +432,7 @@ def test_ifthenelse_strings(base_translator: BaseTranslator, default_step_kwargs then = "'a'" reject = '"b"' - step = steps.IfthenelseStep( - condition=statement, then=then, else_value=reject, newColumn=new_column_name - ) + step = steps.IfthenelseStep(condition=statement, then=then, else_value=reject, newColumn=new_column_name) ctx = base_translator.ifthenelse(step=step, columns=selected_columns, **default_step_kwargs) expected_query = Query.from_(previous_step).select( @@ -478,9 +450,7 @@ def test_lowercase(base_translator: BaseTranslator, default_step_kwargs: dict[st step = steps.LowercaseStep(column=column) ctx = base_translator.lowercase(step=step, columns=selected_columns, **default_step_kwargs) - expected_query = Query.from_(previous_step).select( - Field("pseudonyme"), functions.Lower(Field(column)).as_("name") - ) + expected_query = Query.from_(previous_step).select(Field("pseudonyme"), functions.Lower(Field(column)).as_("name")) assert ctx.selectable.get_sql() == expected_query.get_sql() @@ -493,9 +463,7 @@ def test_uppercase(base_translator: BaseTranslator, default_step_kwargs: dict[st step = steps.UppercaseStep(column=column) ctx = base_translator.uppercase(step=step, columns=selected_columns, **default_step_kwargs) - expected_query = Query.from_(previous_step).select( - Field("pseudonyme"), functions.Upper(Field(column)).as_("name") - ) + expected_query = Query.from_(previous_step).select(Field("pseudonyme"), functions.Upper(Field(column)).as_("name")) assert ctx.selectable.get_sql() == expected_query.get_sql() @@ -556,9 +524,7 @@ def test_sort(base_translator: BaseTranslator, default_step_kwargs: dict[str, An step = steps.SortStep(columns=columns) ctx = base_translator.sort(step=step, columns=selected_columns, **default_step_kwargs) - expected_query = ( - Query.from_(previous_step).select(*selected_columns).orderby(Field("name"), order=Order.asc) - ) + expected_query = Query.from_(previous_step).select(*selected_columns).orderby(Field("name"), order=Order.asc) assert ctx.selectable.get_sql() == expected_query.get_sql() @@ -569,9 +535,7 @@ def test_substring(base_translator: BaseTranslator, default_step_kwargs: dict[st column = "name" new_column_name = "name" - step = steps.SubstringStep( - column=column, newColumnName=new_column_name, start_index=0, end_index=10 - ) + step = steps.SubstringStep(column=column, newColumnName=new_column_name, start_index=0, end_index=10) ctx = base_translator.substring(step=step, columns=selected_columns, **default_step_kwargs) expected_query = Query.from_(previous_step).select( @@ -606,9 +570,7 @@ def test_text_with_datetime(base_translator: BaseTranslator, default_step_kwargs step = steps.TextStep(text=text, new_column=new_column_name) ctx = base_translator.text(step=step, columns=selected_columns, **default_step_kwargs) - text_as_str = ( - text.astimezone(ZoneInfo("UTC")).replace(tzinfo=None).strftime("%Y-%m-%d %H:%M:%S") - ) + text_as_str = text.astimezone(ZoneInfo("UTC")).replace(tzinfo=None).strftime("%Y-%m-%d %H:%M:%S") expected_query = Query.from_(previous_step).select( *selected_columns, @@ -627,9 +589,7 @@ def test_trim(base_translator: BaseTranslator, default_step_kwargs: dict[str, An step = steps.TrimStep(columns=columns) ctx = base_translator.trim(step=step, columns=selected_columns, **default_step_kwargs) - expected_query = Query.from_(previous_step).select( - Field("pseudonyme"), functions.Trim(Field(column)).as_(column) - ) + expected_query = Query.from_(previous_step).select(Field("pseudonyme"), functions.Trim(Field(column)).as_(column)) assert ctx.selectable.get_sql() == expected_query.get_sql() @@ -643,12 +603,7 @@ def test_uniquegroups(base_translator: BaseTranslator, default_step_kwargs: dict step = steps.UniqueGroupsStep(on=columns) ctx = base_translator.uniquegroups(step=step, columns=selected_columns, **default_step_kwargs) - expected_query = ( - Query.from_(previous_step) - .select(Field(column)) - .groupby(Field(column)) - .orderby(Field(column)) - ) + expected_query = Query.from_(previous_step).select(Field(column)).groupby(Field(column)).orderby(Field(column)) assert ctx.selectable.get_sql() == expected_query.get_sql() @@ -662,9 +617,7 @@ def test_absolutevalue(base_translator: BaseTranslator, default_step_kwargs: dic step = steps.AbsoluteValueStep(column=column, new_column=new_column) ctx = base_translator.absolutevalue(step=step, columns=selected_columns, **default_step_kwargs) - expected_query = Query.from_(previous_step).select( - *selected_columns, functions.Abs(Field(column)).as_(new_column) - ) + expected_query = Query.from_(previous_step).select(*selected_columns, functions.Abs(Field(column)).as_(new_column)) assert ctx.selectable.get_sql() == expected_query.get_sql() @@ -685,9 +638,7 @@ def test_join_simple( join_columns = [("project_id", "id")] previous_step = "previous_with" - step = steps.JoinStep( - right_pipeline=[steps.DomainStep(domain=right_domain)], type=join_type, on=join_columns - ) + step = steps.JoinStep(right_pipeline=[steps.DomainStep(domain=right_domain)], type=join_type, on=join_columns) ctx = base_translator.join(step=step, columns=selected_columns, **default_step_kwargs) left_table = Table(previous_step) @@ -727,11 +678,7 @@ def test_append_simple(base_translator: BaseTranslator, default_step_kwargs: dic expected_query = ( Query.from_(previous_step) .select(*selected_columns, LiteralValue("NULL").as_("user_id")) - .union_all( - Query.from_(right_table).select( - "name", LiteralValue("NULL").as_("created_at"), "user_id" - ) - ) + .union_all(Query.from_(right_table).select("name", LiteralValue("NULL").as_("created_at"), "user_id")) .orderby("name", "created_at", "user_id") ) assert ctx.selectable.get_sql() == expected_query.get_sql() @@ -746,19 +693,19 @@ def test_no_extra_quotes_in_base_translator( step = steps.RenameStep(toRename=[(to_rename, rename_as)]) ctx = base_translator.rename(step=step, columns=selected_columns, **default_step_kwargs) - assert ( - '''SELECT "name","age" "old-est","created_at" FROM "previous_with"''' - in ctx.selectable.get_sql() - ) + assert '''SELECT "name","age" "old-est","created_at" FROM "previous_with"''' in ctx.selectable.get_sql() def test_no_extra_quotes_in_base_translator_with_entire_pipeline(base_translator: BaseTranslator): pipeline = [steps.DomainStep(domain="users")] translated = base_translator.get_query_str(steps=pipeline) - assert translated == ( - 'WITH __step_0_basetranslator__ AS (SELECT "name","pseudonyme","age","id","project_id" FROM "test_schema"."users") ' # noqa: E501 - 'SELECT "name","pseudonyme","age","id","project_id" FROM "__step_0_basetranslator__"' + assert ( + translated + == ( + 'WITH __step_0_basetranslator__ AS (SELECT "name","pseudonyme","age","id","project_id" FROM "test_schema"."users") ' # noqa: E501 + 'SELECT "name","pseudonyme","age","id","project_id" FROM "__step_0_basetranslator__"' + ) ) @@ -770,8 +717,7 @@ def test_materialize_customsql_query_with_no_columns(base_translator: BaseTransl translated = base_translator.get_query_str(steps=pipeline) assert translated == ( - "WITH __step_0_basetranslator__ AS (SELECT titi, tata FROM toto) " - 'SELECT * FROM "__step_0_basetranslator__"' + "WITH __step_0_basetranslator__ AS (SELECT titi, tata FROM toto) " 'SELECT * FROM "__step_0_basetranslator__"' ) diff --git a/server/tests/backends/sql_translator_unit_tests/test_base_translator_strings.py b/server/tests/backends/sql_translator_unit_tests/test_base_translator_strings.py index 680f94e311..88c48bdfa8 100644 --- a/server/tests/backends/sql_translator_unit_tests/test_base_translator_strings.py +++ b/server/tests/backends/sql_translator_unit_tests/test_base_translator_strings.py @@ -14,9 +14,7 @@ steps.DomainStep(domain="beers_tiny"), steps.AggregateStep( on=["beer_kind"], - aggregations=[ - {"aggfunction": "count", "new_columns": ["beer_count"], "columns": ["name"]} - ], + aggregations=[{"aggfunction": "count", "new_columns": ["beer_count"], "columns": ["name"]}], ), ], 'WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") ,__step_1_dummy__ AS (SELECT "beer_kind",COUNT("name") "beer_count" FROM "__step_0_dummy__" GROUP BY "beer_kind" ORDER BY "beer_kind") SELECT "beer_kind","beer_count" FROM "__step_1_dummy__"', @@ -26,9 +24,7 @@ steps.DomainStep(domain="beers_tiny"), steps.AggregateStep( on=["beer_kind"], - aggregations=[ - {"aggfunction": "count", "new_columns": ["beer_count"], "columns": ["name"]} - ], + aggregations=[{"aggfunction": "count", "new_columns": ["beer_count"], "columns": ["name"]}], keep_original_granularity=True, ), ], diff --git a/server/tests/backends/sql_translator_unit_tests/test_date_format_translators.py b/server/tests/backends/sql_translator_unit_tests/test_date_format_translators.py index 7c3ba962fe..df3502748d 100644 --- a/server/tests/backends/sql_translator_unit_tests/test_date_format_translators.py +++ b/server/tests/backends/sql_translator_unit_tests/test_date_format_translators.py @@ -38,18 +38,14 @@ def date_format_translators(): ) -def test_fromdate( - date_format_translators: DateFormatTranslator, default_step_kwargs: dict[str, Any] -): +def test_fromdate(date_format_translators: DateFormatTranslator, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "pseudonyme"] previous_step = "previous_with" column = "birthday" format_ = "dd/yy" step = steps.FromdateStep(column=column, format=format_) - ctx = date_format_translators.fromdate( - step=step, columns=selected_columns, **default_step_kwargs - ) + ctx = date_format_translators.fromdate(step=step, columns=selected_columns, **default_step_kwargs) expected_query = Query.from_(previous_step).select( *selected_columns, functions.ToChar(Field(column), format_).as_(column) diff --git a/server/tests/backends/sql_translator_unit_tests/test_filter_translators.py b/server/tests/backends/sql_translator_unit_tests/test_filter_translators.py index da4d163284..072a338297 100644 --- a/server/tests/backends/sql_translator_unit_tests/test_filter_translators.py +++ b/server/tests/backends/sql_translator_unit_tests/test_filter_translators.py @@ -34,9 +34,7 @@ def filter_translator(): @pytest.mark.parametrize("op", ["eq", "ne", "lt", "le", "gt", "ge"]) -def test_comparison_filter( - filter_translator: FilterTranslator, op: str, default_step_kwargs: dict[str, Any] -): +def test_comparison_filter(filter_translator: FilterTranslator, op: str, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "age"] previous_step = "previous_with" column = "name" @@ -50,19 +48,13 @@ def test_comparison_filter( import operator op_func = getattr(operator, op) - expected_query = ( - Query.from_(previous_step) - .where(op_func(Field(column), anonymous)) - .select(*selected_columns) - ) + expected_query = Query.from_(previous_step).where(op_func(Field(column), anonymous)).select(*selected_columns) assert ctx.selectable.get_sql() == expected_query.get_sql() @pytest.mark.parametrize("op", ["in", "nin"]) -def test_inclusion_filter( - filter_translator: FilterTranslator, op: str, default_step_kwargs: dict[str, Any] -): +def test_inclusion_filter(filter_translator: FilterTranslator, op: str, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "age"] previous_step = "previous_with" column = "name" @@ -75,17 +67,13 @@ def test_inclusion_filter( step = steps.FilterStep(condition=condition) ctx = filter_translator.filter(step=step, columns=selected_columns, **default_step_kwargs) - expected_query = ( - Query.from_(previous_step).where(op_func[op](inclusion)).select(*selected_columns) - ) + expected_query = Query.from_(previous_step).where(op_func[op](inclusion)).select(*selected_columns) assert ctx.selectable.get_sql() == expected_query.get_sql() @pytest.mark.parametrize("op", ["isnull", "notnull"]) -def test_null_filter( - filter_translator: FilterTranslator, op: str, default_step_kwargs: dict[str, Any] -): +def test_null_filter(filter_translator: FilterTranslator, op: str, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "age"] previous_step = "previous_with" column = "name" @@ -103,9 +91,7 @@ def test_null_filter( @pytest.mark.parametrize("op", ["from", "until"]) -def test_datebound_filter( - filter_translator: FilterTranslator, op: str, default_step_kwargs: dict[str, Any] -): +def test_datebound_filter(filter_translator: FilterTranslator, op: str, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "age"] previous_step = "previous_with" column = "name" @@ -117,13 +103,9 @@ def test_datebound_filter( value_str_time = dateutil_parser.parse(datetime).astimezone().strftime("%Y-%m-%d %H:%M:%S") if op == "from": - op_func = functions.Cast(Field(column), "TIMESTAMP") >= functions.Cast( - value_str_time, "TIMESTAMP" - ) + op_func = functions.Cast(Field(column), "TIMESTAMP") >= functions.Cast(value_str_time, "TIMESTAMP") else: - op_func = functions.Cast(Field(column), "TIMESTAMP") <= functions.Cast( - value_str_time, "TIMESTAMP" - ) + op_func = functions.Cast(Field(column), "TIMESTAMP") <= functions.Cast(value_str_time, "TIMESTAMP") step = steps.FilterStep(condition=condition) ctx = filter_translator.filter(step=step, columns=selected_columns, **default_step_kwargs) @@ -165,9 +147,7 @@ def regexp_translator(): ) -def test_matches_regexp_filter( - regexp_translator: REGEXPTranslator, default_step_kwargs: dict[str, Any] -): +def test_matches_regexp_filter(regexp_translator: REGEXPTranslator, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "age"] previous_step = "previous_with" column = "name" @@ -186,9 +166,7 @@ def test_matches_regexp_filter( assert ctx.selectable.get_sql() == expected_query.get_sql() -def test_notmatches_regexp_filter( - regexp_translator: REGEXPTranslator, default_step_kwargs: dict[str, Any] -): +def test_notmatches_regexp_filter(regexp_translator: REGEXPTranslator, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "age"] previous_step = "previous_with" column = "name" @@ -229,9 +207,7 @@ def similar_to_translator(): ) -def test_matches_similar_to_filter( - similar_to_translator: SimilarToTranslator, default_step_kwargs: dict[str, Any] -): +def test_matches_similar_to_filter(similar_to_translator: SimilarToTranslator, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "age"] previous_step = "previous_with" column = "name" @@ -256,9 +232,7 @@ def test_matches_similar_to_filter( assert ctx.selectable.get_sql() == expected_query.get_sql() -def test_notmatches_similar_to_filter( - similar_to_translator: SimilarToTranslator, default_step_kwargs: dict[str, Any] -): +def test_notmatches_similar_to_filter(similar_to_translator: SimilarToTranslator, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "age"] previous_step = "previous_with" column = "name" @@ -305,9 +279,7 @@ def contains_translator(): ) -def test_matches_contains_filter( - contains_translator: ContainsTranslator, default_step_kwargs: dict[str, Any] -): +def test_matches_contains_filter(contains_translator: ContainsTranslator, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "age"] previous_step = "previous_with" column = "name" @@ -332,9 +304,7 @@ def test_matches_contains_filter( assert ctx.selectable.get_sql() == expected_query.get_sql() -def test_notmatches_contains_filter( - contains_translator: ContainsTranslator, default_step_kwargs: dict[str, Any] -): +def test_notmatches_contains_filter(contains_translator: ContainsTranslator, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "age"] previous_step = "previous_with" column = "name" @@ -380,9 +350,7 @@ def regexp_like_translator(): ) -def test_matches_regexp_like_filter( - regexp_translator: REGEXP_LIKE_Translator, default_step_kwargs: dict[str, Any] -): +def test_matches_regexp_like_filter(regexp_translator: REGEXP_LIKE_Translator, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "age"] previous_step = "previous_with" column = "name" @@ -401,9 +369,7 @@ def test_matches_regexp_like_filter( assert ctx.selectable.get_sql() == expected_query.get_sql() -def test_notmatches_regexp__like_filter( - regexp_translator: REGEXP_LIKE_Translator, default_step_kwargs: dict[str, Any] -): +def test_notmatches_regexp__like_filter(regexp_translator: REGEXP_LIKE_Translator, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "age"] previous_step = "previous_with" column = "name" @@ -443,9 +409,7 @@ def regexp_contains_translator(): ) -def test_matches_regexp_contains_filter( - regexp_translator: REGEXP_LIKE_Translator, default_step_kwargs: dict[str, Any] -): +def test_matches_regexp_contains_filter(regexp_translator: REGEXP_LIKE_Translator, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "age"] previous_step = "previous_with" column = "name" diff --git a/server/tests/backends/sql_translator_unit_tests/test_google_big_query.py b/server/tests/backends/sql_translator_unit_tests/test_google_big_query.py index f0e65d2ec7..806c7607bb 100644 --- a/server/tests/backends/sql_translator_unit_tests/test_google_big_query.py +++ b/server/tests/backends/sql_translator_unit_tests/test_google_big_query.py @@ -14,9 +14,7 @@ def test_escape_names(gbq_translator: GoogleBigQueryTranslator) -> None: query = gbq_translator.get_query_str( steps=[ DomainStep(domain="table"), - FilterStep( - condition={"column": "col1", "operator": "in", "value": ["pika", "l'alcool"]} - ), + FilterStep(condition={"column": "col1", "operator": "in", "value": ["pika", "l'alcool"]}), ] ) assert query == ( diff --git a/server/tests/backends/sql_translator_unit_tests/test_row_number_translators.py b/server/tests/backends/sql_translator_unit_tests/test_row_number_translators.py index ab6cab2a39..9c57139e28 100644 --- a/server/tests/backends/sql_translator_unit_tests/test_row_number_translators.py +++ b/server/tests/backends/sql_translator_unit_tests/test_row_number_translators.py @@ -51,10 +51,7 @@ def get_top_query(sort_order, previous_step, group, rank_on, selected_columns, l name_field = Field(group) sub_query = Query.from_(previous_step).select(*selected_columns) rank_select = sub_query.select( - RowNumber() - .as_("row_number") - .over(name_field) - .orderby(age_field, order=getattr(Order, sort_order)) + RowNumber().as_("row_number").over(name_field).orderby(age_field, order=getattr(Order, sort_order)) ) expected_query = ( Query.from_(rank_select) @@ -77,13 +74,9 @@ def test_top_with_enabled_row_number( selected_columns = ["*"] step = steps.TopStep(rank_on=rank_on, groups=[group], sort=sort_order, limit=100) - ctx = row_number_enabled_translator.top( - step=step, columns=selected_columns, **default_step_kwargs - ) + ctx = row_number_enabled_translator.top(step=step, columns=selected_columns, **default_step_kwargs) - expected_query = get_top_query( - sort_order, previous_step, group, rank_on, selected_columns, step.limit - ) + expected_query = get_top_query(sort_order, previous_step, group, rank_on, selected_columns, step.limit) assert ctx.selectable.get_sql() == expected_query.get_sql() @@ -100,9 +93,7 @@ def test_top_with_disabled_row_number( step = steps.TopStep(rank_on=rank_on, groups=[group], sort=sort_order, limit=100) with pytest.raises(NotImplementedError): - row_number_disabled_translator.top( - step=step, columns=selected_columns, **default_step_kwargs - ) + row_number_disabled_translator.top(step=step, columns=selected_columns, **default_step_kwargs) def test_argmax_with_enabled_split_part( @@ -114,9 +105,7 @@ def test_argmax_with_enabled_split_part( selected_columns = ["*"] step = steps.ArgmaxStep(column=rank_on, groups=[group]) - ctx = row_number_enabled_translator.argmax( - step=step, columns=selected_columns, **default_step_kwargs - ) + ctx = row_number_enabled_translator.argmax(step=step, columns=selected_columns, **default_step_kwargs) expected_query = get_top_query("desc", previous_step, group, rank_on, selected_columns, 1) @@ -132,9 +121,7 @@ def test_argmax_with_disabled_split_part( step = steps.ArgmaxStep(column=rank_on, groups=[group]) with pytest.raises(NotImplementedError): - row_number_disabled_translator.argmax( - step=step, columns=selected_columns, **default_step_kwargs - ) + row_number_disabled_translator.argmax(step=step, columns=selected_columns, **default_step_kwargs) def test_argmin_with_enabled_split_part( @@ -146,9 +133,7 @@ def test_argmin_with_enabled_split_part( selected_columns = ["*"] step = steps.ArgminStep(column=rank_on, groups=[group]) - ctx = row_number_enabled_translator.argmin( - step=step, columns=selected_columns, **default_step_kwargs - ) + ctx = row_number_enabled_translator.argmin(step=step, columns=selected_columns, **default_step_kwargs) expected_query = get_top_query("asc", previous_step, group, rank_on, selected_columns, 1) @@ -165,6 +150,4 @@ def test_argmin_with_disabled_split_part( step = steps.ArgminStep(column=rank_on, groups=[group]) with pytest.raises(NotImplementedError): - row_number_disabled_translator.argmin( - step=step, columns=selected_columns, **default_step_kwargs - ) + row_number_disabled_translator.argmin(step=step, columns=selected_columns, **default_step_kwargs) diff --git a/server/tests/backends/sql_translator_unit_tests/test_snowflake_translator.py b/server/tests/backends/sql_translator_unit_tests/test_snowflake_translator.py index b05185ec83..2f6e087d1f 100644 --- a/server/tests/backends/sql_translator_unit_tests/test_snowflake_translator.py +++ b/server/tests/backends/sql_translator_unit_tests/test_snowflake_translator.py @@ -17,9 +17,7 @@ def snowflake_translator() -> SnowflakeTranslator: return SnowflakeTranslator(tables_columns={}) -def test_evolution_abs_day( - snowflake_translator: SnowflakeTranslator, default_step_kwargs: dict[str, Any] -) -> None: +def test_evolution_abs_day(snowflake_translator: SnowflakeTranslator, default_step_kwargs: dict[str, Any]) -> None: selected_columns = ["name", "brewing_date"] previous_step = "previous_with" new_column = "evol" @@ -39,17 +37,13 @@ def test_evolution_abs_day( Query.from_(previous_step) .select( *[prev_table.field(col) for col in selected_columns], - (prev_table.field(step.value_col) - right_table.field(step.value_col)).as_( - step.new_column - ), + (prev_table.field(step.value_col) - right_table.field(step.value_col)).as_(step.new_column), ) .left_join( Query.from_(previous_step) .select( step.value_col, - DateAddWithoutUnderscore("day", 1, prev_table.field(step.date_col)).as_( - step.date_col - ), + DateAddWithoutUnderscore("day", 1, prev_table.field(step.date_col)).as_(step.date_col), ) .as_("right_table") ) @@ -81,18 +75,16 @@ def test_evolution_perc_groups_day( Query.from_(previous_step) .select( *[prev_table.field(col) for col in selected_columns], - (prev_table.field(step.value_col) - right_table.field(step.value_col)).as_( - step.new_column - ), + (prev_table.field(step.value_col) - right_table.field(step.value_col)).as_(step.new_column), *[prev_table.field(col).as_(f"left_table_{col}") for col in step.index_columns], ) .left_join( Query.from_(previous_step) .select( step.value_col, - DateAddWithoutUnderscore( - date_part="day", interval=1, term=prev_table.field(step.date_col) - ).as_(step.date_col), + DateAddWithoutUnderscore(date_part="day", interval=1, term=prev_table.field(step.date_col)).as_( + step.date_col + ), *step.index_columns, ) .as_("right_table") @@ -115,9 +107,7 @@ def test_date_extract_extract_kw( date_info=["year", "month", "day"], column="brewing_date", ) - ctx = snowflake_translator.dateextract( - step=step, columns=selected_columns, **default_step_kwargs - ) + ctx = snowflake_translator.dateextract(step=step, columns=selected_columns, **default_step_kwargs) expected_query = Query.from_(previous_step).select( *[prev_table.field(col) for col in selected_columns], Cast( @@ -145,9 +135,7 @@ def test_date_extract_extract_kw( assert ctx.selectable.get_sql(quote_char='"') == expected_query.get_sql(quote_char='"') -def test_date_extract_func( - snowflake_translator: SnowflakeTranslator, default_step_kwargs: dict[str, Any] -) -> None: +def test_date_extract_func(snowflake_translator: SnowflakeTranslator, default_step_kwargs: dict[str, Any]) -> None: selected_columns = ["name", "brewing_date"] previous_step = "previous_with" prev_table = Table(previous_step) @@ -157,9 +145,7 @@ def test_date_extract_func( date_info=["isoWeek"], column="brewing_date", ) - ctx = snowflake_translator.dateextract( - step=step, columns=selected_columns, **default_step_kwargs - ) + ctx = snowflake_translator.dateextract(step=step, columns=selected_columns, **default_step_kwargs) expected_query = Query.from_(previous_step).select( *[prev_table.field(col) for col in selected_columns], Cast( diff --git a/server/tests/backends/sql_translator_unit_tests/test_split_part_translators.py b/server/tests/backends/sql_translator_unit_tests/test_split_part_translators.py index 1ee652117d..e9a899ec14 100644 --- a/server/tests/backends/sql_translator_unit_tests/test_split_part_translators.py +++ b/server/tests/backends/sql_translator_unit_tests/test_split_part_translators.py @@ -46,9 +46,7 @@ def split_disabled_translator(): ) -def test_split_enabled( - split_enabled_translator: SplitEnabledTranslator, default_step_kwargs: dict[str, Any] -): +def test_split_enabled(split_enabled_translator: SplitEnabledTranslator, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "pseudonyme"] previous_step = "previous_with" column = "name" @@ -66,9 +64,7 @@ def test_split_enabled( assert ctx.selectable.get_sql() == expected_query.get_sql() -def test_split_disabled( - split_disabled_translator: SplitDisabledTranslator, default_step_kwargs: dict[str, Any] -): +def test_split_disabled(split_disabled_translator: SplitDisabledTranslator, default_step_kwargs: dict[str, Any]): selected_columns = ["name", "pseudonyme"] column = "name" delimiter = "," diff --git a/server/tests/backends/sql_translator_unit_tests/test_translation_with_mergeable_first_step.py b/server/tests/backends/sql_translator_unit_tests/test_translation_with_mergeable_first_step.py index 21563c2174..84621629fb 100644 --- a/server/tests/backends/sql_translator_unit_tests/test_translation_with_mergeable_first_step.py +++ b/server/tests/backends/sql_translator_unit_tests/test_translation_with_mergeable_first_step.py @@ -57,12 +57,8 @@ ] -@pytest.mark.parametrize( - "steps, expected", zip(_CASES, _EXPECTED_NO_SOURCE_ROW_SUBSET, strict=True) -) -def test_base_translator_merge_first_steps( - translator: SQLTranslator, steps: list[PipelineStep], expected: str -): +@pytest.mark.parametrize("steps, expected", zip(_CASES, _EXPECTED_NO_SOURCE_ROW_SUBSET, strict=True)) +def test_base_translator_merge_first_steps(translator: SQLTranslator, steps: list[PipelineStep], expected: str): pipeline = Pipeline(steps=steps) assert translator.get_query_str(steps=pipeline.steps) == expected @@ -79,9 +75,7 @@ def test_base_translator_merge_first_steps( ] -@pytest.mark.parametrize( - "steps, expected", zip(_CASES, _EXPECTED_WITH_SOURCE_ROW_SUBSET, strict=True) -) +@pytest.mark.parametrize("steps, expected", zip(_CASES, _EXPECTED_WITH_SOURCE_ROW_SUBSET, strict=True)) def test_base_translator_merge_first_steps_with_subset( translator: SQLTranslator, steps: list[PipelineStep], expected: str ): diff --git a/server/tests/steps/test_addmissingdates.py b/server/tests/steps/test_addmissingdates.py index dd23c7b052..d6c81dcbb2 100644 --- a/server/tests/steps/test_addmissingdates.py +++ b/server/tests/steps/test_addmissingdates.py @@ -29,29 +29,22 @@ def test_missing_date(today): } ) - step = AddMissingDatesStep( - name="addmissingdates", datesColumn="date", datesGranularity="day", groups=[] - ) + step = AddMissingDatesStep(name="addmissingdates", datesColumn="date", datesGranularity="day", groups=[]) result = execute_addmissingdates(step, df) - expected_result = pd.concat( - [df, pd.DataFrame({"date": missing_dates, "value": [None, None]})] - ).sort_values(by="date") + expected_result = pd.concat([df, pd.DataFrame({"date": missing_dates, "value": [None, None]})]).sort_values( + by="date" + ) assert_dataframes_equals(result, expected_result) def test_missing_date_years(today): - dates = [ - today + timedelta(days=nb_years * 365) - for nb_years in list(range(1, 10)) + list(range(12, 20)) - ] + dates = [today + timedelta(days=nb_years * 365) for nb_years in list(range(1, 10)) + list(range(12, 20))] missing_dates = [today + timedelta(days=10 * 365), today + timedelta(days=11 * 365)] # dates added by pandas are at the beginning of the last day of the year - missing_dates = [ - datetime.datetime(year=missing_date.year, month=1, day=1) for missing_date in missing_dates - ] + missing_dates = [datetime.datetime(year=missing_date.year, month=1, day=1) for missing_date in missing_dates] values = [idx for (idx, value) in enumerate(dates)] df = pd.DataFrame( @@ -61,14 +54,12 @@ def test_missing_date_years(today): } ) - step = AddMissingDatesStep( - name="addmissingdates", datesColumn="date", datesGranularity="year", groups=[] - ) + step = AddMissingDatesStep(name="addmissingdates", datesColumn="date", datesGranularity="year", groups=[]) result = execute_addmissingdates(step, df) - expected_result = pd.concat( - [df, pd.DataFrame({"date": missing_dates, "value": [None, None]})] - ).sort_values(by="date") + expected_result = pd.concat([df, pd.DataFrame({"date": missing_dates, "value": [None, None]})]).sort_values( + by="date" + ) assert_dataframes_equals(result, expected_result) @@ -89,9 +80,7 @@ def test_missing_date_with_groups_correct_indexing(today): } ) - step = AddMissingDatesStep( - name="addmissingdates", datesColumn="date", datesGranularity="day", groups=["country"] - ) + step = AddMissingDatesStep(name="addmissingdates", datesColumn="date", datesGranularity="day", groups=["country"]) result = execute_addmissingdates(step, df) expected_result = pd.concat( [ @@ -110,10 +99,7 @@ def test_missing_date_with_groups_correct_indexing(today): def test_missing_date_with_groups_various_length(today): - dates = [ - datetime.datetime(year=2020, month=nb_month, day=1) - for nb_month in list(range(1, 5)) + list(range(8, 10)) - ] + dates = [datetime.datetime(year=2020, month=nb_month, day=1) for nb_month in list(range(1, 5)) + list(range(8, 10))] missing_dates = [datetime.datetime(year=2020, month=nb_month, day=1) for nb_month in [5, 6, 7]] @@ -126,9 +112,7 @@ def test_missing_date_with_groups_various_length(today): } ) - step = AddMissingDatesStep( - name="addmissingdates", datesColumn="date", datesGranularity="month", groups=["country"] - ) + step = AddMissingDatesStep(name="addmissingdates", datesColumn="date", datesGranularity="month", groups=["country"]) result = execute_addmissingdates(step, df) expected_result = pd.concat( [ @@ -165,9 +149,7 @@ def test_benchmark_addmissingdate(benchmark, today): } ) - step = AddMissingDatesStep( - name="addmissingdates", datesColumn="date", datesGranularity="day", groups=[] - ) + step = AddMissingDatesStep(name="addmissingdates", datesColumn="date", datesGranularity="day", groups=[]) result = benchmark(execute_addmissingdates, step, df) assert len(result) == 2000 @@ -198,9 +180,7 @@ def test_add_missing_dates_with_tz_aware_timestamps(): } ) - step = AddMissingDatesStep( - name="addmissingdates", dates_column="date", dates_granularity="day", groups=[] - ) + step = AddMissingDatesStep(name="addmissingdates", dates_column="date", dates_granularity="day", groups=[]) result = execute_addmissingdates(step, df) assert_dataframes_equals(result, expected_result) diff --git a/server/tests/steps/test_convert.py b/server/tests/steps/test_convert.py index 575b27be2f..3e072cf76c 100644 --- a/server/tests/steps/test_convert.py +++ b/server/tests/steps/test_convert.py @@ -7,9 +7,7 @@ def test_benchmark_convert(benchmark): - dates = [ - str(datetime.datetime.today() + timedelta(days=nb_day)) for nb_day in list(range(1, 2001)) - ] + dates = [str(datetime.datetime.today() + timedelta(days=nb_day)) for nb_day in list(range(1, 2001))] df = pandas.DataFrame( { "date": dates, diff --git a/server/tests/steps/test_cumsum.py b/server/tests/steps/test_cumsum.py index d6470ee31b..a66bcd1cde 100644 --- a/server/tests/steps/test_cumsum.py +++ b/server/tests/steps/test_cumsum.py @@ -16,9 +16,7 @@ def test_cumsum_legacy_syntax(): def test_benchmark_cumsum_legacy_syntax(benchmark): big_df = DataFrame({"value": list(range(1000))}) - step = CumSumStep( - name="cumsum", referenceColumn="value", value_column="value", new_column="my_cumsum" - ) + step = CumSumStep(name="cumsum", referenceColumn="value", value_column="value", new_column="my_cumsum") benchmark(execute_cumsum, step, big_df) diff --git a/server/tests/steps/test_date_extract.py b/server/tests/steps/test_date_extract.py index 6908441543..a03bcb1e14 100644 --- a/server/tests/steps/test_date_extract.py +++ b/server/tests/steps/test_date_extract.py @@ -298,8 +298,6 @@ def test_benchmark_dateextract(benchmark): "date": dates, } ) - step = DateExtractStep( - name="dateextract", column="date", operation="day", new_column_name="date" - ) + step = DateExtractStep(name="dateextract", column="date", operation="day", new_column_name="date") benchmark(execute_date_extract, step, df) diff --git a/server/tests/steps/test_duration.py b/server/tests/steps/test_duration.py index 6c38533142..6f6aa7d389 100644 --- a/server/tests/steps/test_duration.py +++ b/server/tests/steps/test_duration.py @@ -51,9 +51,7 @@ def test_duration(time_delta_parameters: dict[str, int], duration_in: str, expec ({"hours": 1}, "days", 1 / 24.0), ], ) -def test_duration_with_dates( - time_delta_parameters: dict[str, int], duration_in: str, expected_result: float -): +def test_duration_with_dates(time_delta_parameters: dict[str, int], duration_in: str, expected_result: float): step = DurationStep( name="duration", newColumnName="DURATION", diff --git a/server/tests/steps/test_evolution.py b/server/tests/steps/test_evolution.py index 76d508e617..a8c907aeda 100644 --- a/server/tests/steps/test_evolution.py +++ b/server/tests/steps/test_evolution.py @@ -70,9 +70,7 @@ def test_evolution_percentage(sample_df: DataFrame): ) df_result = execute_evolution(step, sample_df) - expected_result = sample_df.assign( - VALUE_EVOL_PCT=[None, 0.0253164, -0.0493827, -0.025974, None, 0.1282051] - ) + expected_result = sample_df.assign(VALUE_EVOL_PCT=[None, 0.0253164, -0.0493827, -0.025974, None, 0.1282051]) assert_dataframes_equals(df_result, expected_result) @@ -114,9 +112,7 @@ def test_evolution_with_groups(df_with_groups: DataFrame): ) df_result = execute_evolution(step, df_with_groups) - expected_result = df_with_groups.assign( - MY_EVOL=[None, 2, -4, -2, None, 10, None, 0, -1, -1, 3, None] - ) + expected_result = df_with_groups.assign(MY_EVOL=[None, 2, -4, -2, None, 10, None, 0, -1, -1, 3, None]) assert_dataframes_equals(df_result, expected_result) diff --git a/server/tests/steps/test_fillna.py b/server/tests/steps/test_fillna.py index 0da3fa85b4..c3160866f3 100644 --- a/server/tests/steps/test_fillna.py +++ b/server/tests/steps/test_fillna.py @@ -9,9 +9,7 @@ @pytest.fixture def sample_df(): - return DataFrame( - {"colA": ["toto", "tutu", None], "colB": [1, 2, None], "colC": [100, 50, None]} - ) + return DataFrame({"colA": ["toto", "tutu", None], "colB": [1, 2, None], "colC": [100, 50, None]}) def test_simple_fillna(sample_df): diff --git a/server/tests/steps/test_filter.py b/server/tests/steps/test_filter.py index 39cc1b9099..5bbaa8da54 100644 --- a/server/tests/steps/test_filter.py +++ b/server/tests/steps/test_filter.py @@ -57,9 +57,7 @@ def test_simple_ne_filter(sample_df): ) df_result = execute_filter(step, sample_df) - assert_dataframes_equals( - df_result, DataFrame({"colA": ["toto", "tata"], "colB": [1, 3], "colC": [100, 25]}) - ) + assert_dataframes_equals(df_result, DataFrame({"colA": ["toto", "tata"], "colB": [1, 3], "colC": [100, 25]})) def test_simple_gt_filter(sample_df): @@ -87,9 +85,7 @@ def test_simple_ge_filter(sample_df): ) df_result = execute_filter(step, sample_df) - assert_dataframes_equals( - df_result, DataFrame({"colA": ["tutu", "tata"], "colB": [2, 3], "colC": [50, 25]}) - ) + assert_dataframes_equals(df_result, DataFrame({"colA": ["tutu", "tata"], "colB": [2, 3], "colC": [50, 25]})) def test_simple_lt_filter(sample_df): @@ -117,9 +113,7 @@ def test_simple_le_filter(sample_df): ) df_result = execute_filter(step, sample_df) - assert_dataframes_equals( - df_result, DataFrame({"colA": ["toto", "tutu"], "colB": [1, 2], "colC": [100, 50]}) - ) + assert_dataframes_equals(df_result, DataFrame({"colA": ["toto", "tutu"], "colB": [1, 2], "colC": [100, 50]})) def test_simple_in_filter(sample_df): @@ -133,9 +127,7 @@ def test_simple_in_filter(sample_df): ) df_result = execute_filter(step, sample_df) - assert_dataframes_equals( - df_result, DataFrame({"colA": ["toto", "tutu"], "colB": [1, 2], "colC": [100, 50]}) - ) + assert_dataframes_equals(df_result, DataFrame({"colA": ["toto", "tutu"], "colB": [1, 2], "colC": [100, 50]})) def test_simple_nin_filter(sample_df): @@ -203,9 +195,7 @@ def test_simple_notmatches_filter(sample_df): ) df_result = execute_filter(step, sample_df) - assert_dataframes_equals( - df_result, DataFrame({"colA": ["toto", "tutu"], "colB": [1, 2], "colC": [100, 50]}) - ) + assert_dataframes_equals(df_result, DataFrame({"colA": ["toto", "tutu"], "colB": [1, 2], "colC": [100, 50]})) def test_and_logical_conditions(sample_df): @@ -251,9 +241,7 @@ def test_or_logical_conditions(sample_df): ) df_result = execute_filter(step, sample_df) - assert_dataframes_equals( - df_result, DataFrame({"colA": ["toto", "tata"], "colB": [1, 3], "colC": [100, 25]}) - ) + assert_dataframes_equals(df_result, DataFrame({"colA": ["toto", "tata"], "colB": [1, 3], "colC": [100, 25]})) def test_nested_logical_conditions(sample_df): @@ -370,7 +358,9 @@ def test_date_filter(date_df: DataFrame, expected_date_filter_result: DataFrame) condition=ConditionComboAnd( and_=[ DateBoundCondition( - column="Transaction_date", operator="from", value="2009-01-02T00:00:00" # naive + column="Transaction_date", + operator="from", + value="2009-01-02T00:00:00", # naive ), DateBoundCondition( column="Transaction_date", diff --git a/server/tests/steps/test_formula.py b/server/tests/steps/test_formula.py index 76d278ab8e..074878207a 100644 --- a/server/tests/steps/test_formula.py +++ b/server/tests/steps/test_formula.py @@ -22,9 +22,7 @@ def sample_df() -> DataFrame: def test_formula(sample_df: DataFrame): - step = FormulaStep( - name="formula", new_column="z", formula="(colA + [col B]) * ([col C] + [col D]) / 10" - ) + step = FormulaStep(name="formula", new_column="z", formula="(colA + [col B]) * ([col C] + [col D]) / 10") df_result = execute_formula(step, sample_df) expected_result = sample_df.assign(z=[2.1, 210.0]) diff --git a/server/tests/steps/test_ifthenelse.py b/server/tests/steps/test_ifthenelse.py index bfc514ad79..817d647da7 100644 --- a/server/tests/steps/test_ifthenelse.py +++ b/server/tests/steps/test_ifthenelse.py @@ -41,9 +41,7 @@ def test_simple_condition_strings(): ) result_df = execute_ifthenelse(step, sample_df) - expected_df = DataFrame( - {"a_str": ["test", "test", "autre chose"], "test": ["foo", "foo", "bar"]} - ) + expected_df = DataFrame({"a_str": ["test", "test", "autre chose"], "test": ["foo", "foo", "bar"]}) assert_dataframes_equals(result_df, expected_df) @@ -61,9 +59,7 @@ def test_then_should_support_formulas(): ) result_df = execute_ifthenelse(step, base_df) - expected_df = DataFrame( - {"a_bool": [True, False, True], "a_number": [1, 2, 3], "result": [1, -2, 3]} - ) + expected_df = DataFrame({"a_bool": [True, False, True], "a_number": [1, 2, 3], "result": [1, -2, 3]}) assert_dataframes_equals(result_df, expected_df) @@ -89,9 +85,7 @@ def test_then_should_support_nested_else(): ) result_df = execute_ifthenelse(step, base_df) - expected_df = DataFrame( - {"a_bool": [True, False, False], "a_number": [1, 2, 3], "result": [3, 2, 1]} - ) + expected_df = DataFrame({"a_bool": [True, False, False], "a_number": [1, 2, 3], "result": [3, 2, 1]}) assert_dataframes_equals(result_df, expected_df) diff --git a/server/tests/steps/test_join.py b/server/tests/steps/test_join.py index b99c6601f1..614899d398 100644 --- a/server/tests/steps/test_join.py +++ b/server/tests/steps/test_join.py @@ -49,9 +49,7 @@ def test_join_left( execute_pipeline=mock_execute_pipeline, ) - expected_result = DataFrame( - {"NAME": ["foo", "bar"], "name": [None, "bar"], "AGE": [42, 43], "score": [None, 100]} - ) + expected_result = DataFrame({"NAME": ["foo", "bar"], "name": [None, "bar"], "AGE": [42, 43], "score": [None, 100]}) assert_dataframes_equals(df_result, expected_result) @@ -140,9 +138,7 @@ def test_join_domain_name( execute_pipeline=mock_execute_pipeline, ) - expected_result = DataFrame( - {"NAME": ["foo", "bar"], "name": [None, "bar"], "AGE": [42, 43], "score": [None, 1]} - ) + expected_result = DataFrame({"NAME": ["foo", "bar"], "name": [None, "bar"], "AGE": [42, 43], "score": [None, 1]}) assert_dataframes_equals(df_result, expected_result) diff --git a/server/tests/steps/test_moving_average.py b/server/tests/steps/test_moving_average.py index 6088a7f722..a75a446031 100644 --- a/server/tests/steps/test_moving_average.py +++ b/server/tests/steps/test_moving_average.py @@ -8,9 +8,7 @@ def test_moving_average_basic(): - df = DataFrame( - {"date": [f"2018-01-0{i}" for i in range(1, 9)], "value": [75, 80, 82, 83, 80, 86, 79, 76]} - ) + df = DataFrame({"date": [f"2018-01-0{i}" for i in range(1, 9)], "value": [75, 80, 82, 83, 80, 86, 79, 76]}) df["date"] = pd.to_datetime(df["date"]) step = MovingAverageStep( @@ -21,9 +19,7 @@ def test_moving_average_basic(): ) df_result = execute_moving_average(step, df) - expected_result = df.assign( - **{"value_MOVING_AVG": [None, 77.5, 81, 82.5, 81.5, 83, 82.5, 77.5]} - ) + expected_result = df.assign(**{"value_MOVING_AVG": [None, 77.5, 81, 82.5, 81.5, 83, 82.5, 77.5]}) assert_dataframes_equals(df_result, expected_result) @@ -48,10 +44,7 @@ def test_moving_average_with_groups(): df_result = execute_moving_average(step, df) expected_result = df.assign( - **{ - "rolling_average": [None, None, 79, 81.6667, 81.6667, 83] - + [None, None, 71.6667, 73.6667, 72.6667, 73.6667] - } + **{"rolling_average": [None, None, 79, 81.6667, 81.6667, 83] + [None, None, 71.6667, 73.6667, 72.6667, 73.6667]} ) assert_dataframes_equals(df_result, expected_result) diff --git a/server/tests/steps/test_percentage.py b/server/tests/steps/test_percentage.py index 6fee16601f..7fc3cdf3c1 100644 --- a/server/tests/steps/test_percentage.py +++ b/server/tests/steps/test_percentage.py @@ -19,9 +19,7 @@ def test_simple_percentage(): def test_percentage_with_groups(): sample_df = pd.DataFrame({"a_bool": [True, False, True, False], "values": [50, 25, 50, 75]}) - step = PercentageStep( - name="percentage", column="values", group=["a_bool"], newColumnName="result" - ) + step = PercentageStep(name="percentage", column="values", group=["a_bool"], newColumnName="result") result = execute_percentage(step, sample_df) expected_df = pd.DataFrame( { diff --git a/server/tests/steps/test_rank.py b/server/tests/steps/test_rank.py index 55b8953bab..3e7d74f490 100644 --- a/server/tests/steps/test_rank.py +++ b/server/tests/steps/test_rank.py @@ -11,9 +11,7 @@ @pytest.fixture def sample_df(): - return DataFrame( - {"COUNTRY": ["France"] * 3 + ["USA"] * 4, "VALUE": [10, 20, 30, 10, 40, 30, 50]} - ) + return DataFrame({"COUNTRY": ["France"] * 3 + ["USA"] * 4, "VALUE": [10, 20, 30, 10, 40, 30, 50]}) def test_rank(sample_df: DataFrame): diff --git a/server/tests/steps/test_replacetext.py b/server/tests/steps/test_replacetext.py index 751fe822e4..ba4cf44645 100644 --- a/server/tests/steps/test_replacetext.py +++ b/server/tests/steps/test_replacetext.py @@ -15,9 +15,7 @@ def sample_df() -> pd.DataFrame: def test_simple_replace(sample_df): - step = ReplaceTextStep( - name="replacetext", search_column="values", old_str="FR", new_str="France" - ) + step = ReplaceTextStep(name="replacetext", search_column="values", old_str="FR", new_str="France") result = execute_replacetext(step, sample_df) expected_df = pd.DataFrame({"values": ["France", "a string with France in it", "UK"]}) diff --git a/server/tests/steps/test_substring.py b/server/tests/steps/test_substring.py index b542eac195..1977364f88 100644 --- a/server/tests/steps/test_substring.py +++ b/server/tests/steps/test_substring.py @@ -62,9 +62,7 @@ def test_substring_negative_start_negative_end(sample_df): def test_substring_new_column_name(sample_df): - step = SubstringStep( - name="substring", column="Label", start_index=-3, end_index=-1, newColumnName="FOO" - ) + step = SubstringStep(name="substring", column="Label", start_index=-3, end_index=-1, newColumnName="FOO") result_df = execute_substring(step, sample_df) expected_df = pd.DataFrame( @@ -87,7 +85,5 @@ def test_benchmark_substring(benchmark): } ) - step = SubstringStep( - name="substring", column="group", start_index=0, end_index=3, newColumnName="FOO" - ) + step = SubstringStep(name="substring", column="group", start_index=0, end_index=3, newColumnName="FOO") benchmark(execute_substring, step, df) diff --git a/server/tests/steps/test_todate.py b/server/tests/steps/test_todate.py index 55b27d7ccd..c8e26795cf 100644 --- a/server/tests/steps/test_todate.py +++ b/server/tests/steps/test_todate.py @@ -15,10 +15,7 @@ def test_benchmark_sort(benchmark): { "value": np.random.random(1000), "id": list(range(1000)), - "date": [ - (today + datetime.timedelta(days=nb_day)).strftime("%Y-%m-%d") - for nb_day in list(range(1000)) - ], + "date": [(today + datetime.timedelta(days=nb_day)).strftime("%Y-%m-%d") for nb_day in list(range(1000))], "group": [random.choice(groups) for _ in range(1000)], } ) diff --git a/server/tests/steps/test_totals.py b/server/tests/steps/test_totals.py index 219ad0b284..ec40c1a9de 100644 --- a/server/tests/steps/test_totals.py +++ b/server/tests/steps/test_totals.py @@ -29,9 +29,7 @@ def test_single_totals_without_groups(): expected_result = pd.concat( [ sample_df, - pd.DataFrame( - {"COUNTRY": "All countries", "PRODUCT": [None], "YEAR": [None], "VALUE": [135]} - ), + pd.DataFrame({"COUNTRY": "All countries", "PRODUCT": [None], "YEAR": [None], "VALUE": [135]}), ] ) @@ -85,9 +83,7 @@ def test_totals_2(): pd.DataFrame( { "COUNTRY": ["USA", "France"] * 2 + ["All countries"] * 6, - "PRODUCT": ["All products"] * 4 - + ["product B", "product A"] * 2 - + ["All products"] * 2, + "PRODUCT": ["All products"] * 4 + ["product B", "product A"] * 2 + ["All products"] * 2, "YEAR": (["2020"] * 2 + ["2019"] * 2) * 2 + ["2020", "2019"], "VALUE_2": [450, 500, 250, 150, 550, 400, 250, 150, 950, 400], "VALUE_1-sum": [45, 50, 25, 15, 55, 40, 25, 15, 95, 40], @@ -141,9 +137,7 @@ def test_total_must_contains_aggregation(): with pytest.raises(ValueError): TotalsStep( name="totals", - totalDimensions=[ - TotalDimension(total_column="COUNTRY", total_rows_label="All countries") - ], + totalDimensions=[TotalDimension(total_column="COUNTRY", total_rows_label="All countries")], aggregations=[], groups=[], ) diff --git a/server/tests/steps/test_unpivot.py b/server/tests/steps/test_unpivot.py index fcf5bd6182..12901e4103 100644 --- a/server/tests/steps/test_unpivot.py +++ b/server/tests/steps/test_unpivot.py @@ -69,10 +69,7 @@ def test_unpivot_with_dropna_false(sample_df: DataFrame): result = execute_unpivot(step, sample_df, domain_retriever=None, execute_pipeline=None) expected_result = DataFrame( { - "COMPANY": ["Company 1"] * 2 - + ["Company 2"] * 2 - + ["Company 1"] * 2 - + ["Company 2"] * 2, + "COMPANY": ["Company 1"] * 2 + ["Company 2"] * 2 + ["Company 1"] * 2 + ["Company 2"] * 2, "COUNTRY": ["France"] * 4 + ["USA"] * 4, "KPI": ["NB_CLIENTS", "REVENUES"] * 4, "VALUE": [7, 10, 2, None, 12, 6, 1, 3], diff --git a/server/tests/test_pipeline.py b/server/tests/test_pipeline.py index c867118115..d3ccf0eab3 100644 --- a/server/tests/test_pipeline.py +++ b/server/tests/test_pipeline.py @@ -58,9 +58,7 @@ def get_render_variables_test_cases(): def test_step_with_variables(case: Case): pipeline_with_variables = PipelineWithVariables(**case.data) - pipeline = pipeline_with_variables.render( - case.context, renderer=nosql_apply_parameters_to_query - ) + pipeline = pipeline_with_variables.render(case.context, renderer=nosql_apply_parameters_to_query) expected_result = Pipeline(steps=case.expected_result) assert pipeline == expected_result @@ -246,11 +244,7 @@ def test_skip_void_parameter_from_variables(): {"name": "filter", "condition": {"column": "colB", "operator": "eq", "value": 32}}, { "name": "filter", - "condition": { - "and_": [ - {"column": "ColD", "operator": "until", "value": datetime(2009, 1, 3, 0, 0)} - ] - }, + "condition": {"and_": [{"column": "ColD", "operator": "until", "value": datetime(2009, 1, 3, 0, 0)}]}, }, ] @@ -301,13 +295,7 @@ def test_skip_void_parameter_from_variables_for_mongo_steps(): "pipeline": [], } }, - { - "$project": { - "_vqbPipelinesUnion": { - "$concatArrays": ["$_vqbPipelineInline", "$_vqbPipelineToAppend_0"] - } - } - }, + {"$project": {"_vqbPipelinesUnion": {"$concatArrays": ["$_vqbPipelineInline", "$_vqbPipelineToAppend_0"]}}}, {"$unwind": "$_vqbPipelinesUnion"}, {"$replaceRoot": {"newRoot": "$_vqbPipelinesUnion"}}, {"$project": {"_id": 0}}, @@ -316,13 +304,7 @@ def test_skip_void_parameter_from_variables_for_mongo_steps(): {"$match": {"$eq": None}}, {"$group": {"_id": None, "_vqbPipelineInline": {"$push": "$$ROOT"}}}, {"$lookup": {"as": "_vqbPipelineToAppend_0", "from": "slide_data-append", "pipeline": []}}, - { - "$project": { - "_vqbPipelinesUnion": { - "$concatArrays": ["$_vqbPipelineInline", "$_vqbPipelineToAppend_0"] - } - } - }, + {"$project": {"_vqbPipelinesUnion": {"$concatArrays": ["$_vqbPipelineInline", "$_vqbPipelineToAppend_0"]}}}, {"$unwind": "$_vqbPipelinesUnion"}, {"$replaceRoot": {"newRoot": "$_vqbPipelinesUnion"}}, {"$project": {"_id": 0}}, diff --git a/server/tests/test_pipeline_executor.py b/server/tests/test_pipeline_executor.py index 913b22685c..855ab7cca4 100644 --- a/server/tests/test_pipeline_executor.py +++ b/server/tests/test_pipeline_executor.py @@ -69,9 +69,7 @@ def test_preview_pipeline_limit(pipeline_previewer): ), limit=1, ) - assert json.loads(result)["data"] == [ - {"colA": "toto", "colB": 1, "colC": 100} - ] # first row of the data frame + assert json.loads(result)["data"] == [{"colA": "toto", "colB": 1, "colC": 100}] # first row of the data frame def test_preview_pipeline_limit_offset(pipeline_previewer): @@ -127,9 +125,7 @@ def test_rename(pipeline_executor): assert_dataframes_equals( df, - pd.DataFrame( - {"col_a": ["toto", "tutu", "tata"], "col_b": [1, 2, 3], "colC": [100, 50, 25]} - ), + pd.DataFrame({"col_a": ["toto", "tutu", "tata"], "col_b": [1, 2, 3], "colC": [100, 50, 25]}), )