From d2dc7bd038935148abd61cc18a0d42375aafdaa9 Mon Sep 17 00:00:00 2001 From: Mayur Singal <39544459+ulixius9@users.noreply.github.com> Date: Mon, 27 Jan 2025 18:15:58 +0530 Subject: [PATCH] Fix #19489: Optimise multithreading for lineage (#19524) --- .../src/metadata/ingestion/lineage/masker.py | 43 +++-- .../src/metadata/ingestion/lineage/parser.py | 45 ++--- .../metadata/ingestion/lineage/sql_lineage.py | 9 +- .../ingestion/ometa/mixins/es_mixin.py | 17 +- .../ingestion/ometa/mixins/lineage_mixin.py | 2 +- .../source/database/lineage_source.py | 169 +++++++++++------- .../database/stored_procedures_mixin.py | 59 +++--- ingestion/src/metadata/utils/db_utils.py | 4 + ingestion/src/metadata/utils/logger.py | 11 +- ingestion/tests/unit/test_sql_lineage.py | 44 ++++- 10 files changed, 230 insertions(+), 173 deletions(-) diff --git a/ingestion/src/metadata/ingestion/lineage/masker.py b/ingestion/src/metadata/ingestion/lineage/masker.py index 49a5999e72c6..69aab2d7ba01 100644 --- a/ingestion/src/metadata/ingestion/lineage/masker.py +++ b/ingestion/src/metadata/ingestion/lineage/masker.py @@ -14,8 +14,7 @@ import traceback -import sqlparse -from sqlfluff.core import Linter +from collate_sqllineage.runner import SQLPARSE_DIALECT, LineageRunner from sqlparse.sql import Comparison from sqlparse.tokens import Literal, Number, String @@ -24,6 +23,7 @@ MASK_TOKEN = "?" +# pylint: disable=protected-access def get_logger(): # pylint: disable=import-outside-toplevel from metadata.utils.logger import utils_logger @@ -31,18 +31,14 @@ def get_logger(): return utils_logger() -def mask_literals_with_sqlparse(query: str): +def mask_literals_with_sqlparse(query: str, parser: LineageRunner): """ Mask literals in a query using sqlparse. """ logger = get_logger() try: - parsed = sqlparse.parse(query) # Parse the query - - if not parsed: - return query - parsed = parsed[0] + parsed = parser._parsed_result def mask_token(token): # Mask all literals: strings, numbers, or other literal values @@ -79,17 +75,16 @@ def mask_token(token): return query -def mask_literals_with_sqlfluff(query: str, dialect: str = Dialect.ANSI.value) -> str: +def mask_literals_with_sqlfluff(query: str, parser: LineageRunner) -> str: """ Mask literals in a query using SQLFluff. """ logger = get_logger() try: - # Initialize SQLFluff linter - linter = Linter(dialect=dialect) + if not parser._evaluated: + parser._eval() - # Parse the query - parsed = linter.parse_string(query) + parsed = parser._parsed_result def replace_literals(segment): """Recursively replace literals with placeholders.""" @@ -114,17 +109,21 @@ def replace_literals(segment): return query -def mask_query(query: str, dialect: str = Dialect.ANSI.value) -> str: +def mask_query( + query: str, dialect: str = Dialect.ANSI.value, parser: LineageRunner = None +) -> str: logger = get_logger() try: - sqlfluff_masked_query = mask_literals_with_sqlfluff(query, dialect) - sqlparse_masked_query = mask_literals_with_sqlparse(query) - # compare both masked queries and return the one with more masked tokens - if sqlfluff_masked_query.count(MASK_TOKEN) >= sqlparse_masked_query.count( - MASK_TOKEN - ): - return sqlfluff_masked_query - return sqlparse_masked_query + if not parser: + try: + parser = LineageRunner(query, dialect=dialect) + len(parser.source_tables) + except Exception: + parser = LineageRunner(query) + len(parser.source_tables) + if parser._dialect == SQLPARSE_DIALECT: + return mask_literals_with_sqlparse(query, parser) + return mask_literals_with_sqlfluff(query, parser) except Exception as exc: logger.debug(f"Failed to mask query with sqlfluff: {exc}") logger.debug(traceback.format_exc()) diff --git a/ingestion/src/metadata/ingestion/lineage/parser.py b/ingestion/src/metadata/ingestion/lineage/parser.py index f7fad5fe812a..93bae226d74f 100644 --- a/ingestion/src/metadata/ingestion/lineage/parser.py +++ b/ingestion/src/metadata/ingestion/lineage/parser.py @@ -71,12 +71,13 @@ def __init__( self.query_parsing_success = True self.query_parsing_failure_reason = None self.dialect = dialect - self._masked_query = mask_query(self.query, dialect.value) + self.masked_query = None self._clean_query = self.clean_raw_query(query) - self._masked_clean_query = mask_query(self._clean_query, dialect.value) self.parser = self._evaluate_best_parser( self._clean_query, dialect=dialect, timeout_seconds=timeout_seconds ) + if self.masked_query is None: + self.masked_query = mask_query(self._clean_query, parser=self.parser) @cached_property def involved_tables(self) -> Optional[List[Table]]: @@ -95,7 +96,7 @@ def involved_tables(self) -> Optional[List[Table]]: except SQLLineageException as exc: logger.debug(traceback.format_exc()) logger.warning( - f"Cannot extract source table information from query [{self._masked_query}]: {exc}" + f"Cannot extract source table information from query [{self.masked_query or self.query}]: {exc}" ) return None @@ -334,12 +335,10 @@ def stateful_add_joins_from_statement( ) if not table_left or not table_right: - logger.warning( - f"Can't extract table names when parsing JOIN information from {comparison}" - ) logger.debug( - f"Query: {mask_query(sql_statement, self.dialect.value)}" + f"Can't extract table names when parsing JOIN information from {comparison}" ) + logger.debug(f"Query: {self.masked_query}") continue left_table_column = TableColumn(table=table_left, column=column_left) @@ -422,10 +421,9 @@ def get_sqlfluff_lineage_runner(qry: str, dlct: str) -> LineageRunner: lr_dialect.get_column_lineage() return lr_dialect - sqlfluff_count = 0 try: lr_sqlfluff = get_sqlfluff_lineage_runner(query, dialect.value) - sqlfluff_count = len(lr_sqlfluff.get_column_lineage()) + len( + _ = len(lr_sqlfluff.get_column_lineage()) + len( set(lr_sqlfluff.source_tables).union( set(lr_sqlfluff.target_tables).union( set(lr_sqlfluff.intermediate_tables) @@ -438,23 +436,20 @@ def get_sqlfluff_lineage_runner(qry: str, dlct: str) -> LineageRunner: f"Lineage with SqlFluff failed for the [{dialect.value}]. " f"Parser has been running for more than {timeout_seconds} seconds." ) - logger.debug( - f"{self.query_parsing_failure_reason}] query: [{self._masked_clean_query}]" - ) lr_sqlfluff = None except Exception: self.query_parsing_success = False self.query_parsing_failure_reason = ( f"Lineage with SqlFluff failed for the [{dialect.value}]" ) - logger.debug( - f"{self.query_parsing_failure_reason} query: [{self._masked_clean_query}]" - ) lr_sqlfluff = None + if lr_sqlfluff: + return lr_sqlfluff + lr_sqlparser = LineageRunner(query) try: - sqlparser_count = len(lr_sqlparser.get_column_lineage()) + len( + _ = len(lr_sqlparser.get_column_lineage()) + len( set(lr_sqlparser.source_tables).union( set(lr_sqlparser.target_tables).union( set(lr_sqlparser.intermediate_tables) @@ -463,21 +458,13 @@ def get_sqlfluff_lineage_runner(qry: str, dlct: str) -> LineageRunner: ) except Exception: # if both runner have failed we return the usual one + logger.debug(f"Failed to parse query with sqlparse & sqlfluff: {query}") return lr_sqlfluff if lr_sqlfluff else lr_sqlparser - if lr_sqlfluff: - # if sqlparser retrieve more lineage info that sqlfluff - if sqlparser_count > sqlfluff_count: - self.query_parsing_success = False - self.query_parsing_failure_reason = ( - "Lineage computed with SqlFluff did not perform as expected " - f"for the [{dialect.value}]" - ) - logger.debug( - f"{self.query_parsing_failure_reason} query: [{self._masked_clean_query}]" - ) - return lr_sqlparser - return lr_sqlfluff + self.masked_query = mask_query(self._clean_query, parser=lr_sqlparser) + logger.debug( + f"Using sqlparse for lineage parsing for query: {self.masked_query}" + ) return lr_sqlparser @staticmethod diff --git a/ingestion/src/metadata/ingestion/lineage/sql_lineage.py b/ingestion/src/metadata/ingestion/lineage/sql_lineage.py index c8954fad6f0f..f2876197ac96 100644 --- a/ingestion/src/metadata/ingestion/lineage/sql_lineage.py +++ b/ingestion/src/metadata/ingestion/lineage/sql_lineage.py @@ -37,7 +37,6 @@ from metadata.generated.schema.type.entityLineage import Source as LineageSource from metadata.generated.schema.type.entityReference import EntityReference from metadata.ingestion.api.models import Either -from metadata.ingestion.lineage.masker import mask_query from metadata.ingestion.lineage.models import ( Dialect, QueryParsingError, @@ -614,11 +613,11 @@ def get_lineage_by_query( """ column_lineage = {} query_parsing_failures = QueryParsingFailures() - masked_query = mask_query(query, dialect.value) try: - logger.debug(f"Running lineage with query: {masked_query}") lineage_parser = LineageParser(query, dialect, timeout_seconds=timeout_seconds) + masked_query = lineage_parser.masked_query or query + logger.debug(f"Running lineage with query: {masked_query}") raw_column_lineage = lineage_parser.column_lineage column_lineage.update(populate_column_lineage_map(raw_column_lineage)) @@ -715,11 +714,11 @@ def get_lineage_via_table_entity( """Get lineage from table entity""" column_lineage = {} query_parsing_failures = QueryParsingFailures() - masked_query = mask_query(query, dialect.value) try: - logger.debug(f"Getting lineage via table entity using query: {masked_query}") lineage_parser = LineageParser(query, dialect, timeout_seconds=timeout_seconds) + masked_query = lineage_parser.masked_query or query + logger.debug(f"Getting lineage via table entity using query: {masked_query}") to_table_name = table_entity.name.root for from_table_name in lineage_parser.source_tables: diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/es_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/es_mixin.py index 493aadc4214d..fbc315ffdafa 100644 --- a/ingestion/src/metadata/ingestion/ometa/mixins/es_mixin.py +++ b/ingestion/src/metadata/ingestion/ometa/mixins/es_mixin.py @@ -344,7 +344,7 @@ def _paginate_es_internal( # Get next page last_hit = response.hits.hits[-1] if response.hits.hits else None if not last_hit or not last_hit.sort: - logger.info("No more pages to fetch") + logger.debug("No more pages to fetch") break after = ",".join(last_hit.sort) @@ -429,10 +429,11 @@ def yield_es_view_def( _, database_name, schema_name, table_name = fqn.split( hit.source["fullyQualifiedName"] ) - yield TableView( - view_definition=hit.source["schemaDefinition"], - service_name=service_name, - db_name=database_name, - schema_name=schema_name, - table_name=table_name, - ) + if hit.source.get("schemaDefinition"): + yield TableView( + view_definition=hit.source["schemaDefinition"], + service_name=service_name, + db_name=database_name, + schema_name=schema_name, + table_name=table_name, + ) diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py index 45769f5935c4..5adc2c18296f 100644 --- a/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py +++ b/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py @@ -63,7 +63,7 @@ def _merge_column_lineage( ) for column in updated or []: if not isinstance(column, dict): - data = column.dict() + data = column.model_dump() else: data = column if data.get("toColumn") and data.get("fromColumns"): diff --git a/ingestion/src/metadata/ingestion/source/database/lineage_source.py b/ingestion/src/metadata/ingestion/source/database/lineage_source.py index 7148256a03e2..5b13a6516014 100644 --- a/ingestion/src/metadata/ingestion/source/database/lineage_source.py +++ b/ingestion/src/metadata/ingestion/source/database/lineage_source.py @@ -13,11 +13,12 @@ """ import csv import os +import time import traceback from abc import ABC -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor from functools import partial -from typing import Callable, Iterable, Iterator, List, Optional, Union +from typing import Any, Callable, Iterable, Iterator, List, Optional, Union from metadata.generated.schema.api.data.createQuery import CreateQueryRequest from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest @@ -39,6 +40,7 @@ from metadata.ingestion.lineage.models import ConnectionTypeDialectMapper, Dialect from metadata.ingestion.lineage.sql_lineage import get_column_fqn, get_lineage_by_query from metadata.ingestion.models.ometa_lineage import OMetaLineageRequest +from metadata.ingestion.models.topology import Queue from metadata.ingestion.source.database.query_parser_source import QueryParserSource from metadata.ingestion.source.models import TableView from metadata.utils import fqn @@ -48,6 +50,9 @@ logger = ingestion_logger() +CHUNK_SIZE = 200 + + class LineageSource(QueryParserSource, ABC): """ This is the base source to handle Lineage-only ingestion. @@ -108,27 +113,57 @@ def get_table_query(self) -> Iterator[TableQuery]: ) yield from self.yield_table_query() - def generate_lineage_in_thread(self, producer_fn: Callable, processor_fn: Callable): - with ThreadPoolExecutor(max_workers=self.source_config.threads) as executor: - futures = [] + def generate_lineage_in_thread( + self, + producer_fn: Callable[[], Iterable[Any]], + processor_fn: Callable[[Any], Iterable[Any]], + chunk_size: int = CHUNK_SIZE, + ): + """ + Optimized multithreaded lineage generation with improved error handling and performance. - for produced_input in producer_fn(): - futures.append(executor.submit(processor_fn, produced_input)) + Args: + producer_fn: Function that yields input items + processor_fn: Function to process each input item + chunk_size: Optional batching to reduce thread creation overhead + """ - # Handle remaining futures after the loop - for future in as_completed( - futures, timeout=self.source_config.parsingTimeoutLimit - ): - try: - results = future.result( - timeout=self.source_config.parsingTimeoutLimit - ) - yield from results - except Exception as exc: - logger.debug(traceback.format_exc()) - logger.warning( - f"Error processing result for {produced_input}: {exc}" - ) + def chunk_generator(): + temp_chunk = [] + for chunk in producer_fn(): + temp_chunk.append(chunk) + if len(temp_chunk) >= chunk_size: + yield temp_chunk + temp_chunk = [] + + if temp_chunk: + yield temp_chunk + + thread_pool = ThreadPoolExecutor(max_workers=self.source_config.threads) + queue = Queue() + + futures = [ + thread_pool.submit( + processor_fn, + chunk, + queue, + ) + for chunk in chunk_generator() + ] + while True: + if queue.has_tasks(): + yield from queue.process() + + else: + if not futures: + break + + for i, future in enumerate(futures): + if future.done(): + future.result() + futures.pop(i) + + time.sleep(0.01) def yield_table_query(self) -> Iterator[TableQuery]: """ @@ -170,33 +205,38 @@ def _query_already_processed(self, table_query: TableQuery) -> bool: return fqn.get_query_checksum(table_query.query) in checksums or {} def query_lineage_generator( - self, table_query: TableQuery + self, table_queries: List[TableQuery], queue: Queue ) -> Iterable[Either[Union[AddLineageRequest, CreateQueryRequest]]]: - if not self._query_already_processed(table_query): - lineages: Iterable[Either[AddLineageRequest]] = get_lineage_by_query( - self.metadata, - query=table_query.query, - service_name=table_query.serviceName, - database_name=table_query.databaseName, - schema_name=table_query.databaseSchema, - dialect=self.dialect, - timeout_seconds=self.source_config.parsingTimeoutLimit, - ) + for table_query in table_queries or []: + if not self._query_already_processed(table_query): + lineages: Iterable[Either[AddLineageRequest]] = get_lineage_by_query( + self.metadata, + query=table_query.query, + service_name=table_query.serviceName, + database_name=table_query.databaseName, + schema_name=table_query.databaseSchema, + dialect=self.dialect, + timeout_seconds=self.source_config.parsingTimeoutLimit, + ) - for lineage_request in lineages or []: - yield lineage_request - - # If we identified lineage properly, ingest the original query - if lineage_request.right: - yield Either( - right=CreateQueryRequest( - query=SqlQuery(table_query.query), - query_type=table_query.query_type, - duration=table_query.duration, - processedLineage=True, - service=FullyQualifiedEntityName(self.config.serviceName), + for lineage_request in lineages or []: + queue.put(lineage_request) + + # If we identified lineage properly, ingest the original query + if lineage_request.right: + queue.put( + Either( + right=CreateQueryRequest( + query=SqlQuery(table_query.query), + query_type=table_query.query_type, + duration=table_query.duration, + processedLineage=True, + service=FullyQualifiedEntityName( + self.config.serviceName + ), + ) + ) ) - ) def yield_query_lineage( self, @@ -209,28 +249,33 @@ def yield_query_lineage( self.dialect = ConnectionTypeDialectMapper.dialect_of(connection_type) producer_fn = self.get_table_query processor_fn = self.query_lineage_generator - yield from self.generate_lineage_in_thread(producer_fn, processor_fn) + yield from self.generate_lineage_in_thread( + producer_fn, processor_fn, CHUNK_SIZE + ) def view_lineage_generator( - self, view: TableView + self, views: List[TableView], queue: Queue ) -> Iterable[Either[AddLineageRequest]]: try: - for lineage in get_view_lineage( - view=view, - metadata=self.metadata, - service_name=self.config.serviceName, - connection_type=self.service_connection.type.value, - timeout_seconds=self.source_config.parsingTimeoutLimit, - ): - if lineage.right is not None: - yield Either( - right=OMetaLineageRequest( - lineage_request=lineage.right, - override_lineage=self.source_config.overrideViewLineage, + for view in views: + for lineage in get_view_lineage( + view=view, + metadata=self.metadata, + service_name=self.config.serviceName, + connection_type=self.service_connection.type.value, + timeout_seconds=self.source_config.parsingTimeoutLimit, + ): + if lineage.right is not None: + queue.put( + Either( + right=OMetaLineageRequest( + lineage_request=lineage.right, + override_lineage=self.source_config.overrideViewLineage, + ) + ) ) - ) - else: - yield lineage + else: + queue.put(lineage) except Exception as exc: logger.debug(traceback.format_exc()) logger.warning(f"Error processing view {view}: {exc}") diff --git a/ingestion/src/metadata/ingestion/source/database/stored_procedures_mixin.py b/ingestion/src/metadata/ingestion/source/database/stored_procedures_mixin.py index 5473f8b17009..f7190c696811 100644 --- a/ingestion/src/metadata/ingestion/source/database/stored_procedures_mixin.py +++ b/ingestion/src/metadata/ingestion/source/database/stored_procedures_mixin.py @@ -38,10 +38,11 @@ from metadata.ingestion.api.status import Status from metadata.ingestion.lineage.models import ConnectionTypeDialectMapper from metadata.ingestion.lineage.sql_lineage import get_lineage_by_query +from metadata.ingestion.models.topology import Queue from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.utils.logger import ingestion_logger from metadata.utils.stored_procedures import get_procedure_name_from_call -from metadata.utils.time_utils import convert_timestamp_to_milliseconds +from metadata.utils.time_utils import datetime_to_timestamp logger = ingestion_logger() @@ -176,8 +177,6 @@ def _yield_procedure_lineage( timeout_seconds=self.source_config.parsingTimeoutLimit, lineage_source=LineageSource.QueryLineage, ): - print("&& " * 100) - print(either_lineage) if ( either_lineage.left is None and either_lineage.right.edge.lineageDetails @@ -200,8 +199,8 @@ def yield_procedure_query( query_type=query_by_procedure.query_type, duration=query_by_procedure.query_duration, queryDate=Timestamp( - root=convert_timestamp_to_milliseconds( - int(query_by_procedure.query_start_time.timestamp()) + root=datetime_to_timestamp( + query_by_procedure.query_start_time, True ) ), triggeredBy=EntityReference( @@ -214,29 +213,31 @@ def yield_procedure_query( ) def procedure_lineage_processor( - self, procedure_and_query: ProcedureAndQuery + self, procedure_and_queries: List[ProcedureAndQuery], queue: Queue ) -> Iterable[Either[Union[AddLineageRequest, CreateQueryRequest]]]: - - try: - yield from self._yield_procedure_lineage( - query_by_procedure=procedure_and_query.query_by_procedure, - procedure=procedure_and_query.procedure, - ) - except Exception as exc: - logger.debug(traceback.format_exc()) - logger.warning( - f"Could not get lineage for store procedure '{procedure_and_query.procedure.fullyQualifiedName}' due to [{exc}]." - ) - try: - yield from self.yield_procedure_query( - query_by_procedure=procedure_and_query.query_by_procedure, - procedure=procedure_and_query.procedure, - ) - except Exception as exc: - logger.debug(traceback.format_exc()) - logger.warning( - f"Could not get query for store procedure '{procedure_and_query.procedure.fullyQualifiedName}' due to [{exc}]." - ) + for procedure_and_query in procedure_and_queries: + try: + for lineage in self._yield_procedure_lineage( + query_by_procedure=procedure_and_query.query_by_procedure, + procedure=procedure_and_query.procedure, + ): + queue.put(lineage) + except Exception as exc: + logger.debug(traceback.format_exc()) + logger.warning( + f"Could not get lineage for store procedure '{procedure_and_query.procedure.fullyQualifiedName}' due to [{exc}]." + ) + try: + for lineage in self.yield_procedure_query( + query_by_procedure=procedure_and_query.query_by_procedure, + procedure=procedure_and_query.procedure, + ): + queue.put(lineage) + except Exception as exc: + logger.debug(traceback.format_exc()) + logger.warning( + f"Could not get query for store procedure '{procedure_and_query.procedure.fullyQualifiedName}' due to [{exc}]." + ) def procedure_lineage_generator(self) -> Iterable[ProcedureAndQuery]: query = { @@ -256,7 +257,9 @@ def procedure_lineage_generator(self) -> Iterable[ProcedureAndQuery]: queries_dict = self.get_stored_procedure_queries_dict() # Then for each procedure, iterate over all its queries for procedure in ( - self.metadata.paginate_es(entity=StoredProcedure, query_filter=query_filter) + self.metadata.paginate_es( + entity=StoredProcedure, query_filter=query_filter, size=10 + ) or [] ): if procedure: diff --git a/ingestion/src/metadata/utils/db_utils.py b/ingestion/src/metadata/utils/db_utils.py index ef3a7c8d6088..afd1e4832940 100644 --- a/ingestion/src/metadata/utils/db_utils.py +++ b/ingestion/src/metadata/utils/db_utils.py @@ -69,6 +69,10 @@ def get_view_lineage( fqn=table_fqn, ) + if not view_definition: + logger.warning(f"View definition for view {table_fqn} not available") + return + try: connection_type = str(connection_type) dialect = ConnectionTypeDialectMapper.dialect_of(connection_type) diff --git a/ingestion/src/metadata/utils/logger.py b/ingestion/src/metadata/utils/logger.py index f437bb4d3fcd..08d648d8c2fd 100644 --- a/ingestion/src/metadata/utils/logger.py +++ b/ingestion/src/metadata/utils/logger.py @@ -28,7 +28,6 @@ from metadata.generated.schema.type.queryParserData import QueryParserData from metadata.generated.schema.type.tableQuery import TableQueries from metadata.ingestion.api.models import Entity -from metadata.ingestion.lineage.masker import mask_query from metadata.ingestion.models.delete_entity import DeleteEntity from metadata.ingestion.models.life_cycle import OMetaLifeCycleData from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification @@ -284,19 +283,13 @@ def _(record: PatchRequest) -> str: @get_log_name.register def _(record: TableQueries) -> str: """Get the log of the TableQuery""" - queries = "\n------\n".join( - mask_query(query.query, query.dialect) for query in record.queries - ) - return f"Table Queries [{queries}]" + return f"Table Queries [{len(record.queries)}]" @get_log_name.register def _(record: QueryParserData) -> str: """Get the log of the ParsedData""" - queries = "\n------\n".join( - mask_query(query.sql, query.dialect) for query in record.parsedData - ) - return f"Usage ParsedData [{queries}]" + return f"Usage ParsedData [{len(record.parsedData)}]" def redacted_config(config: Dict[str, Union[str, dict]]) -> Dict[str, Union[str, dict]]: diff --git a/ingestion/tests/unit/test_sql_lineage.py b/ingestion/tests/unit/test_sql_lineage.py index 4199cf84a5f8..ffe07a0557a1 100644 --- a/ingestion/tests/unit/test_sql_lineage.py +++ b/ingestion/tests/unit/test_sql_lineage.py @@ -229,14 +229,39 @@ def test_table_name_from_query(self): def test_query_masker(self): query_list = [ - """SELECT * FROM user WHERE id=1234 AND name='Alice' AND birthdate=DATE '2023-01-01';""", - """insert into user values ('mayur',123,'my random address 1'), ('mayur',123,'my random address 1');""", - """SELECT * FROM user WHERE address = '5th street' and name = 'john';""", - """INSERT INTO user VALUE ('John', '19', '5TH Street');""", - """SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user;""", - """with test as (SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user) select * from test;""", - """select * from (select * from (SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user));""", - """select * from users where id > 2 and name <> 'pere';""", + ( + """SELECT * FROM user WHERE id=1234 AND name='Alice' AND birthdate=DATE '2023-01-01';""", + Dialect.MYSQL.value, + ), + ( + """insert into user values ('mayur',123,'my random address 1'), ('mayur',123,'my random address 1');""", + Dialect.ANSI.value, + ), + ( + """SELECT * FROM user WHERE address = '5th street' and name = 'john';""", + Dialect.ANSI.value, + ), + ( + """INSERT INTO user VALUE ('John', '19', '5TH Street');""", + Dialect.ANSI.value, + ), + ( + """SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user;""", + Dialect.ANSI.value, + ), + ( + """with test as (SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user) select * from test;""", + Dialect.ANSI.value, + ), + ( + """select * from (select * from (SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user));""", + Dialect.ANSI.value, + ), + ( + """select * from users where id > 2 and name <> 'pere';""", + Dialect.ANSI.value, + ), + ("""select * from users where id > 2 and name <> 'pere';""", "random"), ] expected_query_list = [ @@ -248,7 +273,8 @@ def test_query_masker(self): """with test as (SELECT CASE address WHEN ? THEN ? ELSE ? END AS person FROM user) select * from test;""", """select * from (select * from (SELECT CASE address WHEN ? THEN ? ELSE ? END AS person FROM user));""", """select * from users where id > ? and name <> ?;""", + """select * from users where id > ? and name <> ?;""", ] for i, query in enumerate(query_list): - self.assertEqual(mask_query(query), expected_query_list[i]) + self.assertEqual(mask_query(query[0], query[1]), expected_query_list[i])