Skip to content

Commit

Permalink
Fix #19489: Optimise multithreading for lineage (#19524)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulixius9 authored Jan 27, 2025
1 parent 8117586 commit d2dc7bd
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 173 deletions.
43 changes: 21 additions & 22 deletions ingestion/src/metadata/ingestion/lineage/masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

import traceback

import sqlparse
from sqlfluff.core import Linter
from collate_sqllineage.runner import SQLPARSE_DIALECT, LineageRunner
from sqlparse.sql import Comparison
from sqlparse.tokens import Literal, Number, String

Expand All @@ -24,25 +23,22 @@
MASK_TOKEN = "?"


# pylint: disable=protected-access
def get_logger():
# pylint: disable=import-outside-toplevel
from metadata.utils.logger import utils_logger

return utils_logger()


def mask_literals_with_sqlparse(query: str):
def mask_literals_with_sqlparse(query: str, parser: LineageRunner):
"""
Mask literals in a query using sqlparse.
"""
logger = get_logger()

try:
parsed = sqlparse.parse(query) # Parse the query

if not parsed:
return query
parsed = parsed[0]
parsed = parser._parsed_result

def mask_token(token):
# Mask all literals: strings, numbers, or other literal values
Expand Down Expand Up @@ -79,17 +75,16 @@ def mask_token(token):
return query


def mask_literals_with_sqlfluff(query: str, dialect: str = Dialect.ANSI.value) -> str:
def mask_literals_with_sqlfluff(query: str, parser: LineageRunner) -> str:
"""
Mask literals in a query using SQLFluff.
"""
logger = get_logger()
try:
# Initialize SQLFluff linter
linter = Linter(dialect=dialect)
if not parser._evaluated:
parser._eval()

# Parse the query
parsed = linter.parse_string(query)
parsed = parser._parsed_result

def replace_literals(segment):
"""Recursively replace literals with placeholders."""
Expand All @@ -114,17 +109,21 @@ def replace_literals(segment):
return query


def mask_query(query: str, dialect: str = Dialect.ANSI.value) -> str:
def mask_query(
query: str, dialect: str = Dialect.ANSI.value, parser: LineageRunner = None
) -> str:
logger = get_logger()
try:
sqlfluff_masked_query = mask_literals_with_sqlfluff(query, dialect)
sqlparse_masked_query = mask_literals_with_sqlparse(query)
# compare both masked queries and return the one with more masked tokens
if sqlfluff_masked_query.count(MASK_TOKEN) >= sqlparse_masked_query.count(
MASK_TOKEN
):
return sqlfluff_masked_query
return sqlparse_masked_query
if not parser:
try:
parser = LineageRunner(query, dialect=dialect)
len(parser.source_tables)
except Exception:
parser = LineageRunner(query)
len(parser.source_tables)
if parser._dialect == SQLPARSE_DIALECT:
return mask_literals_with_sqlparse(query, parser)
return mask_literals_with_sqlfluff(query, parser)
except Exception as exc:
logger.debug(f"Failed to mask query with sqlfluff: {exc}")
logger.debug(traceback.format_exc())
Expand Down
45 changes: 16 additions & 29 deletions ingestion/src/metadata/ingestion/lineage/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,13 @@ def __init__(
self.query_parsing_success = True
self.query_parsing_failure_reason = None
self.dialect = dialect
self._masked_query = mask_query(self.query, dialect.value)
self.masked_query = None
self._clean_query = self.clean_raw_query(query)
self._masked_clean_query = mask_query(self._clean_query, dialect.value)
self.parser = self._evaluate_best_parser(
self._clean_query, dialect=dialect, timeout_seconds=timeout_seconds
)
if self.masked_query is None:
self.masked_query = mask_query(self._clean_query, parser=self.parser)

@cached_property
def involved_tables(self) -> Optional[List[Table]]:
Expand All @@ -95,7 +96,7 @@ def involved_tables(self) -> Optional[List[Table]]:
except SQLLineageException as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Cannot extract source table information from query [{self._masked_query}]: {exc}"
f"Cannot extract source table information from query [{self.masked_query or self.query}]: {exc}"
)
return None

Expand Down Expand Up @@ -334,12 +335,10 @@ def stateful_add_joins_from_statement(
)

if not table_left or not table_right:
logger.warning(
f"Can't extract table names when parsing JOIN information from {comparison}"
)
logger.debug(
f"Query: {mask_query(sql_statement, self.dialect.value)}"
f"Can't extract table names when parsing JOIN information from {comparison}"
)
logger.debug(f"Query: {self.masked_query}")
continue

left_table_column = TableColumn(table=table_left, column=column_left)
Expand Down Expand Up @@ -422,10 +421,9 @@ def get_sqlfluff_lineage_runner(qry: str, dlct: str) -> LineageRunner:
lr_dialect.get_column_lineage()
return lr_dialect

sqlfluff_count = 0
try:
lr_sqlfluff = get_sqlfluff_lineage_runner(query, dialect.value)
sqlfluff_count = len(lr_sqlfluff.get_column_lineage()) + len(
_ = len(lr_sqlfluff.get_column_lineage()) + len(
set(lr_sqlfluff.source_tables).union(
set(lr_sqlfluff.target_tables).union(
set(lr_sqlfluff.intermediate_tables)
Expand All @@ -438,23 +436,20 @@ def get_sqlfluff_lineage_runner(qry: str, dlct: str) -> LineageRunner:
f"Lineage with SqlFluff failed for the [{dialect.value}]. "
f"Parser has been running for more than {timeout_seconds} seconds."
)
logger.debug(
f"{self.query_parsing_failure_reason}] query: [{self._masked_clean_query}]"
)
lr_sqlfluff = None
except Exception:
self.query_parsing_success = False
self.query_parsing_failure_reason = (
f"Lineage with SqlFluff failed for the [{dialect.value}]"
)
logger.debug(
f"{self.query_parsing_failure_reason} query: [{self._masked_clean_query}]"
)
lr_sqlfluff = None

if lr_sqlfluff:
return lr_sqlfluff

lr_sqlparser = LineageRunner(query)
try:
sqlparser_count = len(lr_sqlparser.get_column_lineage()) + len(
_ = len(lr_sqlparser.get_column_lineage()) + len(
set(lr_sqlparser.source_tables).union(
set(lr_sqlparser.target_tables).union(
set(lr_sqlparser.intermediate_tables)
Expand All @@ -463,21 +458,13 @@ def get_sqlfluff_lineage_runner(qry: str, dlct: str) -> LineageRunner:
)
except Exception:
# if both runner have failed we return the usual one
logger.debug(f"Failed to parse query with sqlparse & sqlfluff: {query}")
return lr_sqlfluff if lr_sqlfluff else lr_sqlparser

if lr_sqlfluff:
# if sqlparser retrieve more lineage info that sqlfluff
if sqlparser_count > sqlfluff_count:
self.query_parsing_success = False
self.query_parsing_failure_reason = (
"Lineage computed with SqlFluff did not perform as expected "
f"for the [{dialect.value}]"
)
logger.debug(
f"{self.query_parsing_failure_reason} query: [{self._masked_clean_query}]"
)
return lr_sqlparser
return lr_sqlfluff
self.masked_query = mask_query(self._clean_query, parser=lr_sqlparser)
logger.debug(
f"Using sqlparse for lineage parsing for query: {self.masked_query}"
)
return lr_sqlparser

@staticmethod
Expand Down
9 changes: 4 additions & 5 deletions ingestion/src/metadata/ingestion/lineage/sql_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from metadata.generated.schema.type.entityLineage import Source as LineageSource
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.models import Either
from metadata.ingestion.lineage.masker import mask_query
from metadata.ingestion.lineage.models import (
Dialect,
QueryParsingError,
Expand Down Expand Up @@ -614,11 +613,11 @@ def get_lineage_by_query(
"""
column_lineage = {}
query_parsing_failures = QueryParsingFailures()
masked_query = mask_query(query, dialect.value)

try:
logger.debug(f"Running lineage with query: {masked_query}")
lineage_parser = LineageParser(query, dialect, timeout_seconds=timeout_seconds)
masked_query = lineage_parser.masked_query or query
logger.debug(f"Running lineage with query: {masked_query}")

raw_column_lineage = lineage_parser.column_lineage
column_lineage.update(populate_column_lineage_map(raw_column_lineage))
Expand Down Expand Up @@ -715,11 +714,11 @@ def get_lineage_via_table_entity(
"""Get lineage from table entity"""
column_lineage = {}
query_parsing_failures = QueryParsingFailures()
masked_query = mask_query(query, dialect.value)

try:
logger.debug(f"Getting lineage via table entity using query: {masked_query}")
lineage_parser = LineageParser(query, dialect, timeout_seconds=timeout_seconds)
masked_query = lineage_parser.masked_query or query
logger.debug(f"Getting lineage via table entity using query: {masked_query}")
to_table_name = table_entity.name.root

for from_table_name in lineage_parser.source_tables:
Expand Down
17 changes: 9 additions & 8 deletions ingestion/src/metadata/ingestion/ometa/mixins/es_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _paginate_es_internal(
# Get next page
last_hit = response.hits.hits[-1] if response.hits.hits else None
if not last_hit or not last_hit.sort:
logger.info("No more pages to fetch")
logger.debug("No more pages to fetch")
break

after = ",".join(last_hit.sort)
Expand Down Expand Up @@ -429,10 +429,11 @@ def yield_es_view_def(
_, database_name, schema_name, table_name = fqn.split(
hit.source["fullyQualifiedName"]
)
yield TableView(
view_definition=hit.source["schemaDefinition"],
service_name=service_name,
db_name=database_name,
schema_name=schema_name,
table_name=table_name,
)
if hit.source.get("schemaDefinition"):
yield TableView(
view_definition=hit.source["schemaDefinition"],
service_name=service_name,
db_name=database_name,
schema_name=schema_name,
table_name=table_name,
)
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _merge_column_lineage(
)
for column in updated or []:
if not isinstance(column, dict):
data = column.dict()
data = column.model_dump()
else:
data = column
if data.get("toColumn") and data.get("fromColumns"):
Expand Down
Loading

0 comments on commit d2dc7bd

Please sign in to comment.