From 7d879f0de92e7e2eccdb541369098e6783c2a294 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Tue, 25 Jun 2024 13:50:06 -0400 Subject: [PATCH] Add record/replay support. --- dbt/adapters/snowflake/connections.py | 40 ++++++++++++------ dbt/adapters/snowflake/record.py | 59 +++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 13 deletions(-) create mode 100644 dbt/adapters/snowflake/record.py diff --git a/dbt/adapters/snowflake/connections.py b/dbt/adapters/snowflake/connections.py index 4db007f19..c786167db 100644 --- a/dbt/adapters/snowflake/connections.py +++ b/dbt/adapters/snowflake/connections.py @@ -36,6 +36,7 @@ DbtConfigError, ) from dbt_common.exceptions import DbtDatabaseError +from dbt_common.record import get_record_mode_from_env, RecorderMode from dbt.adapters.exceptions.connection import FailedToConnectError from dbt.adapters.contracts.connection import AdapterResponse, Connection, Credentials from dbt.adapters.sql import SQLConnectionManager @@ -43,6 +44,7 @@ from dbt_common.events.functions import warn_or_error from dbt.adapters.events.types import AdapterEventWarning, AdapterEventError from dbt_common.ui import line_wrap_message, warning_tag +from dbt.adapters.snowflake.record import SnowflakeRecordReplayHandle if TYPE_CHECKING: import agate @@ -370,20 +372,32 @@ def connect(): if creds.query_tag: session_parameters.update({"QUERY_TAG": creds.query_tag}) + handle = None + + # In replay mode, we won't connect to a real database at all, while + # in record and diff modes we do, but insert an intermediate handle + # object which monitors native connection activity. + rec_mode = get_record_mode_from_env() + handle = None + if rec_mode != RecorderMode.REPLAY: + handle = snowflake.connector.connect( + account=creds.account, + database=creds.database, + schema=creds.schema, + warehouse=creds.warehouse, + role=creds.role, + autocommit=True, + client_session_keep_alive=creds.client_session_keep_alive, + application="dbt", + insecure_mode=creds.insecure_mode, + session_parameters=session_parameters, + **creds.auth_args(), + ) - handle = snowflake.connector.connect( - account=creds.account, - database=creds.database, - schema=creds.schema, - warehouse=creds.warehouse, - role=creds.role, - autocommit=True, - client_session_keep_alive=creds.client_session_keep_alive, - application="dbt", - insecure_mode=creds.insecure_mode, - session_parameters=session_parameters, - **creds.auth_args(), - ) + if rec_mode is not None: + # If using the record/replay mechanism, regardless of mode, we + # use a wrapper. + handle = SnowflakeRecordReplayHandle(handle, connection) return handle diff --git a/dbt/adapters/snowflake/record.py b/dbt/adapters/snowflake/record.py new file mode 100644 index 000000000..6d475b3e8 --- /dev/null +++ b/dbt/adapters/snowflake/record.py @@ -0,0 +1,59 @@ +import dataclasses +from typing import Optional + +from dbt.adapters.record import RecordReplayHandle, RecordReplayCursor +from dbt_common.record import record_function, Record, Recorder + + +class SnowflakeRecordReplayHandle(RecordReplayHandle): + def cursor(self): + cursor = None if self.native_handle is None else self.native_handle.cursor() + return SnowflakeRecordReplayCursor(cursor, self.connection) + + +@dataclasses.dataclass +class CursorGetSqlStateParams: + connection_name: str + + +@dataclasses.dataclass +class CursorGetSqlStateResult: + msg: Optional[str] + + +class CursorGetSqlStateRecord(Record): + params_cls = CursorGetSqlStateParams + result_cls = CursorGetSqlStateResult + + +Recorder.register_record_type(CursorGetSqlStateRecord) + + +@dataclasses.dataclass +class CursorGetSqfidParams: + connection_name: str + + +@dataclasses.dataclass +class CursorGetSqfidResult: + msg: Optional[str] + + +class CursorGetSqfidRecord(Record): + params_cls = CursorGetSqfidParams + result_cls = CursorGetSqfidResult + + +Recorder.register_record_type(CursorGetSqfidRecord) + + +class SnowflakeRecordReplayCursor(RecordReplayCursor): + @property + @record_function(CursorGetSqlStateRecord, method=True, id_field_name="connection_name") + def sqlstate(self): + return self.native_cursor.sqlstate + + @property + @record_function(CursorGetSqfidRecord, method=True, id_field_name="connection_name") + def sfqid(self): + return self.native_cursor.sfqid