diff --git a/dbt/adapters/duckdb/plugins/glue.py b/dbt/adapters/duckdb/plugins/glue.py index 76f25419..07eee995 100644 --- a/dbt/adapters/duckdb/plugins/glue.py +++ b/dbt/adapters/duckdb/plugins/glue.py @@ -8,6 +8,7 @@ from mypy_boto3_glue import GlueClient from mypy_boto3_glue.type_defs import ColumnTypeDef from mypy_boto3_glue.type_defs import GetTableResponseTypeDef +from mypy_boto3_glue.type_defs import PartitionInputTypeDef from mypy_boto3_glue.type_defs import SerDeInfoTypeDef from mypy_boto3_glue.type_defs import StorageDescriptorTypeDef from mypy_boto3_glue.type_defs import TableInputTypeDef @@ -132,12 +133,50 @@ def _convert_columns(column_list: Sequence[Column]) -> Sequence["ColumnTypeDef"] return column_types -def _create_table(client: "GlueClient", database: str, table_def: "TableInputTypeDef") -> None: +def _create_table( + client: "GlueClient", + database: str, + table_def: "TableInputTypeDef", + partition_columns: List[Dict[str, str]], +) -> None: client.create_table(DatabaseName=database, TableInput=table_def) + # Create partition if relevant + if partition_columns != []: + partition_input, partition_values = _parse_partition_columns(partition_columns, table_def) + + client.create_partition( + DatabaseName=database, TableName=table_def["Name"], PartitionInput=partition_input + ) -def _update_table(client: "GlueClient", database: str, table_def: "TableInputTypeDef") -> None: +def _update_table( + client: "GlueClient", + database: str, + table_def: "TableInputTypeDef", + partition_columns: List[Dict[str, str]], +) -> None: client.update_table(DatabaseName=database, TableInput=table_def) + # Update or create partition if relevant + if partition_columns != []: + partition_input, partition_values = _parse_partition_columns(partition_columns, table_def) + + try: + client.get_partition( + DatabaseName=database, + TableName=table_def["Name"], + PartitionValues=partition_values, + ) + client.update_partition( + DatabaseName=database, + TableName=table_def["Name"], + PartitionValueList=partition_values, + PartitionInput=partition_input, + ) + + except client.exceptions.EntityNotFoundException: + client.create_partition( + DatabaseName=database, TableName=table_def["Name"], PartitionInput=partition_input + ) def _get_table( @@ -163,7 +202,9 @@ def _get_column_type_def( return None -def _add_partition_columns(table_def: TableInputTypeDef, partition_columns) -> TableInputTypeDef: +def _add_partition_columns( + table_def: TableInputTypeDef, partition_columns: List[Dict[str, str]] +) -> TableInputTypeDef: partition_keys = [] if "PartitionKeys" not in table_def: table_def["PartitionKeys"] = [] @@ -172,18 +213,35 @@ def _add_partition_columns(table_def: TableInputTypeDef, partition_columns) -> T partition_keys.append(partition_column) table_def["PartitionKeys"] = partition_keys # Remove columns from StorageDescriptor if they match with partition columns to avoid duplicate columns - for partition_column in partition_columns: + for p_column in partition_columns: table_def["StorageDescriptor"]["Columns"] = [ column for column in table_def["StorageDescriptor"]["Columns"] - if not ( - column["Name"] == partition_column["Name"] - and column["Type"] == partition_column["Type"] - ) + if not (column["Name"] == p_column["Name"] and column["Type"] == p_column["Type"]) ] return table_def +def _parse_partition_columns( + partition_columns: List[Dict[str, str]], table_def: TableInputTypeDef +): + partition_input = None + if partition_columns: + partition_values = [column["Value"] for column in partition_columns] + partition_location = table_def["StorageDescriptor"]["Location"] + partition_components = [partition_location] + for c in partition_columns: + partition_components.append("=".join((c["Name"], c["Value"]))) + partition_location = "/".join(partition_components) + + partition_input = PartitionInputTypeDef() + partition_input["Values"] = partition_values + partition_input["StorageDescriptor"] = table_def["StorageDescriptor"] + partition_input["StorageDescriptor"]["Location"] = partition_location + + return partition_input, partition_values + + def _get_table_def( table: str, s3_parent: str, @@ -252,9 +310,19 @@ def create_or_update_table( glue_columns = _get_column_type_def(glue_table) # Create new version only if columns are changed if glue_columns != columns: - _update_table(client=client, database=database, table_def=table_def) + _update_table( + client=client, + database=database, + table_def=table_def, + partition_columns=partition_columns, + ) else: - _create_table(client=client, database=database, table_def=table_def) + _create_table( + client=client, + database=database, + table_def=table_def, + partition_columns=partition_columns, + ) class Plugin(BasePlugin):