Skip to content

Commit

Permalink
implement core / adapters decoupling
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-rogers-dbt committed Jan 11, 2024
1 parent 1b1fcec commit 0a2b73d
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 30 deletions.
3 changes: 2 additions & 1 deletion dagger/run_dbt_spark_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import sys
from typing import Dict

import anyio as anyio
import dagger as dagger
Expand All @@ -19,7 +20,7 @@
TESTING_ENV_VARS.update({"ODBC_DRIVER": "Simba"})


def env_variables(envs: dict[str, str]):
def env_variables(envs: Dict[str, str]):
def env_variables_inner(ctr: dagger.Container):
for key, value in envs.items():
ctr = ctr.with_env_variable(key, value)
Expand Down
12 changes: 6 additions & 6 deletions dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import dbt.exceptions
from dbt.adapters.base import Credentials
from dbt.adapters.contracts.connection import AdapterResponse, ConnectionState
from dbt.adapters.contracts.connection import AdapterResponse, ConnectionState, Connection
from dbt.adapters.events.logging import AdapterLogger
from dbt.adapters.sql import SQLConnectionManager
from dbt.common.exceptions import DbtConfigError

from dbt.utils import DECIMALS
from dbt.common.utils.encoding import DECIMALS
from dbt.adapters.spark import __version__
from dbt.adapters.spark.session import Connection

try:
from TCLIService.ttypes import TOperationState as ThriftState
Expand Down Expand Up @@ -391,7 +391,7 @@ def validate_creds(cls, creds: Any, required: Iterable[str]) -> None:

for key in required:
if not hasattr(creds, key):
raise dbt.exceptions.DbtProfileError(
raise DbtConfigError(
"The config '{}' is required when using the {} method"
" to connect to Spark".format(key, method)
)
Expand Down Expand Up @@ -482,7 +482,7 @@ def open(cls, connection: Connection) -> Connection:
endpoint=creds.endpoint
)
else:
raise dbt.exceptions.DbtProfileError(
raise DbtConfigError(
"Either `cluster` or `endpoint` must set when"
" using the odbc method to connect to Spark"
)
Expand Down Expand Up @@ -526,7 +526,7 @@ def open(cls, connection: Connection) -> Connection:
Connection(server_side_parameters=creds.server_side_parameters)
)
else:
raise dbt.exceptions.DbtProfileError(
raise DbtConfigError(
f"invalid credential method: {creds.method}"
)
break
Expand Down
11 changes: 5 additions & 6 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from dbt.adapters.contracts.connection import AdapterResponse
from dbt.adapters.events.logging import AdapterLogger
from dbt.common.utils import AttrDict, executor
from dbt.contracts.graph.manifest import Manifest

from typing_extensions import TypeAlias

Expand All @@ -28,7 +27,7 @@
AllPurposeClusterPythonJobHelper,
)
from dbt.adapters.base import BaseRelation
from dbt.adapters.contracts.relation import RelationType
from dbt.adapters.contracts.relation import RelationType, RelationConfig
from dbt.common.clients.agate_helper import DEFAULT_TYPE_TESTER
from dbt.common.contracts.constraints import ConstraintType

Expand Down Expand Up @@ -353,9 +352,9 @@ def _get_columns_for_catalog(self, relation: BaseRelation) -> Iterable[Dict[str,
yield as_dict

def get_catalog(
self, manifest: Manifest, selected_nodes: Optional[Set] = None
self, relation_configs: Iterable[RelationConfig], selected_nodes: Optional[Set] = None
) -> Tuple[agate.Table, List[Exception]]:
schema_map = self._get_catalog_schemas(manifest)
schema_map = self._get_catalog_schemas(relation_configs)
if len(schema_map) > 1:
raise dbt.exceptions.CompilationError(
f"Expected only one database in get_catalog, found " f"{list(schema_map)}"
Expand All @@ -372,7 +371,7 @@ def get_catalog(
self._get_one_catalog,
info,
[schema],
manifest,
relation_configs,
)
)
catalogs, exceptions = catch_as_completed(futures)
Expand All @@ -382,7 +381,7 @@ def _get_one_catalog(
self,
information_schema: InformationSchema,
schemas: Set[str],
manifest: Manifest,
relation_configs: Iterable[RelationConfig],
) -> agate.Table:
if len(schemas) != 1:
raise dbt.exceptions.CompilationError(
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/spark/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dbt.adapters.base.relation import BaseRelation, Policy
from dbt.adapters.events.logging import AdapterLogger

from dbt.exceptions import DbtRuntimeError
from dbt.common.exceptions import DbtRuntimeError

logger = AdapterLogger("Spark")

Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/spark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from dbt.adapters.spark.connections import SparkConnectionWrapper
from dbt.adapters.events.logging import AdapterLogger
from dbt.utils import DECIMALS
from dbt.exceptions import DbtRuntimeError
from dbt.common.utils.encoding import DECIMALS
from dbt.common.exceptions import DbtRuntimeError
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql.utils import AnalysisException

Expand Down
27 changes: 14 additions & 13 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from multiprocessing import get_context
from unittest import mock

import dbt.flags as flags
Expand Down Expand Up @@ -146,7 +147,7 @@ def _get_target_odbc_sql_endpoint(self, project):

def test_http_connection(self):
config = self._get_target_http(self.project_cfg)
adapter = SparkAdapter(config)
adapter = SparkAdapter(config, get_context("spawn"))

def hive_http_connect(thrift_transport, configuration):
self.assertEqual(thrift_transport.scheme, "https")
Expand All @@ -171,7 +172,7 @@ def hive_http_connect(thrift_transport, configuration):

def test_thrift_connection(self):
config = self._get_target_thrift(self.project_cfg)
adapter = SparkAdapter(config)
adapter = SparkAdapter(config, get_context("spawn"))

def hive_thrift_connect(
host, port, username, auth, kerberos_service_name, password, configuration
Expand All @@ -195,7 +196,7 @@ def hive_thrift_connect(

def test_thrift_ssl_connection(self):
config = self._get_target_use_ssl_thrift(self.project_cfg)
adapter = SparkAdapter(config)
adapter = SparkAdapter(config, get_context("spawn"))

def hive_thrift_connect(thrift_transport, configuration):
self.assertIsNotNone(thrift_transport)
Expand All @@ -215,7 +216,7 @@ def hive_thrift_connect(thrift_transport, configuration):

def test_thrift_connection_kerberos(self):
config = self._get_target_thrift_kerberos(self.project_cfg)
adapter = SparkAdapter(config)
adapter = SparkAdapter(config, get_context("spawn"))

def hive_thrift_connect(
host, port, username, auth, kerberos_service_name, password, configuration
Expand All @@ -239,7 +240,7 @@ def hive_thrift_connect(

def test_odbc_cluster_connection(self):
config = self._get_target_odbc_cluster(self.project_cfg)
adapter = SparkAdapter(config)
adapter = SparkAdapter(config, get_context("spawn"))

def pyodbc_connect(connection_str, autocommit):
self.assertTrue(autocommit)
Expand All @@ -266,7 +267,7 @@ def pyodbc_connect(connection_str, autocommit):

def test_odbc_endpoint_connection(self):
config = self._get_target_odbc_sql_endpoint(self.project_cfg)
adapter = SparkAdapter(config)
adapter = SparkAdapter(config, get_context("spawn"))

def pyodbc_connect(connection_str, autocommit):
self.assertTrue(autocommit)
Expand Down Expand Up @@ -329,7 +330,7 @@ def test_parse_relation(self):
input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows]

config = self._get_target_http(self.project_cfg)
rows = SparkAdapter(config).parse_describe_extended(relation, input_cols)
rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended(relation, input_cols)
self.assertEqual(len(rows), 4)
self.assertEqual(
rows[0].to_column_dict(omit_none=False),
Expand Down Expand Up @@ -418,7 +419,7 @@ def test_parse_relation_with_integer_owner(self):
input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows]

config = self._get_target_http(self.project_cfg)
rows = SparkAdapter(config).parse_describe_extended(relation, input_cols)
rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended(relation, input_cols)

self.assertEqual(rows[0].to_column_dict().get("table_owner"), "1234")

Expand Down Expand Up @@ -454,7 +455,7 @@ def test_parse_relation_with_statistics(self):
input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows]

config = self._get_target_http(self.project_cfg)
rows = SparkAdapter(config).parse_describe_extended(relation, input_cols)
rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended(relation, input_cols)
self.assertEqual(len(rows), 1)
self.assertEqual(
rows[0].to_column_dict(omit_none=False),
Expand Down Expand Up @@ -483,7 +484,7 @@ def test_parse_relation_with_statistics(self):

def test_relation_with_database(self):
config = self._get_target_http(self.project_cfg)
adapter = SparkAdapter(config)
adapter = SparkAdapter(config, get_context("spawn"))
# fine
adapter.Relation.create(schema="different", identifier="table")
with self.assertRaises(DbtRuntimeError):
Expand Down Expand Up @@ -564,7 +565,7 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self)
)

config = self._get_target_http(self.project_cfg)
columns = SparkAdapter(config).parse_columns_from_information(relation)
columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information(relation)
self.assertEqual(len(columns), 4)
self.assertEqual(
columns[0].to_column_dict(omit_none=False),
Expand Down Expand Up @@ -649,7 +650,7 @@ def test_parse_columns_from_information_with_view_type(self):
)

config = self._get_target_http(self.project_cfg)
columns = SparkAdapter(config).parse_columns_from_information(relation)
columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information(relation)
self.assertEqual(len(columns), 4)
self.assertEqual(
columns[1].to_column_dict(omit_none=False),
Expand Down Expand Up @@ -715,7 +716,7 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel
)

config = self._get_target_http(self.project_cfg)
columns = SparkAdapter(config).parse_columns_from_information(relation)
columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information(relation)
self.assertEqual(len(columns), 4)

self.assertEqual(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import agate
import pytest
from dbt.dataclass_schema import ValidationError
from dbt.common.dataclass_schema import ValidationError
from dbt.config.project import PartialProject


Expand Down

0 comments on commit 0a2b73d

Please sign in to comment.