From 0a2b73db07bd2128519b90952985a837a38d9a01 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 10 Jan 2024 17:19:03 -0800 Subject: [PATCH] implement core / adapters decoupling --- dagger/run_dbt_spark_tests.py | 3 ++- dbt/adapters/spark/connections.py | 12 ++++++------ dbt/adapters/spark/impl.py | 11 +++++------ dbt/adapters/spark/relation.py | 2 +- dbt/adapters/spark/session.py | 4 ++-- tests/unit/test_adapter.py | 27 ++++++++++++++------------- tests/unit/utils.py | 2 +- 7 files changed, 31 insertions(+), 30 deletions(-) diff --git a/dagger/run_dbt_spark_tests.py b/dagger/run_dbt_spark_tests.py index 718519909..2fde4a25d 100644 --- a/dagger/run_dbt_spark_tests.py +++ b/dagger/run_dbt_spark_tests.py @@ -2,6 +2,7 @@ import argparse import sys +from typing import Dict import anyio as anyio import dagger as dagger @@ -19,7 +20,7 @@ TESTING_ENV_VARS.update({"ODBC_DRIVER": "Simba"}) -def env_variables(envs: dict[str, str]): +def env_variables(envs: Dict[str, str]): def env_variables_inner(ctr: dagger.Container): for key, value in envs.items(): ctr = ctr.with_env_variable(key, value) diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 76390a2bc..fa6f48f52 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -2,13 +2,13 @@ import dbt.exceptions from dbt.adapters.base import Credentials -from dbt.adapters.contracts.connection import AdapterResponse, ConnectionState +from dbt.adapters.contracts.connection import AdapterResponse, ConnectionState, Connection from dbt.adapters.events.logging import AdapterLogger from dbt.adapters.sql import SQLConnectionManager +from dbt.common.exceptions import DbtConfigError -from dbt.utils import DECIMALS +from dbt.common.utils.encoding import DECIMALS from dbt.adapters.spark import __version__ -from dbt.adapters.spark.session import Connection try: from TCLIService.ttypes import TOperationState as ThriftState @@ -391,7 +391,7 @@ def validate_creds(cls, creds: Any, required: Iterable[str]) -> None: for key in required: if not hasattr(creds, key): - raise dbt.exceptions.DbtProfileError( + raise DbtConfigError( "The config '{}' is required when using the {} method" " to connect to Spark".format(key, method) ) @@ -482,7 +482,7 @@ def open(cls, connection: Connection) -> Connection: endpoint=creds.endpoint ) else: - raise dbt.exceptions.DbtProfileError( + raise DbtConfigError( "Either `cluster` or `endpoint` must set when" " using the odbc method to connect to Spark" ) @@ -526,7 +526,7 @@ def open(cls, connection: Connection) -> Connection: Connection(server_side_parameters=creds.server_side_parameters) ) else: - raise dbt.exceptions.DbtProfileError( + raise DbtConfigError( f"invalid credential method: {creds.method}" ) break diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 325139911..8cc7d848b 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -7,7 +7,6 @@ from dbt.adapters.contracts.connection import AdapterResponse from dbt.adapters.events.logging import AdapterLogger from dbt.common.utils import AttrDict, executor -from dbt.contracts.graph.manifest import Manifest from typing_extensions import TypeAlias @@ -28,7 +27,7 @@ AllPurposeClusterPythonJobHelper, ) from dbt.adapters.base import BaseRelation -from dbt.adapters.contracts.relation import RelationType +from dbt.adapters.contracts.relation import RelationType, RelationConfig from dbt.common.clients.agate_helper import DEFAULT_TYPE_TESTER from dbt.common.contracts.constraints import ConstraintType @@ -353,9 +352,9 @@ def _get_columns_for_catalog(self, relation: BaseRelation) -> Iterable[Dict[str, yield as_dict def get_catalog( - self, manifest: Manifest, selected_nodes: Optional[Set] = None + self, relation_configs: Iterable[RelationConfig], selected_nodes: Optional[Set] = None ) -> Tuple[agate.Table, List[Exception]]: - schema_map = self._get_catalog_schemas(manifest) + schema_map = self._get_catalog_schemas(relation_configs) if len(schema_map) > 1: raise dbt.exceptions.CompilationError( f"Expected only one database in get_catalog, found " f"{list(schema_map)}" @@ -372,7 +371,7 @@ def get_catalog( self._get_one_catalog, info, [schema], - manifest, + relation_configs, ) ) catalogs, exceptions = catch_as_completed(futures) @@ -382,7 +381,7 @@ def _get_one_catalog( self, information_schema: InformationSchema, schemas: Set[str], - manifest: Manifest, + relation_configs: Iterable[RelationConfig], ) -> agate.Table: if len(schemas) != 1: raise dbt.exceptions.CompilationError( diff --git a/dbt/adapters/spark/relation.py b/dbt/adapters/spark/relation.py index 1fa1272f4..a6d679d56 100644 --- a/dbt/adapters/spark/relation.py +++ b/dbt/adapters/spark/relation.py @@ -4,7 +4,7 @@ from dbt.adapters.base.relation import BaseRelation, Policy from dbt.adapters.events.logging import AdapterLogger -from dbt.exceptions import DbtRuntimeError +from dbt.common.exceptions import DbtRuntimeError logger = AdapterLogger("Spark") diff --git a/dbt/adapters/spark/session.py b/dbt/adapters/spark/session.py index 1def33be1..d5d3ff050 100644 --- a/dbt/adapters/spark/session.py +++ b/dbt/adapters/spark/session.py @@ -8,8 +8,8 @@ from dbt.adapters.spark.connections import SparkConnectionWrapper from dbt.adapters.events.logging import AdapterLogger -from dbt.utils import DECIMALS -from dbt.exceptions import DbtRuntimeError +from dbt.common.utils.encoding import DECIMALS +from dbt.common.exceptions import DbtRuntimeError from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.utils import AnalysisException diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index a7da63301..b46f7eef6 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -1,4 +1,5 @@ import unittest +from multiprocessing import get_context from unittest import mock import dbt.flags as flags @@ -146,7 +147,7 @@ def _get_target_odbc_sql_endpoint(self, project): def test_http_connection(self): config = self._get_target_http(self.project_cfg) - adapter = SparkAdapter(config) + adapter = SparkAdapter(config, get_context("spawn")) def hive_http_connect(thrift_transport, configuration): self.assertEqual(thrift_transport.scheme, "https") @@ -171,7 +172,7 @@ def hive_http_connect(thrift_transport, configuration): def test_thrift_connection(self): config = self._get_target_thrift(self.project_cfg) - adapter = SparkAdapter(config) + adapter = SparkAdapter(config, get_context("spawn")) def hive_thrift_connect( host, port, username, auth, kerberos_service_name, password, configuration @@ -195,7 +196,7 @@ def hive_thrift_connect( def test_thrift_ssl_connection(self): config = self._get_target_use_ssl_thrift(self.project_cfg) - adapter = SparkAdapter(config) + adapter = SparkAdapter(config, get_context("spawn")) def hive_thrift_connect(thrift_transport, configuration): self.assertIsNotNone(thrift_transport) @@ -215,7 +216,7 @@ def hive_thrift_connect(thrift_transport, configuration): def test_thrift_connection_kerberos(self): config = self._get_target_thrift_kerberos(self.project_cfg) - adapter = SparkAdapter(config) + adapter = SparkAdapter(config, get_context("spawn")) def hive_thrift_connect( host, port, username, auth, kerberos_service_name, password, configuration @@ -239,7 +240,7 @@ def hive_thrift_connect( def test_odbc_cluster_connection(self): config = self._get_target_odbc_cluster(self.project_cfg) - adapter = SparkAdapter(config) + adapter = SparkAdapter(config, get_context("spawn")) def pyodbc_connect(connection_str, autocommit): self.assertTrue(autocommit) @@ -266,7 +267,7 @@ def pyodbc_connect(connection_str, autocommit): def test_odbc_endpoint_connection(self): config = self._get_target_odbc_sql_endpoint(self.project_cfg) - adapter = SparkAdapter(config) + adapter = SparkAdapter(config, get_context("spawn")) def pyodbc_connect(connection_str, autocommit): self.assertTrue(autocommit) @@ -329,7 +330,7 @@ def test_parse_relation(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config).parse_describe_extended(relation, input_cols) + rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended(relation, input_cols) self.assertEqual(len(rows), 4) self.assertEqual( rows[0].to_column_dict(omit_none=False), @@ -418,7 +419,7 @@ def test_parse_relation_with_integer_owner(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config).parse_describe_extended(relation, input_cols) + rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended(relation, input_cols) self.assertEqual(rows[0].to_column_dict().get("table_owner"), "1234") @@ -454,7 +455,7 @@ def test_parse_relation_with_statistics(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config).parse_describe_extended(relation, input_cols) + rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended(relation, input_cols) self.assertEqual(len(rows), 1) self.assertEqual( rows[0].to_column_dict(omit_none=False), @@ -483,7 +484,7 @@ def test_parse_relation_with_statistics(self): def test_relation_with_database(self): config = self._get_target_http(self.project_cfg) - adapter = SparkAdapter(config) + adapter = SparkAdapter(config, get_context("spawn")) # fine adapter.Relation.create(schema="different", identifier="table") with self.assertRaises(DbtRuntimeError): @@ -564,7 +565,7 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) ) config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config).parse_columns_from_information(relation) + columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information(relation) self.assertEqual(len(columns), 4) self.assertEqual( columns[0].to_column_dict(omit_none=False), @@ -649,7 +650,7 @@ def test_parse_columns_from_information_with_view_type(self): ) config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config).parse_columns_from_information(relation) + columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information(relation) self.assertEqual(len(columns), 4) self.assertEqual( columns[1].to_column_dict(omit_none=False), @@ -715,7 +716,7 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel ) config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config).parse_columns_from_information(relation) + columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information(relation) self.assertEqual(len(columns), 4) self.assertEqual( diff --git a/tests/unit/utils.py b/tests/unit/utils.py index ac8c62244..a32d6608d 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -9,7 +9,7 @@ import agate import pytest -from dbt.dataclass_schema import ValidationError +from dbt.common.dataclass_schema import ValidationError from dbt.config.project import PartialProject