diff --git a/snowflake_utils/models.py b/snowflake_utils/models.py index 23e2417..eabacb3 100644 --- a/snowflake_utils/models.py +++ b/snowflake_utils/models.py @@ -99,6 +99,16 @@ class Table(BaseModel): table_structure: TableStructure | None = None role: str | None = None database: str | None = None + include_metadata: dict[str, str] = Field(default_factory=dict) + + def _include_metadata(self) -> str: + if not self.include_metadata: + return "INCLUDE_METADATA = NULL" + else: + metadata = ", ".join( + f"{k}=METADATA${v}'" for k, v in self.include_metadata.items() + ) + return f"INCLUDE_METADATA = ({metadata})" @property def fqn(self): @@ -183,60 +193,68 @@ def bulk_insert( ) return None - def copy_into( + def _copy( self, + query: str, + query_args: dict, path: str, file_format: InlineFileFormat | FileFormat, storage_integration: str | None = None, - match_by_column_name: MatchByColumnName = MatchByColumnName.CASE_INSENSITIVE, full_refresh: bool = False, - target_columns: list[str] | None = None, sync_tags: bool = False, ) -> None: - """Copy files into Snowflake""" with connect() as connection: cursor = connection.cursor() - _execute_statement = partial(execute_statement, cursor) - if self.role is not None: - logging.debug(f"Using role: {self.role}") - _execute_statement(f"USE ROLE {self.role}") - if self.database is not None: - logging.debug(f"Using database: {self.database}") - _execute_statement(f"USE DATABASE {self.database}") - - _execute_statement(self.get_create_schema_statement()) - logging.debug(f"Using schema: {self.schema_}") - _execute_statement(f"USE SCHEMA {self.schema_}") - _execute_statement( - self.get_create_temporary_external_stage( - path=path, storage_integration=storage_integration - ) - ) - - if isinstance(file_format, InlineFileFormat): - _execute_statement( - self.get_create_temporary_file_format_statement( - file_format=file_format.definition - ) - ) - file_format = self.temporary_file_format - - _execute_statement(self.get_create_table_statement(full_refresh)) + execute = self.setup_connection(path, storage_integration, cursor) + file_format = self.setup_file_format(file_format, execute) + self.create_table(full_refresh, execute) if sync_tags and self.table_structure: self.sync_tags(cursor) - logging.info(f"Starting copy into `{self.fqn}` from path '{path}'") - col_str = f"({', '.join(target_columns)})" if target_columns else "" - return _execute_statement( - f""" + query_args = query_args | {"file_format": file_format} + return execute(query.format(**query_args)) + + def copy_into( + self, + path: str, + file_format: InlineFileFormat | FileFormat, + storage_integration: str | None = None, + match_by_column_name: MatchByColumnName = MatchByColumnName.CASE_INSENSITIVE, + full_refresh: bool = False, + target_columns: list[str] | None = None, + sync_tags: bool = False, + ) -> None: + col_str = f"({', '.join(target_columns)})" if target_columns else "" + return self._copy( + f""" COPY INTO {self.fqn} {col_str} FROM {path} {f"STORAGE_INTEGRATION = {storage_integration}" if storage_integration else ''} - FILE_FORMAT = ( FORMAT_NAME ='{file_format}') + FILE_FORMAT = ( FORMAT_NAME ='{{file_format}}') MATCH_BY_COLUMN_NAME={match_by_column_name.value} - """ + {self._include_metadata()} + """, + {}, + path, + file_format, + storage_integration, + full_refresh, + sync_tags, + ) + + def create_table(self, full_refresh, _execute_statement): + _execute_statement(self.get_create_table_statement(full_refresh)) + + def setup_file_format(self, file_format, _execute_statement): + if isinstance(file_format, InlineFileFormat): + _execute_statement( + self.get_create_temporary_file_format_statement( + file_format=file_format.definition + ) ) + file_format = self.temporary_file_format + return file_format def get_columns(self, cursor: SnowflakeCursor) -> list[Column]: data = cursor.execute(f"desc table {self.fqn}").fetchall() @@ -428,6 +446,53 @@ def _unset_tag(self, cursor: SnowflakeCursor, column: str, tag: str): f'ALTER TABLE {self.fqn} MODIFY COLUMN "{column.upper()}" UNSET TAG {governance_settings.fqn(tag)}' ) + def copy_custom( + self, + column_definitions: dict[str, str], + path: str, + file_format: InlineFileFormat | FileFormat, + storage_integration: str | None = None, + full_refresh: bool = False, + sync_tags: bool = False, + ) -> None: + return self._copy( + f""" + COPY INTO {self.fqn} ({", ".join(column_definitions.keys())}) + FROM @{self.temporary_stage}/ + FILE_FORMAT = ( FORMAT_NAME ='{{file_format}}') + {self._include_metadata()} + """, + {}, + path, + file_format, + storage_integration, + full_refresh, + sync_tags, + ) + + def setup_connection( + self, path: str, storage_integration: str, cursor: SnowflakeCursor + ) -> callable: + """Setup the connection including custom role, database, schema, and temporary stage""" + _execute_statement = partial(execute_statement, cursor) + if self.role is not None: + logging.debug(f"Using role: {self.role}") + _execute_statement(f"USE ROLE {self.role}") + if self.database is not None: + logging.debug(f"Using database: {self.database}") + _execute_statement(f"USE DATABASE {self.database}") + + _execute_statement(self.get_create_schema_statement()) + logging.debug(f"Using schema: {self.schema_}") + _execute_statement(f"USE SCHEMA {self.schema_}") + _execute_statement( + self.get_create_temporary_external_stage( + path=path, storage_integration=storage_integration + ) + ) + + return _execute_statement + def _possibly_cast(s: str, old_column_type: str, new_column_type: str) -> str: if old_column_type == "VARIANT" and new_column_type != "VARIANT":