Skip to content

Commit

Permalink
feat: add copy custom
Browse files Browse the repository at this point in the history
  • Loading branch information
pquadri committed Jul 30, 2024
1 parent ef506e4 commit 21a68a1
Showing 1 changed file with 101 additions and 36 deletions.
137 changes: 101 additions & 36 deletions snowflake_utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,16 @@ class Table(BaseModel):
table_structure: TableStructure | None = None
role: str | None = None
database: str | None = None
include_metadata: dict[str, str] = Field(default_factory=dict)

def _include_metadata(self) -> str:
if not self.include_metadata:
return "INCLUDE_METADATA = NULL"
else:
metadata = ", ".join(
f"{k}=METADATA${v}'" for k, v in self.include_metadata.items()
)
return f"INCLUDE_METADATA = ({metadata})"

@property
def fqn(self):
Expand Down Expand Up @@ -183,60 +193,68 @@ def bulk_insert(
)
return None

def copy_into(
def _copy(
self,
query: str,
query_args: dict,
path: str,
file_format: InlineFileFormat | FileFormat,
storage_integration: str | None = None,
match_by_column_name: MatchByColumnName = MatchByColumnName.CASE_INSENSITIVE,
full_refresh: bool = False,
target_columns: list[str] | None = None,
sync_tags: bool = False,
) -> None:
"""Copy files into Snowflake"""
with connect() as connection:
cursor = connection.cursor()
_execute_statement = partial(execute_statement, cursor)
if self.role is not None:
logging.debug(f"Using role: {self.role}")
_execute_statement(f"USE ROLE {self.role}")
if self.database is not None:
logging.debug(f"Using database: {self.database}")
_execute_statement(f"USE DATABASE {self.database}")

_execute_statement(self.get_create_schema_statement())
logging.debug(f"Using schema: {self.schema_}")
_execute_statement(f"USE SCHEMA {self.schema_}")
_execute_statement(
self.get_create_temporary_external_stage(
path=path, storage_integration=storage_integration
)
)

if isinstance(file_format, InlineFileFormat):
_execute_statement(
self.get_create_temporary_file_format_statement(
file_format=file_format.definition
)
)
file_format = self.temporary_file_format

_execute_statement(self.get_create_table_statement(full_refresh))
execute = self.setup_connection(path, storage_integration, cursor)
file_format = self.setup_file_format(file_format, execute)
self.create_table(full_refresh, execute)

if sync_tags and self.table_structure:
self.sync_tags(cursor)

logging.info(f"Starting copy into `{self.fqn}` from path '{path}'")
col_str = f"({', '.join(target_columns)})" if target_columns else ""
return _execute_statement(
f"""
query_args = query_args | {"file_format": file_format}
return execute(query.format(**query_args))

def copy_into(
self,
path: str,
file_format: InlineFileFormat | FileFormat,
storage_integration: str | None = None,
match_by_column_name: MatchByColumnName = MatchByColumnName.CASE_INSENSITIVE,
full_refresh: bool = False,
target_columns: list[str] | None = None,
sync_tags: bool = False,
) -> None:
col_str = f"({', '.join(target_columns)})" if target_columns else ""
return self._copy(
f"""
COPY INTO {self.fqn} {col_str}
FROM {path}
{f"STORAGE_INTEGRATION = {storage_integration}" if storage_integration else ''}
FILE_FORMAT = ( FORMAT_NAME ='{file_format}')
FILE_FORMAT = ( FORMAT_NAME ='{{file_format}}')
MATCH_BY_COLUMN_NAME={match_by_column_name.value}
"""
{self._include_metadata()}
""",
{},
path,
file_format,
storage_integration,
full_refresh,
sync_tags,
)

def create_table(self, full_refresh, _execute_statement):
_execute_statement(self.get_create_table_statement(full_refresh))

def setup_file_format(self, file_format, _execute_statement):
if isinstance(file_format, InlineFileFormat):
_execute_statement(
self.get_create_temporary_file_format_statement(
file_format=file_format.definition
)
)
file_format = self.temporary_file_format
return file_format

def get_columns(self, cursor: SnowflakeCursor) -> list[Column]:
data = cursor.execute(f"desc table {self.fqn}").fetchall()
Expand Down Expand Up @@ -428,6 +446,53 @@ def _unset_tag(self, cursor: SnowflakeCursor, column: str, tag: str):
f'ALTER TABLE {self.fqn} MODIFY COLUMN "{column.upper()}" UNSET TAG {governance_settings.fqn(tag)}'
)

def copy_custom(
self,
column_definitions: dict[str, str],
path: str,
file_format: InlineFileFormat | FileFormat,
storage_integration: str | None = None,
full_refresh: bool = False,
sync_tags: bool = False,
) -> None:
return self._copy(
f"""
COPY INTO {self.fqn} ({", ".join(column_definitions.keys())})
FROM @{self.temporary_stage}/
FILE_FORMAT = ( FORMAT_NAME ='{{file_format}}')
{self._include_metadata()}
""",
{},
path,
file_format,
storage_integration,
full_refresh,
sync_tags,
)

def setup_connection(
self, path: str, storage_integration: str, cursor: SnowflakeCursor
) -> callable:
"""Setup the connection including custom role, database, schema, and temporary stage"""
_execute_statement = partial(execute_statement, cursor)
if self.role is not None:
logging.debug(f"Using role: {self.role}")
_execute_statement(f"USE ROLE {self.role}")
if self.database is not None:
logging.debug(f"Using database: {self.database}")
_execute_statement(f"USE DATABASE {self.database}")

_execute_statement(self.get_create_schema_statement())
logging.debug(f"Using schema: {self.schema_}")
_execute_statement(f"USE SCHEMA {self.schema_}")
_execute_statement(
self.get_create_temporary_external_stage(
path=path, storage_integration=storage_integration
)
)

return _execute_statement


def _possibly_cast(s: str, old_column_type: str, new_column_type: str) -> str:
if old_column_type == "VARIANT" and new_column_type != "VARIANT":
Expand Down

0 comments on commit 21a68a1

Please sign in to comment.