Skip to content

Commit

Permalink
support loom
Browse files Browse the repository at this point in the history
  • Loading branch information
rognerud committed Feb 2, 2025
1 parent 75d406c commit 717f438
Showing 1 changed file with 40 additions and 5 deletions.
45 changes: 40 additions & 5 deletions src/dbt_osmosis/core/osmosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
)
from dbt.mp_context import get_mp_context
from dbt.node_types import NodeType
from dbt.parser.models import ModelParser
from dbt.parser.manifest import ManifestLoader, process_node
from dbt.parser.sql import SqlBlockParser, SqlMacroParser
from dbt.task.docs.generate import Catalog
Expand Down Expand Up @@ -250,6 +251,26 @@ def manifest_mutex(self) -> threading.Lock:
"""Return the manifest mutex for thread safety."""
return self._manifest_mutex

def _add_cross_project_references(manifest, dbt_loom, project_name):
"""Add cross-project references to the dbt manifest from dbt-loom defined manifests."""
loomnodes = []
loom = dbt_loom.dbtLoom(project_name)
loom_manifests = loom.manifests
logger.info(":arrows_counterclockwise: Loaded dbt loom manifests!")
for name, loom_manifest in loom_manifests.items():
if loom_manifest.get("nodes"):
loom_manifest_nodes = loom_manifest.get("nodes")
for _, node in loom_manifest_nodes.items():
if node.get("access"):
node_access = node.get("access")
if node_access!="protected":
if node.get("resource_type")=="model":
loomnodes.append(ModelParser.parse_from_dict(None, node))
for node in loomnodes:
manifest.nodes[node.unique_id] = node
logger.info(f":arrows_counterclockwise: added {len(loomnodes)} exposed nodes from {name} to the dbt manifest!")
return manifest


def _instantiate_adapter(runtime_config: RuntimeConfig) -> BaseAdapter:
"""Instantiate a dbt adapter based on the runtime configuration."""
Expand All @@ -276,6 +297,18 @@ def create_dbt_project_context(config: DbtConfiguration) -> DbtProjectContext:
runtime_cfg.load_dependencies(),
)
manifest = loader.load()

# check if dbt-loom is installed
loom_imported = False
try:
dbt_loom = __import__("dbt_loom")
loom_imported = True
except ImportError:
pass

if loom_imported:
manifest = _add_cross_project_references(manifest, dbt_loom, runtime_cfg.project_name)

manifest.build_flat_graph()
logger.info(":arrows_counterclockwise: Loaded the dbt project manifest!")

Expand All @@ -284,6 +317,7 @@ def create_dbt_project_context(config: DbtConfiguration) -> DbtProjectContext:
setattr(runtime_cfg, "adapter", adapter)
adapter.set_macro_resolver(manifest)


sql_parser = SqlBlockParser(runtime_cfg, manifest, runtime_cfg)
macro_parser = SqlMacroParser(runtime_cfg, manifest)

Expand Down Expand Up @@ -739,25 +773,26 @@ def _topological_sort(

def _iter_candidate_nodes(
context: YamlRefactorContext,
include_external: bool = False,
) -> Iterator[tuple[str, ResultNode]]:
"""Iterate over the models in the dbt project manifest applying the filter settings."""
logger.debug(
":mag: Filtering nodes (models/sources/seeds) with user-specified settings => %s",
context.settings,
)

def f(node: ResultNode) -> bool:
def f(node: ResultNode, include_external: bool = False) -> bool:
"""Closure to filter models based on the context settings."""
if node.resource_type not in (NodeType.Model, NodeType.Source, NodeType.Seed):
return False
if node.package_name != context.project.runtime_cfg.project_name:
if node.package_name != context.project.runtime_cfg.project_name and not include_external:
return False
if node.resource_type == NodeType.Model and node.config.materialized == "ephemeral":
return False
if context.settings.models:
if not _is_file_match(
node, context.settings.models, context.project.runtime_cfg.project_root
):
) and not include_external:
return False
if context.settings.fqn:
if not _is_fqn_match(node, context.settings.fqn):
Expand All @@ -768,7 +803,7 @@ def f(node: ResultNode) -> bool:
candidate_nodes: list[t.Any] = []
items = chain(context.project.manifest.nodes.items(), context.project.manifest.sources.items())
for uid, dbt_node in items:
if f(dbt_node):
if f(dbt_node, include_external):
candidate_nodes.append((uid, dbt_node))

for uid, node in _topological_sort(candidate_nodes):
Expand Down Expand Up @@ -1957,7 +1992,7 @@ def inherit_upstream_column_knowledge(
logger.info(":wave: Inheriting column knowledge across all matched nodes.")
for _ in context.pool.map(
partial(inherit_upstream_column_knowledge, context),
(n for _, n in _iter_candidate_nodes(context)),
(n for _, n in _iter_candidate_nodes(context, include_external=True)),
):
...
return
Expand Down

0 comments on commit 717f438

Please sign in to comment.