Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #19489: Optimise multithreading for lineage #19524

Merged
merged 4 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions ingestion/src/metadata/ingestion/lineage/masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

import traceback

import sqlparse
from sqlfluff.core import Linter
from collate_sqllineage.runner import SQLPARSE_DIALECT, LineageRunner
from sqlparse.sql import Comparison
from sqlparse.tokens import Literal, Number, String

Expand All @@ -24,25 +23,22 @@
MASK_TOKEN = "?"


# pylint: disable=protected-access
def get_logger():
# pylint: disable=import-outside-toplevel
from metadata.utils.logger import utils_logger

return utils_logger()


def mask_literals_with_sqlparse(query: str):
def mask_literals_with_sqlparse(query: str, parser: LineageRunner):
"""
Mask literals in a query using sqlparse.
"""
logger = get_logger()

try:
parsed = sqlparse.parse(query) # Parse the query

if not parsed:
return query
parsed = parsed[0]
parsed = parser._parsed_result

def mask_token(token):
# Mask all literals: strings, numbers, or other literal values
Expand Down Expand Up @@ -79,17 +75,16 @@ def mask_token(token):
return query


def mask_literals_with_sqlfluff(query: str, dialect: str = Dialect.ANSI.value) -> str:
def mask_literals_with_sqlfluff(query: str, parser: LineageRunner) -> str:
"""
Mask literals in a query using SQLFluff.
"""
logger = get_logger()
try:
# Initialize SQLFluff linter
linter = Linter(dialect=dialect)
if not parser._evaluated:
parser._eval()

# Parse the query
parsed = linter.parse_string(query)
parsed = parser._parsed_result

def replace_literals(segment):
"""Recursively replace literals with placeholders."""
Expand All @@ -114,17 +109,21 @@ def replace_literals(segment):
return query


def mask_query(query: str, dialect: str = Dialect.ANSI.value) -> str:
def mask_query(
query: str, dialect: str = Dialect.ANSI.value, parser: LineageRunner = None
) -> str:
logger = get_logger()
try:
sqlfluff_masked_query = mask_literals_with_sqlfluff(query, dialect)
sqlparse_masked_query = mask_literals_with_sqlparse(query)
# compare both masked queries and return the one with more masked tokens
if sqlfluff_masked_query.count(MASK_TOKEN) >= sqlparse_masked_query.count(
MASK_TOKEN
):
return sqlfluff_masked_query
return sqlparse_masked_query
if not parser:
try:
parser = LineageRunner(query, dialect=dialect)
len(parser.source_tables)
except Exception:
parser = LineageRunner(query)
len(parser.source_tables)
ulixius9 marked this conversation as resolved.
Show resolved Hide resolved
if parser._dialect == SQLPARSE_DIALECT:
return mask_literals_with_sqlparse(query, parser)
return mask_literals_with_sqlfluff(query, parser)
except Exception as exc:
logger.debug(f"Failed to mask query with sqlfluff: {exc}")
logger.debug(traceback.format_exc())
Expand Down
45 changes: 16 additions & 29 deletions ingestion/src/metadata/ingestion/lineage/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,13 @@ def __init__(
self.query_parsing_success = True
self.query_parsing_failure_reason = None
self.dialect = dialect
self._masked_query = mask_query(self.query, dialect.value)
self.masked_query = None
self._clean_query = self.clean_raw_query(query)
self._masked_clean_query = mask_query(self._clean_query, dialect.value)
self.parser = self._evaluate_best_parser(
self._clean_query, dialect=dialect, timeout_seconds=timeout_seconds
)
if self.masked_query is None:
self.masked_query = mask_query(self._clean_query, parser=self.parser)

@cached_property
def involved_tables(self) -> Optional[List[Table]]:
Expand All @@ -95,7 +96,7 @@ def involved_tables(self) -> Optional[List[Table]]:
except SQLLineageException as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Cannot extract source table information from query [{self._masked_query}]: {exc}"
f"Cannot extract source table information from query [{self.masked_query or self.query}]: {exc}"
)
return None

Expand Down Expand Up @@ -334,12 +335,10 @@ def stateful_add_joins_from_statement(
)

if not table_left or not table_right:
logger.warning(
f"Can't extract table names when parsing JOIN information from {comparison}"
)
logger.debug(
f"Query: {mask_query(sql_statement, self.dialect.value)}"
f"Can't extract table names when parsing JOIN information from {comparison}"
)
logger.debug(f"Query: {self.masked_query}")
continue

left_table_column = TableColumn(table=table_left, column=column_left)
Expand Down Expand Up @@ -422,10 +421,9 @@ def get_sqlfluff_lineage_runner(qry: str, dlct: str) -> LineageRunner:
lr_dialect.get_column_lineage()
return lr_dialect

sqlfluff_count = 0
try:
lr_sqlfluff = get_sqlfluff_lineage_runner(query, dialect.value)
sqlfluff_count = len(lr_sqlfluff.get_column_lineage()) + len(
_ = len(lr_sqlfluff.get_column_lineage()) + len(
set(lr_sqlfluff.source_tables).union(
set(lr_sqlfluff.target_tables).union(
set(lr_sqlfluff.intermediate_tables)
Expand All @@ -438,23 +436,20 @@ def get_sqlfluff_lineage_runner(qry: str, dlct: str) -> LineageRunner:
f"Lineage with SqlFluff failed for the [{dialect.value}]. "
f"Parser has been running for more than {timeout_seconds} seconds."
)
logger.debug(
f"{self.query_parsing_failure_reason}] query: [{self._masked_clean_query}]"
)
lr_sqlfluff = None
except Exception:
self.query_parsing_success = False
self.query_parsing_failure_reason = (
f"Lineage with SqlFluff failed for the [{dialect.value}]"
)
logger.debug(
f"{self.query_parsing_failure_reason} query: [{self._masked_clean_query}]"
)
lr_sqlfluff = None

if lr_sqlfluff:
return lr_sqlfluff

lr_sqlparser = LineageRunner(query)
try:
sqlparser_count = len(lr_sqlparser.get_column_lineage()) + len(
_ = len(lr_sqlparser.get_column_lineage()) + len(
set(lr_sqlparser.source_tables).union(
set(lr_sqlparser.target_tables).union(
set(lr_sqlparser.intermediate_tables)
Expand All @@ -463,21 +458,13 @@ def get_sqlfluff_lineage_runner(qry: str, dlct: str) -> LineageRunner:
)
except Exception:
# if both runner have failed we return the usual one
logger.debug(f"Failed to parse query with sqlparse & sqlfluff: {query}")
return lr_sqlfluff if lr_sqlfluff else lr_sqlparser

if lr_sqlfluff:
# if sqlparser retrieve more lineage info that sqlfluff
if sqlparser_count > sqlfluff_count:
self.query_parsing_success = False
self.query_parsing_failure_reason = (
"Lineage computed with SqlFluff did not perform as expected "
f"for the [{dialect.value}]"
)
logger.debug(
f"{self.query_parsing_failure_reason} query: [{self._masked_clean_query}]"
)
return lr_sqlparser
return lr_sqlfluff
self.masked_query = mask_query(self._clean_query, parser=lr_sqlparser)
logger.debug(
f"Using sqlparse for lineage parsing for query: {self.masked_query}"
)
return lr_sqlparser

@staticmethod
Expand Down
9 changes: 4 additions & 5 deletions ingestion/src/metadata/ingestion/lineage/sql_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from metadata.generated.schema.type.entityLineage import Source as LineageSource
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.models import Either
from metadata.ingestion.lineage.masker import mask_query
from metadata.ingestion.lineage.models import (
Dialect,
QueryParsingError,
Expand Down Expand Up @@ -614,11 +613,11 @@ def get_lineage_by_query(
"""
column_lineage = {}
query_parsing_failures = QueryParsingFailures()
masked_query = mask_query(query, dialect.value)

try:
logger.debug(f"Running lineage with query: {masked_query}")
lineage_parser = LineageParser(query, dialect, timeout_seconds=timeout_seconds)
masked_query = lineage_parser.masked_query or query
logger.debug(f"Running lineage with query: {masked_query}")

raw_column_lineage = lineage_parser.column_lineage
column_lineage.update(populate_column_lineage_map(raw_column_lineage))
Expand Down Expand Up @@ -715,11 +714,11 @@ def get_lineage_via_table_entity(
"""Get lineage from table entity"""
column_lineage = {}
query_parsing_failures = QueryParsingFailures()
masked_query = mask_query(query, dialect.value)

try:
logger.debug(f"Getting lineage via table entity using query: {masked_query}")
lineage_parser = LineageParser(query, dialect, timeout_seconds=timeout_seconds)
masked_query = lineage_parser.masked_query or query
logger.debug(f"Getting lineage via table entity using query: {masked_query}")
to_table_name = table_entity.name.root

for from_table_name in lineage_parser.source_tables:
Expand Down
17 changes: 9 additions & 8 deletions ingestion/src/metadata/ingestion/ometa/mixins/es_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _paginate_es_internal(
# Get next page
last_hit = response.hits.hits[-1] if response.hits.hits else None
if not last_hit or not last_hit.sort:
logger.info("No more pages to fetch")
logger.debug("No more pages to fetch")
break

after = ",".join(last_hit.sort)
Expand Down Expand Up @@ -429,10 +429,11 @@ def yield_es_view_def(
_, database_name, schema_name, table_name = fqn.split(
hit.source["fullyQualifiedName"]
)
yield TableView(
view_definition=hit.source["schemaDefinition"],
service_name=service_name,
db_name=database_name,
schema_name=schema_name,
table_name=table_name,
)
if hit.source.get("schemaDefinition"):
yield TableView(
view_definition=hit.source["schemaDefinition"],
service_name=service_name,
db_name=database_name,
schema_name=schema_name,
table_name=table_name,
)
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _merge_column_lineage(
)
for column in updated or []:
if not isinstance(column, dict):
data = column.dict()
data = column.model_dump()
else:
data = column
if data.get("toColumn") and data.get("fromColumns"):
Expand Down
Loading
Loading