From 717f43897e1efe9d2efa96fe1544f87390c9f31c Mon Sep 17 00:00:00 2001 From: Gisle Rognerud Date: Sun, 2 Feb 2025 18:57:45 +0100 Subject: [PATCH] support loom --- src/dbt_osmosis/core/osmosis.py | 45 +++++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/src/dbt_osmosis/core/osmosis.py b/src/dbt_osmosis/core/osmosis.py index e22e8a6..d6d705f 100644 --- a/src/dbt_osmosis/core/osmosis.py +++ b/src/dbt_osmosis/core/osmosis.py @@ -52,6 +52,7 @@ ) from dbt.mp_context import get_mp_context from dbt.node_types import NodeType +from dbt.parser.models import ModelParser from dbt.parser.manifest import ManifestLoader, process_node from dbt.parser.sql import SqlBlockParser, SqlMacroParser from dbt.task.docs.generate import Catalog @@ -250,6 +251,26 @@ def manifest_mutex(self) -> threading.Lock: """Return the manifest mutex for thread safety.""" return self._manifest_mutex +def _add_cross_project_references(manifest, dbt_loom, project_name): + """Add cross-project references to the dbt manifest from dbt-loom defined manifests.""" + loomnodes = [] + loom = dbt_loom.dbtLoom(project_name) + loom_manifests = loom.manifests + logger.info(":arrows_counterclockwise: Loaded dbt loom manifests!") + for name, loom_manifest in loom_manifests.items(): + if loom_manifest.get("nodes"): + loom_manifest_nodes = loom_manifest.get("nodes") + for _, node in loom_manifest_nodes.items(): + if node.get("access"): + node_access = node.get("access") + if node_access!="protected": + if node.get("resource_type")=="model": + loomnodes.append(ModelParser.parse_from_dict(None, node)) + for node in loomnodes: + manifest.nodes[node.unique_id] = node + logger.info(f":arrows_counterclockwise: added {len(loomnodes)} exposed nodes from {name} to the dbt manifest!") + return manifest + def _instantiate_adapter(runtime_config: RuntimeConfig) -> BaseAdapter: """Instantiate a dbt adapter based on the runtime configuration.""" @@ -276,6 +297,18 @@ def create_dbt_project_context(config: DbtConfiguration) -> DbtProjectContext: runtime_cfg.load_dependencies(), ) manifest = loader.load() + + # check if dbt-loom is installed + loom_imported = False + try: + dbt_loom = __import__("dbt_loom") + loom_imported = True + except ImportError: + pass + + if loom_imported: + manifest = _add_cross_project_references(manifest, dbt_loom, runtime_cfg.project_name) + manifest.build_flat_graph() logger.info(":arrows_counterclockwise: Loaded the dbt project manifest!") @@ -284,6 +317,7 @@ def create_dbt_project_context(config: DbtConfiguration) -> DbtProjectContext: setattr(runtime_cfg, "adapter", adapter) adapter.set_macro_resolver(manifest) + sql_parser = SqlBlockParser(runtime_cfg, manifest, runtime_cfg) macro_parser = SqlMacroParser(runtime_cfg, manifest) @@ -739,6 +773,7 @@ def _topological_sort( def _iter_candidate_nodes( context: YamlRefactorContext, + include_external: bool = False, ) -> Iterator[tuple[str, ResultNode]]: """Iterate over the models in the dbt project manifest applying the filter settings.""" logger.debug( @@ -746,18 +781,18 @@ def _iter_candidate_nodes( context.settings, ) - def f(node: ResultNode) -> bool: + def f(node: ResultNode, include_external: bool = False) -> bool: """Closure to filter models based on the context settings.""" if node.resource_type not in (NodeType.Model, NodeType.Source, NodeType.Seed): return False - if node.package_name != context.project.runtime_cfg.project_name: + if node.package_name != context.project.runtime_cfg.project_name and not include_external: return False if node.resource_type == NodeType.Model and node.config.materialized == "ephemeral": return False if context.settings.models: if not _is_file_match( node, context.settings.models, context.project.runtime_cfg.project_root - ): + ) and not include_external: return False if context.settings.fqn: if not _is_fqn_match(node, context.settings.fqn): @@ -768,7 +803,7 @@ def f(node: ResultNode) -> bool: candidate_nodes: list[t.Any] = [] items = chain(context.project.manifest.nodes.items(), context.project.manifest.sources.items()) for uid, dbt_node in items: - if f(dbt_node): + if f(dbt_node, include_external): candidate_nodes.append((uid, dbt_node)) for uid, node in _topological_sort(candidate_nodes): @@ -1957,7 +1992,7 @@ def inherit_upstream_column_knowledge( logger.info(":wave: Inheriting column knowledge across all matched nodes.") for _ in context.pool.map( partial(inherit_upstream_column_knowledge, context), - (n for _, n in _iter_candidate_nodes(context)), + (n for _, n in _iter_candidate_nodes(context, include_external=True)), ): ... return