Skip to content

Commit

Permalink
Merge branch 'main' into fix_parameter_validation_error
Browse files Browse the repository at this point in the history
  • Loading branch information
moomindani committed Feb 18, 2025
2 parents e574ccb + 8a6f8e3 commit fd4c96e
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 15 deletions.
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
## Future Release
- Allow spawning new isolated sessions for the models that require different session configuration
## New version
- Correctly handle EntityNotFound when trying to determine session state, setting state to does not exist instead of STOPPED.
- Allow spawning new isolated sessions for the models that require different session configuration.
- Correctly handle EntityNotFound when listing relations.
- Added configuration property to allow spark casting of seed column types.
- Fix the get_columns_in_relation function error when on_schema_change is specified
- Fix error handling

## v1.9.0
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ The table below describes all the options.
| glue_session_reuse | re-use the glue-session to run multiple dbt run commands: If set to true, the glue session will not be closed for re-use. If set to false, the session will be closed. The glue session will close after idle_timeout time is expired after idle_timeout time | no |
| datalake_formats | The ACID datalake format that you want to use if you are doing merge, can be `hudi`, `ìceberg` or `delta` |no|
| use_arrow | (experimental) use an arrow file instead of stdout to have better scalability. |no|
| enable_spark_seed_casting | Allows spark to cast the columns depending on the specified model column types. Default `False`. |no|

## Configs

Expand Down
8 changes: 6 additions & 2 deletions dbt/adapters/glue/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from dbt.adapters.contracts.connection import Credentials
from dbt_common.exceptions import DbtRuntimeError


@dataclass
class GlueCredentials(Credentials):
""" Required connections for a Glue connection"""
"""Required connections for a Glue connection"""

role_arn: Optional[str] = None # type: ignore
region: Optional[str] = None # type: ignore
workers: Optional[int] = None # type: ignore
Expand Down Expand Up @@ -36,6 +38,7 @@ class GlueCredentials(Credentials):
enable_session_per_model: Optional[bool] = False
use_arrow: Optional[bool] = False
custom_iceberg_catalog_namespace: Optional[str] = "glue_catalog"
enable_spark_seed_casting: Optional[bool] = False

@property
def type(self):
Expand Down Expand Up @@ -93,5 +96,6 @@ def _connection_keys(self):
'glue_session_reuse',
'datalake_formats',
'enable_session_per_model',
'use_arrow'
'use_arrow',
'enable_spark_seed_casting',
]
6 changes: 6 additions & 0 deletions dbt/adapters/glue/gluedbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,12 @@ def state(self) -> str:
session = response.get("Session", {})
self._state = session.get("Status")
except Exception as e:
if isinstance(e, botocore.exceptions.ClientError):
if e.response['Error']['Code'] == 'EntityNotFoundException':
logger.debug(f"Session {self.session_id} not found")
logger.debug(e)
self._state = None
return self._state
logger.debug(f"Error while checking state of session {self.session_id}")
logger.debug(e)
self._state = GlueSessionState.STOPPED
Expand Down
74 changes: 70 additions & 4 deletions dbt/adapters/glue/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,41 @@
from dbt_common.exceptions import DbtDatabaseError, CompilationError
from dbt.adapters.base.impl import catch_as_completed
from dbt_common.utils import executor
from dbt_common.clients import agate_helper
from dbt.adapters.events.logging import AdapterLogger

logger = AdapterLogger("Glue")


class ColumnCsvMappingStrategy:
_schema_mappings = {
"timestamp": "string",
"bigint": "double",
"date": "string",
}

def __init__(self, column_name, converted_agate_type, specified_type):
self.column_name = column_name
self.converted_agate_type = converted_agate_type
self.specified_type = specified_type

def as_schema_value(self):
return ColumnCsvMappingStrategy._schema_mappings.get(self.converted_agate_type, self.converted_agate_type)

def as_cast_value(self):
return self.specified_type if self.specified_type else self.converted_agate_type

@classmethod
def from_model(cls, model, agate_table):
return [
ColumnCsvMappingStrategy(
column.name,
GlueAdapter.convert_agate_type(agate_table, i),
model.get("config", {}).get("column_types", {}).get(column.name),
)
for i, column in enumerate(agate_table.columns)
]

class GlueAdapter(SQLAdapter):
ConnectionManager = GlueConnectionManager
Relation = SparkRelation
Expand Down Expand Up @@ -135,6 +165,8 @@ def list_relations_without_caching(self, schema_relation: SparkRelation):
type=self.relation_type_map.get(table.get("TableType")),
))
return relations
except client.exceptions.EntityNotFoundException as e:
return []
except Exception as e:
logger.error(e)

Expand Down Expand Up @@ -242,7 +274,14 @@ def get_columns_in_relation(self, relation: BaseRelation):
logger.debug("get_columns_in_relation called")
session, client = self.get_connection()
computed_schema = self.__compute_schema_based_on_type(schema=relation.schema, identifier=relation.identifier)
code = f"""describe {computed_schema}.{relation.identifier}"""


if relation.identifier.endswith('_tmp') and not relation.identifier.endswith('_dbt_tmp'):
code = f"""describe {relation.identifier}"""
else:
code = f"""describe {computed_schema}.{relation.identifier}"""
logger.debug(f"code: {code}")

columns = []
try:
response = session.cursor().execute(code)
Expand Down Expand Up @@ -535,7 +574,7 @@ def create_csv_table(self, model, agate_table):
mode = "False"

csv_chunks = self._split_csv_records_into_chunks(json.loads(f.getvalue()))
statements = self._map_csv_chunks_to_code(csv_chunks, session, model, mode)
statements = self._map_csv_chunks_to_code(csv_chunks, session, model, mode, ColumnCsvMappingStrategy.from_model(model, agate_table))
try:
cursor = session.cursor()
for statement in statements:
Expand All @@ -545,7 +584,14 @@ def create_csv_table(self, model, agate_table):
except Exception as e:
logger.error(e)

def _map_csv_chunks_to_code(self, csv_chunks: List[List[dict]], session: GlueConnection, model, mode):
def _map_csv_chunks_to_code(
self,
csv_chunks: List[List[dict]],
session: GlueConnection,
model,
mode,
column_mappings: List[ColumnCsvMappingStrategy],
):
statements = []
for i, csv_chunk in enumerate(csv_chunks):
is_first = i == 0
Expand All @@ -564,8 +610,28 @@ def _map_csv_chunks_to_code(self, csv_chunks: List[List[dict]], session: GlueCon
SqlWrapper2.execute("""select 1""")
'''
else:
code += f'''
if session.credentials.enable_spark_seed_casting:
csv_schema = ", ".join(
[f"{mapping.column_name}: {mapping.as_schema_value()}" for mapping in column_mappings]
)

cast_columns = ", ".join(
[
f'"cast({mapping.column_name} as {mapping.as_cast_value()}) as {mapping.column_name}"'
for mapping in column_mappings
if (cast_value := mapping.as_cast_value())
],
)

code += f"""
df = spark.createDataFrame(csv, "{csv_schema}")
df = df.selectExpr({cast_columns})
"""
else:
code += """
df = spark.createDataFrame(csv)
"""
code += f'''
table_name = '{model["schema"]}.{model["name"]}'
if (spark.sql("show tables in {model["schema"]}").where("tableName == lower('{model["name"]}')").count() > 0):
df.write\
Expand Down
143 changes: 136 additions & 7 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import unittest
from unittest import mock
from unittest.mock import Mock
import pytest
from multiprocessing import get_context
import agate.data_types
from botocore.client import BaseClient
from moto import mock_aws
import boto3

import agate
from dbt.config import RuntimeConfig
Expand All @@ -13,6 +16,9 @@
from dbt.adapters.glue import GlueAdapter
from dbt.adapters.glue.gluedbapi import GlueConnection
from dbt.adapters.glue.relation import SparkRelation
from dbt.adapters.glue.impl import ColumnCsvMappingStrategy
from dbt_common.clients import agate_helper
from dbt.adapters.contracts.relation import RelationConfig
from tests.util import config_from_parts_or_dicts
from .util import MockAWSService

Expand Down Expand Up @@ -41,7 +47,7 @@ def setUp(self):
"region": "us-east-1",
"workers": 2,
"worker_type": "G.1X",
"location" : "path_to_location/",
"location": "path_to_location/",
"schema": "dbt_unit_test_01",
"database": "dbt_unit_test_01",
"use_interactive_session_role_for_api_calls": False,
Expand Down Expand Up @@ -71,7 +77,6 @@ def test_glue_connection(self):
self.assertIsNotNone(connection.handle)
self.assertIsInstance(glueSession.client, BaseClient)


@mock_aws
def test_get_table_type(self):
config = self._get_config()
Expand All @@ -96,8 +101,10 @@ def test_create_csv_table_slices_big_datasets(self):
adapter = GlueAdapter(config, get_context("spawn"))
model = {"name": "mock_model", "schema": "mock_schema"}
session_mock = Mock()
adapter.get_connection = lambda: (session_mock, 'mock_client')
test_table = agate.Table([(f'mock_value_{i}',f'other_mock_value_{i}') for i in range(2000)], column_names=['value', 'other_value'])
adapter.get_connection = lambda: (session_mock, "mock_client")
test_table = agate.Table(
[(f"mock_value_{i}", f"other_mock_value_{i}") for i in range(2000)], column_names=["value", "other_value"]
)
adapter.create_csv_table(model, test_table)

# test table is between 120000 and 180000 characters so it should be split three times (max chunk is 60000)
Expand All @@ -113,13 +120,135 @@ def test_get_location(self):
with mock.patch("dbt.adapters.glue.connections.open"):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load
print(adapter.get_location(relation))
self.assertEqual(adapter.get_location(relation), "LOCATION 'path_to_location/some_database/some_table'")

def test_get_custom_iceberg_catalog_namespace(self):
config = self._get_config()
adapter = GlueAdapter(config, get_context("spawn"))
with mock.patch("dbt.adapters.glue.connections.open"):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load
self.assertEqual(adapter.get_custom_iceberg_catalog_namespace(), "custom_iceberg_catalog")
self.assertEqual(adapter.get_custom_iceberg_catalog_namespace(), "custom_iceberg_catalog")

def test_create_csv_table_provides_schema_and_casts_when_spark_seed_cast_is_enabled(self):
config = self._get_config()
config.credentials.enable_spark_seed_casting = True
adapter = GlueAdapter(config, get_context("spawn"))
csv_chunks = [{"test_column_double": "1.2345", "test_column_str": "test"}]
model = {
"name": "mock_model",
"schema": "mock_schema",
"config": {"column_types": {"test_column_double": "double", "test_column_str": "string"}},
}
column_mappings = [
ColumnCsvMappingStrategy("test_column_double", "string", "double"),
ColumnCsvMappingStrategy("test_column_str", "string", "string"),
]
code = adapter._map_csv_chunks_to_code(csv_chunks, config, model, "True", column_mappings)
self.assertIn('spark.createDataFrame(csv, "test_column_double: string, test_column_str: string")', code[0])
self.assertIn(
'df = df.selectExpr("cast(test_column_double as double) as test_column_double", '
+ '"cast(test_column_str as string) as test_column_str")',
code[0],
)

def test_create_csv_table_doesnt_provide_schema_when_spark_seed_cast_is_disabled(self):
config = self._get_config()
config.credentials.enable_spark_seed_casting = False
adapter = GlueAdapter(config, get_context("spawn"))
csv_chunks = [{"test_column": "1.2345"}]
model = {"name": "mock_model", "schema": "mock_schema"}
column_mappings = [ColumnCsvMappingStrategy("test_column", agate.data_types.Text, "double")]
code = adapter._map_csv_chunks_to_code(csv_chunks, config, model, "True", column_mappings)
self.assertIn("spark.createDataFrame(csv)", code[0])

@mock_aws
def test_when_database_not_exists_list_relations_without_caching_returns_empty_array(self):
config = self._get_config()
adapter = GlueAdapter(config, get_context("spawn"))
adapter.get_connection = lambda : (None, boto3.client("glue", region_name="us-east-1"))
relation = Mock(SparkRelation)
relation.schema = 'mockdb'
actual = adapter.list_relations_without_caching(relation)
self.assertEqual([],actual)

@mock_aws
def test_list_relations_returns_database_tables(self):
config = self._get_config()
glue_client = boto3.client("glue", region_name="us-east-1")

# Prepare database tables
database_name = 'mockdb'
table_names = ['table1', 'table2', 'table3']
glue_client.create_database(DatabaseInput={"Name":database_name})
for table_name in table_names:
glue_client.create_table(DatabaseName=database_name,TableInput={"Name":table_name})
expected = [(database_name, table_name) for table_name in table_names]

# Prepare adapter for test
adapter = GlueAdapter(config, get_context("spawn"))
adapter.get_connection = lambda : (None, glue_client)
relation = Mock(SparkRelation)
relation.schema = database_name

relations = adapter.list_relations_without_caching(relation)

actual = [(relation.path.schema, relation.path.identifier) for relation in relations]
self.assertCountEqual(expected,actual)


class TestCsvMappingStrategy:
@pytest.mark.parametrize(
"agate_type,specified_type,expected_schema_type,expected_cast_type",
[
("timestamp", None, "string", "timestamp"),
("double", None, "double", "double"),
("bigint", None, "double", "bigint"),
("boolean", None, "boolean", "boolean"),
("date", None, "string", "date"),
("timestamp", None, "string", "timestamp"),
("string", None, "string", "string"),
("string", "double", "string", "double"),
],
ids=[
"test isodatetime cast",
"test number cast",
"test integer cast",
"test boolean cast",
"test date cast",
"test datetime cast",
"test text cast",
"test specified cast",
],
)
def test_mapping_strategy_provides_proper_mappings(
self, agate_type, specified_type, expected_schema_type, expected_cast_type
):
column_mapping = ColumnCsvMappingStrategy("test_column", agate_type, specified_type)
assert column_mapping.as_schema_value() == expected_schema_type
assert column_mapping.as_cast_value() == expected_cast_type

def test_from_model_builds_column_mappings(self):
expected_column_names = ["col_int", "col_str", "col_date", "col_specific"]
expected_converted_agate_types = [
"bigint",
"string",
"date",
"string",
]
expected_specified_types = [None, None, None, "double"]
agate_table = agate.Table(
[(111, "str_val", "2024-01-01", "1.234")],
column_names=expected_column_names,
column_types=[
agate.data_types.Number(),
agate.data_types.Text(),
agate.data_types.Date(),
agate.data_types.Text(),
],
)
model = {"name": "mock_model", "config": {"column_types": {"col_specific": "double"}}}
mappings = ColumnCsvMappingStrategy.from_model(model, agate_table)
assert expected_column_names == [mapping.column_name for mapping in mappings]
assert expected_converted_agate_types == [mapping.converted_agate_type for mapping in mappings]
assert expected_specified_types == [mapping.specified_type for mapping in mappings]

0 comments on commit fd4c96e

Please sign in to comment.