From 74680c8eed14c1098c995fd76aa4d88254a5251b Mon Sep 17 00:00:00 2001 From: Jonas Kalderstam Date: Thu, 20 Oct 2022 15:33:23 +0200 Subject: [PATCH 1/5] Ran Black on files --- tap_postgres/__init__.py | 398 ++++-- tap_postgres/db.py | 182 ++- tap_postgres/discovery_utils.py | 381 ++--- tap_postgres/stream_utils.py | 88 +- tap_postgres/sync_strategies/common.py | 34 +- tap_postgres/sync_strategies/full_table.py | 129 +- tap_postgres/sync_strategies/incremental.py | 113 +- .../sync_strategies/logical_replication.py | 575 +++++--- .../test_clear_state_on_replication_change.py | 293 +++- tests/test_db.py | 197 +-- tests/test_discovery.py | 1242 ++++++++++++----- tests/test_full_table_interruption.py | 437 +++--- tests/test_logical_replication.py | 757 +++++----- tests/test_streams_utils.py | 170 ++- tests/test_unsupported_pk.py | 137 +- tests/utils.py | 161 ++- 16 files changed, 3413 insertions(+), 1881 deletions(-) diff --git a/tap_postgres/__init__.py b/tap_postgres/__init__.py index 9f69ace2..c986134b 100644 --- a/tap_postgres/__init__.py +++ b/tap_postgres/__init__.py @@ -18,18 +18,16 @@ from tap_postgres.sync_strategies import incremental from tap_postgres.discovery_utils import discover_db from tap_postgres.stream_utils import ( - dump_catalog, clear_state_on_replication_change, - is_selected_via_metadata, refresh_streams_schema, any_logical_streams) + dump_catalog, + clear_state_on_replication_change, + is_selected_via_metadata, + refresh_streams_schema, + any_logical_streams, +) -LOGGER = singer.get_logger('tap_postgres') +LOGGER = singer.get_logger("tap_postgres") -REQUIRED_CONFIG_KEYS = [ - 'dbname', - 'host', - 'port', - 'user', - 'password' -] +REQUIRED_CONFIG_KEYS = ["dbname", "host", "port", "user", "password"] def do_discovery(conn_config): @@ -41,11 +39,11 @@ def do_discovery(conn_config): Returns: list of discovered streams """ with post_db.open_connection(conn_config) as conn: - LOGGER.info("Discovering db %s", conn_config['dbname']) - streams = discover_db(conn, conn_config.get('filter_schemas')) + LOGGER.info("Discovering db %s", conn_config["dbname"]) + streams = discover_db(conn, conn_config.get("filter_schemas")) if len(streams) == 0: - raise RuntimeError('0 tables were discovered across the entire cluster') + raise RuntimeError("0 tables were discovered across the entire cluster") dump_catalog(streams) return streams @@ -55,12 +53,16 @@ def do_sync_full_table(conn_config, stream, state, desired_columns, md_map): """ Runs full table sync """ - LOGGER.info("Stream %s is using full_table replication", stream['tap_stream_id']) + LOGGER.info("Stream %s is using full_table replication", stream["tap_stream_id"]) sync_common.send_schema_message(stream, []) - if md_map.get((), {}).get('is-view'): - state = full_table.sync_view(conn_config, stream, state, desired_columns, md_map) + if md_map.get((), {}).get("is-view"): + state = full_table.sync_view( + conn_config, stream, state, desired_columns, md_map + ) else: - state = full_table.sync_table(conn_config, stream, state, desired_columns, md_map) + state = full_table.sync_table( + conn_config, stream, state, desired_columns, md_map + ) return state @@ -69,18 +71,28 @@ def do_sync_incremental(conn_config, stream, state, desired_columns, md_map): """ Runs Incremental sync """ - replication_key = md_map.get((), {}).get('replication-key') - LOGGER.info("Stream %s is using incremental replication with replication key %s", - stream['tap_stream_id'], - replication_key) - - stream_state = state.get('bookmarks', {}).get(stream['tap_stream_id']) + replication_key = md_map.get((), {}).get("replication-key") + LOGGER.info( + "Stream %s is using incremental replication with replication key %s", + stream["tap_stream_id"], + replication_key, + ) + + stream_state = state.get("bookmarks", {}).get(stream["tap_stream_id"]) illegal_bk_keys = set(stream_state.keys()).difference( - {'replication_key', 'replication_key_value', 'version', 'last_replication_method'}) + { + "replication_key", + "replication_key_value", + "version", + "last_replication_method", + } + ) if len(illegal_bk_keys) != 0: raise Exception("invalid keys found in state: {}".format(illegal_bk_keys)) - state = singer.write_bookmark(state, stream['tap_stream_id'], 'replication_key', replication_key) + state = singer.write_bookmark( + state, stream["tap_stream_id"], "replication_key", replication_key + ) sync_common.send_schema_message(stream, [replication_key]) state = incremental.sync_table(conn_config, stream, state, desired_columns, md_map) @@ -97,55 +109,76 @@ def sync_method_for_streams(streams, state, default_replication_method): logical_streams = [] for stream in streams: - stream_metadata = metadata.to_map(stream['metadata']) - replication_method = stream_metadata.get((), {}).get('replication-method', default_replication_method) - replication_key = stream_metadata.get((), {}).get('replication-key') - - state = clear_state_on_replication_change(state, stream['tap_stream_id'], replication_key, replication_method) - - if replication_method not in {'LOG_BASED', 'FULL_TABLE', 'INCREMENTAL'}: - raise Exception("Unrecognized replication_method {}".format(replication_method)) - - md_map = metadata.to_map(stream['metadata']) - desired_columns = [c for c in stream['schema']['properties'].keys() if - sync_common.should_sync_column(md_map, c)] + stream_metadata = metadata.to_map(stream["metadata"]) + replication_method = stream_metadata.get((), {}).get( + "replication-method", default_replication_method + ) + replication_key = stream_metadata.get((), {}).get("replication-key") + + state = clear_state_on_replication_change( + state, stream["tap_stream_id"], replication_key, replication_method + ) + + if replication_method not in {"LOG_BASED", "FULL_TABLE", "INCREMENTAL"}: + raise Exception( + "Unrecognized replication_method {}".format(replication_method) + ) + + md_map = metadata.to_map(stream["metadata"]) + desired_columns = [ + c + for c in stream["schema"]["properties"].keys() + if sync_common.should_sync_column(md_map, c) + ] desired_columns.sort() if len(desired_columns) == 0: - LOGGER.warning('There are no columns selected for stream %s, skipping it', stream['tap_stream_id']) + LOGGER.warning( + "There are no columns selected for stream %s, skipping it", + stream["tap_stream_id"], + ) continue - if replication_method == 'LOG_BASED' and stream_metadata.get((), {}).get('is-view'): - raise Exception(f'Logical Replication is NOT supported for views. ' \ - f'Please change the replication method for {stream["tap_stream_id"]}') + if replication_method == "LOG_BASED" and stream_metadata.get((), {}).get( + "is-view" + ): + raise Exception( + f"Logical Replication is NOT supported for views. " + f'Please change the replication method for {stream["tap_stream_id"]}' + ) - if replication_method == 'FULL_TABLE': - lookup[stream['tap_stream_id']] = 'full' + if replication_method == "FULL_TABLE": + lookup[stream["tap_stream_id"]] = "full" traditional_steams.append(stream) - elif replication_method == 'INCREMENTAL': - lookup[stream['tap_stream_id']] = 'incremental' + elif replication_method == "INCREMENTAL": + lookup[stream["tap_stream_id"]] = "incremental" traditional_steams.append(stream) - elif get_bookmark(state, stream['tap_stream_id'], 'xmin') and \ - get_bookmark(state, stream['tap_stream_id'], 'lsn'): + elif get_bookmark(state, stream["tap_stream_id"], "xmin") and get_bookmark( + state, stream["tap_stream_id"], "lsn" + ): # finishing previously interrupted full-table (first stage of logical replication) - lookup[stream['tap_stream_id']] = 'logical_initial_interrupted' + lookup[stream["tap_stream_id"]] = "logical_initial_interrupted" traditional_steams.append(stream) # inconsistent state - elif get_bookmark(state, stream['tap_stream_id'], 'xmin') and \ - not get_bookmark(state, stream['tap_stream_id'], 'lsn'): - raise Exception("Xmin found(%s) in state implying full-table replication but no lsn is present") - - elif not get_bookmark(state, stream['tap_stream_id'], 'xmin') and \ - not get_bookmark(state, stream['tap_stream_id'], 'lsn'): + elif get_bookmark(state, stream["tap_stream_id"], "xmin") and not get_bookmark( + state, stream["tap_stream_id"], "lsn" + ): + raise Exception( + "Xmin found(%s) in state implying full-table replication but no lsn is present" + ) + + elif not get_bookmark( + state, stream["tap_stream_id"], "xmin" + ) and not get_bookmark(state, stream["tap_stream_id"], "lsn"): # initial full-table phase of logical replication - lookup[stream['tap_stream_id']] = 'logical_initial' + lookup[stream["tap_stream_id"]] = "logical_initial" traditional_steams.append(stream) else: # no xmin but we have an lsn # initial stage of logical replication(full-table) has been completed. moving onto pure logical replication - lookup[stream['tap_stream_id']] = 'pure_logical' + lookup[stream["tap_stream_id"]] = "pure_logical" logical_streams.append(stream) return lookup, traditional_steams, logical_streams @@ -155,39 +188,58 @@ def sync_traditional_stream(conn_config, stream, state, sync_method, end_lsn): """ Sync INCREMENTAL and FULL_TABLE streams """ - LOGGER.info("Beginning sync of stream(%s) with sync method(%s)", stream['tap_stream_id'], sync_method) - md_map = metadata.to_map(stream['metadata']) - conn_config['dbname'] = md_map.get(()).get('database-name') - desired_columns = [c for c in stream['schema']['properties'].keys() if sync_common.should_sync_column(md_map, c)] + LOGGER.info( + "Beginning sync of stream(%s) with sync method(%s)", + stream["tap_stream_id"], + sync_method, + ) + md_map = metadata.to_map(stream["metadata"]) + conn_config["dbname"] = md_map.get(()).get("database-name") + desired_columns = [ + c + for c in stream["schema"]["properties"].keys() + if sync_common.should_sync_column(md_map, c) + ] desired_columns.sort() if len(desired_columns) == 0: - LOGGER.warning('There are no columns selected for stream %s, skipping it', stream['tap_stream_id']) + LOGGER.warning( + "There are no columns selected for stream %s, skipping it", + stream["tap_stream_id"], + ) return state register_type_adapters(conn_config) - if sync_method == 'full': - state = singer.set_currently_syncing(state, stream['tap_stream_id']) + if sync_method == "full": + state = singer.set_currently_syncing(state, stream["tap_stream_id"]) state = do_sync_full_table(conn_config, stream, state, desired_columns, md_map) - elif sync_method == 'incremental': - state = singer.set_currently_syncing(state, stream['tap_stream_id']) + elif sync_method == "incremental": + state = singer.set_currently_syncing(state, stream["tap_stream_id"]) state = do_sync_incremental(conn_config, stream, state, desired_columns, md_map) - elif sync_method == 'logical_initial': - state = singer.set_currently_syncing(state, stream['tap_stream_id']) + elif sync_method == "logical_initial": + state = singer.set_currently_syncing(state, stream["tap_stream_id"]) LOGGER.info("Performing initial full table sync") - state = singer.write_bookmark(state, stream['tap_stream_id'], 'lsn', end_lsn) + state = singer.write_bookmark(state, stream["tap_stream_id"], "lsn", end_lsn) sync_common.send_schema_message(stream, []) - state = full_table.sync_table(conn_config, stream, state, desired_columns, md_map) - state = singer.write_bookmark(state, stream['tap_stream_id'], 'xmin', None) - elif sync_method == 'logical_initial_interrupted': - state = singer.set_currently_syncing(state, stream['tap_stream_id']) + state = full_table.sync_table( + conn_config, stream, state, desired_columns, md_map + ) + state = singer.write_bookmark(state, stream["tap_stream_id"], "xmin", None) + elif sync_method == "logical_initial_interrupted": + state = singer.set_currently_syncing(state, stream["tap_stream_id"]) LOGGER.info("Initial stage of full table sync was interrupted. resuming...") sync_common.send_schema_message(stream, []) - state = full_table.sync_table(conn_config, stream, state, desired_columns, md_map) + state = full_table.sync_table( + conn_config, stream, state, desired_columns, md_map + ) else: - raise Exception("unknown sync method {} for stream {}".format(sync_method, stream['tap_stream_id'])) + raise Exception( + "unknown sync method {} for stream {}".format( + sync_method, stream["tap_stream_id"] + ) + ) state = singer.set_currently_syncing(state, None) singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) @@ -199,26 +251,39 @@ def sync_logical_streams(conn_config, logical_streams, state, end_lsn, state_fil Sync streams that use LOG_BASED method """ if logical_streams: - LOGGER.info("Pure Logical Replication upto lsn %s for (%s)", end_lsn, - [s['tap_stream_id'] for s in logical_streams]) - - logical_streams = [logical_replication.add_automatic_properties( - s, conn_config.get('debug_lsn', False)) for s in logical_streams] + LOGGER.info( + "Pure Logical Replication upto lsn %s for (%s)", + end_lsn, + [s["tap_stream_id"] for s in logical_streams], + ) + + logical_streams = [ + logical_replication.add_automatic_properties( + s, conn_config.get("debug_lsn", False) + ) + for s in logical_streams + ] # Remove LOG_BASED stream bookmarks from state if it has been de-selected # This is to avoid sending very old starting and flushing positions to source selected_streams = set() for stream in logical_streams: - selected_streams.add("{}".format(stream['tap_stream_id'])) + selected_streams.add("{}".format(stream["tap_stream_id"])) - new_state = dict(currently_syncing=state['currently_syncing'], bookmarks={}) + new_state = dict(currently_syncing=state["currently_syncing"], bookmarks={}) - for stream, bookmark in state['bookmarks'].items(): - if bookmark == {} or bookmark['last_replication_method'] != 'LOG_BASED' or stream in selected_streams: - new_state['bookmarks'][stream] = bookmark + for stream, bookmark in state["bookmarks"].items(): + if ( + bookmark == {} + or bookmark["last_replication_method"] != "LOG_BASED" + or stream in selected_streams + ): + new_state["bookmarks"][stream] = bookmark state = new_state - state = logical_replication.sync_tables(conn_config, logical_streams, state, end_lsn, state_file) + state = logical_replication.sync_tables( + conn_config, logical_streams, state, end_lsn, state_file + ) return state @@ -235,28 +300,36 @@ def register_type_adapters(conn_config): if citext_array_oid: psycopg2.extensions.register_type( psycopg2.extensions.new_array_type( - (citext_array_oid[0],), 'CITEXT[]', psycopg2.STRING)) + (citext_array_oid[0],), "CITEXT[]", psycopg2.STRING + ) + ) # bit[] cur.execute("SELECT typarray FROM pg_type where typname = 'bit'") bit_array_oid = cur.fetchone()[0] psycopg2.extensions.register_type( psycopg2.extensions.new_array_type( - (bit_array_oid,), 'BIT[]', psycopg2.STRING)) + (bit_array_oid,), "BIT[]", psycopg2.STRING + ) + ) # UUID[] cur.execute("SELECT typarray FROM pg_type where typname = 'uuid'") uuid_array_oid = cur.fetchone()[0] psycopg2.extensions.register_type( psycopg2.extensions.new_array_type( - (uuid_array_oid,), 'UUID[]', psycopg2.STRING)) + (uuid_array_oid,), "UUID[]", psycopg2.STRING + ) + ) # money[] cur.execute("SELECT typarray FROM pg_type where typname = 'money'") money_array_oid = cur.fetchone()[0] psycopg2.extensions.register_type( psycopg2.extensions.new_array_type( - (money_array_oid,), 'MONEY[]', psycopg2.STRING)) + (money_array_oid,), "MONEY[]", psycopg2.STRING + ) + ) # json and jsonb # pylint: disable=unnecessary-lambda @@ -264,12 +337,16 @@ def register_type_adapters(conn_config): psycopg2.extras.register_default_jsonb(loads=lambda x: str(x)) # enum[]'s - cur.execute("SELECT distinct(t.typarray) FROM pg_type t JOIN pg_enum e ON t.oid = e.enumtypid") + cur.execute( + "SELECT distinct(t.typarray) FROM pg_type t JOIN pg_enum e ON t.oid = e.enumtypid" + ) for oid in cur.fetchall(): enum_oid = oid[0] psycopg2.extensions.register_type( psycopg2.extensions.new_array_type( - (enum_oid,), 'ENUM_{}[]'.format(enum_oid), psycopg2.STRING)) + (enum_oid,), "ENUM_{}[]".format(enum_oid), psycopg2.STRING + ) + ) def do_sync(conn_config, catalog, default_replication_method, state, state_file=None): @@ -277,9 +354,9 @@ def do_sync(conn_config, catalog, default_replication_method, state, state_file= Orchestrates sync of all streams """ currently_syncing = singer.get_currently_syncing(state) - streams = list(filter(is_selected_via_metadata, catalog['streams'])) - streams.sort(key=lambda s: s['tap_stream_id']) - LOGGER.info("Selected streams: %s ", [s['tap_stream_id'] for s in streams]) + streams = list(filter(is_selected_via_metadata, catalog["streams"])) + streams.sort(key=lambda s: s["tap_stream_id"]) + LOGGER.info("Selected streams: %s ", [s["tap_stream_id"] for s in streams]) if any_logical_streams(streams, default_replication_method): # Use of logical replication requires fetching an lsn end_lsn = logical_replication.fetch_current_lsn(conn_config) @@ -289,37 +366,56 @@ def do_sync(conn_config, catalog, default_replication_method, state, state_file= refresh_streams_schema(conn_config, streams) - sync_method_lookup, traditional_streams, logical_streams = \ - sync_method_for_streams(streams, state, default_replication_method) + sync_method_lookup, traditional_streams, logical_streams = sync_method_for_streams( + streams, state, default_replication_method + ) if currently_syncing: LOGGER.debug("Found currently_syncing: %s", currently_syncing) - currently_syncing_stream = list(filter(lambda s: s['tap_stream_id'] == currently_syncing, traditional_streams)) + currently_syncing_stream = list( + filter( + lambda s: s["tap_stream_id"] == currently_syncing, traditional_streams + ) + ) if not currently_syncing_stream: - LOGGER.warning("unable to locate currently_syncing(%s) amongst selected traditional streams(%s). " - "Will ignore", - currently_syncing, - {s['tap_stream_id'] for s in traditional_streams}) - - other_streams = list(filter(lambda s: s['tap_stream_id'] != currently_syncing, traditional_streams)) + LOGGER.warning( + "unable to locate currently_syncing(%s) amongst selected traditional streams(%s). " + "Will ignore", + currently_syncing, + {s["tap_stream_id"] for s in traditional_streams}, + ) + + other_streams = list( + filter( + lambda s: s["tap_stream_id"] != currently_syncing, traditional_streams + ) + ) traditional_streams = currently_syncing_stream + other_streams else: LOGGER.info("No streams marked as currently_syncing in state file") for stream in traditional_streams: - state = sync_traditional_stream(conn_config, - stream, - state, - sync_method_lookup[stream['tap_stream_id']], - end_lsn) - - logical_streams.sort(key=lambda s: metadata.to_map(s['metadata']).get(()).get('database-name')) - for dbname, streams in itertools.groupby(logical_streams, - lambda s: metadata.to_map(s['metadata']).get(()).get('database-name')): - conn_config['dbname'] = dbname - state = sync_logical_streams(conn_config, list(streams), state, end_lsn, state_file) + state = sync_traditional_stream( + conn_config, + stream, + state, + sync_method_lookup[stream["tap_stream_id"]], + end_lsn, + ) + + logical_streams.sort( + key=lambda s: metadata.to_map(s["metadata"]).get(()).get("database-name") + ) + for dbname, streams in itertools.groupby( + logical_streams, + lambda s: metadata.to_map(s["metadata"]).get(()).get("database-name"), + ): + conn_config["dbname"] = dbname + state = sync_logical_streams( + conn_config, list(streams), state, end_lsn, state_file + ) return state @@ -341,44 +437,38 @@ def parse_args(required_config_keys): load and parse the JSON file.""" parser = argparse.ArgumentParser() - parser.add_argument( - '-c', '--config', - help='Config file', - required=True) + parser.add_argument("-c", "--config", help="Config file", required=True) - parser.add_argument( - '-s', '--state', - help='state file') + parser.add_argument("-s", "--state", help="state file") parser.add_argument( - '-p', '--properties', - help='Property selections: DEPRECATED, Please use --catalog instead') + "-p", + "--properties", + help="Property selections: DEPRECATED, Please use --catalog instead", + ) - parser.add_argument( - '--catalog', - help='Catalog file') + parser.add_argument("--catalog", help="Catalog file") parser.add_argument( - '-d', '--discover', - action='store_true', - help='Do schema discovery') + "-d", "--discover", action="store_true", help="Do schema discovery" + ) args = parser.parse_args() if args.config: - setattr(args, 'config_path', args.config) + setattr(args, "config_path", args.config) args.config = utils.load_json(args.config) if args.state: - setattr(args, 'state_path', args.state) + setattr(args, "state_path", args.state) args.state_file = args.state args.state = utils.load_json(args.state) else: args.state_file = None args.state = {} if args.properties: - setattr(args, 'properties_path', args.properties) + setattr(args, "properties_path", args.properties) args.properties = utils.load_json(args.properties) if args.catalog: - setattr(args, 'catalog_path', args.catalog) + setattr(args, "catalog_path", args.catalog) args.catalog = Catalog.load(args.catalog) utils.check_config(args.config, required_config_keys) @@ -393,33 +483,41 @@ def main_impl(): args = parse_args(REQUIRED_CONFIG_KEYS) conn_config = { # Required config keys - 'host': args.config['host'], - 'user': args.config['user'], - 'password': args.config['password'], - 'port': args.config['port'], - 'dbname': args.config['dbname'], - + "host": args.config["host"], + "user": args.config["user"], + "password": args.config["password"], + "port": args.config["port"], + "dbname": args.config["dbname"], # Optional config keys - 'tap_id': args.config.get('tap_id'), - 'filter_schemas': args.config.get('filter_schemas'), - 'debug_lsn': args.config.get('debug_lsn') == 'true', - 'max_run_seconds': args.config.get('max_run_seconds', 43200), - 'break_at_end_lsn': args.config.get('break_at_end_lsn', True), - 'logical_poll_total_seconds': float(args.config.get('logical_poll_total_seconds', 0)) + "tap_id": args.config.get("tap_id"), + "filter_schemas": args.config.get("filter_schemas"), + "debug_lsn": args.config.get("debug_lsn") == "true", + "max_run_seconds": args.config.get("max_run_seconds", 43200), + "break_at_end_lsn": args.config.get("break_at_end_lsn", True), + "logical_poll_total_seconds": float( + args.config.get("logical_poll_total_seconds", 0) + ), } - if args.config.get('ssl') == 'true': - conn_config['sslmode'] = 'require' + if args.config.get("ssl") == "true": + conn_config["sslmode"] = "require" - post_db.CURSOR_ITER_SIZE = int(args.config.get('itersize', post_db.CURSOR_ITER_SIZE)) + post_db.CURSOR_ITER_SIZE = int( + args.config.get("itersize", post_db.CURSOR_ITER_SIZE) + ) if args.discover: do_discovery(conn_config) elif args.properties or args.catalog: state = args.state state_file = args.state_file - do_sync(conn_config, args.catalog.to_dict() if args.catalog else args.properties, - args.config.get('default_replication_method'), state, state_file) + do_sync( + conn_config, + args.catalog.to_dict() if args.catalog else args.properties, + args.config.get("default_replication_method"), + state, + state_file, + ) else: LOGGER.info("No properties were selected") diff --git a/tap_postgres/db.py b/tap_postgres/db.py index 591323cd..e53693df 100644 --- a/tap_postgres/db.py +++ b/tap_postgres/db.py @@ -11,14 +11,14 @@ from typing import List from dateutil.parser import parse -LOGGER = singer.get_logger('tap_postgres') +LOGGER = singer.get_logger("tap_postgres") CURSOR_ITER_SIZE = 20000 # pylint: disable=invalid-name,missing-function-docstring def calculate_destination_stream_name(stream, md_map): - return "{}-{}".format(md_map.get((), {}).get('schema-name'), stream['stream']) + return "{}-{}".format(md_map.get((), {}).get("schema-name"), stream["stream"]) # from the postgres docs: @@ -29,103 +29,119 @@ def canonicalize_identifier(identifier): def fully_qualified_column_name(schema, table, column): - return '"{}"."{}"."{}"'.format(canonicalize_identifier(schema), - canonicalize_identifier(table), - canonicalize_identifier(column)) + return '"{}"."{}"."{}"'.format( + canonicalize_identifier(schema), + canonicalize_identifier(table), + canonicalize_identifier(column), + ) def fully_qualified_table_name(schema, table): - return '"{}"."{}"'.format(canonicalize_identifier(schema), canonicalize_identifier(table)) + return '"{}"."{}"'.format( + canonicalize_identifier(schema), canonicalize_identifier(table) + ) def open_connection(conn_config, logical_replication=False): cfg = { - 'application_name': 'pipelinewise', - 'host': conn_config['host'], - 'dbname': conn_config['dbname'], - 'user': conn_config['user'], - 'password': conn_config['password'], - 'port': conn_config['port'], - 'connect_timeout': 30 + "application_name": "pipelinewise", + "host": conn_config["host"], + "dbname": conn_config["dbname"], + "user": conn_config["user"], + "password": conn_config["password"], + "port": conn_config["port"], + "connect_timeout": 30, } - if conn_config.get('sslmode'): - cfg['sslmode'] = conn_config['sslmode'] + if conn_config.get("sslmode"): + cfg["sslmode"] = conn_config["sslmode"] if logical_replication: - cfg['connection_factory'] = psycopg2.extras.LogicalReplicationConnection + cfg["connection_factory"] = psycopg2.extras.LogicalReplicationConnection conn = psycopg2.connect(**cfg) return conn + def prepare_columns_for_select_sql(c, md_map): column_name = ' "{}" '.format(canonicalize_identifier(c)) - if ('properties', c) in md_map: - sql_datatype = md_map[('properties', c)]['sql-datatype'] - if sql_datatype.startswith('timestamp') and not sql_datatype.endswith('[]'): - return f'CASE ' \ - f'WHEN {column_name} < \'0001-01-01 00:00:00.000\' ' \ - f'OR {column_name} > \'9999-12-31 23:59:59.999\' THEN \'9999-12-31 23:59:59.999\' ' \ - f'ELSE {column_name} ' \ - f'END AS {column_name}' + if ("properties", c) in md_map: + sql_datatype = md_map[("properties", c)]["sql-datatype"] + if sql_datatype.startswith("timestamp") and not sql_datatype.endswith("[]"): + return ( + f"CASE " + f"WHEN {column_name} < '0001-01-01 00:00:00.000' " + f"OR {column_name} > '9999-12-31 23:59:59.999' THEN '9999-12-31 23:59:59.999' " + f"ELSE {column_name} " + f"END AS {column_name}" + ) return column_name + def prepare_columns_sql(c): column_name = """ "{}" """.format(canonicalize_identifier(c)) return column_name def filter_dbs_sql_clause(sql, filter_dbs): - in_clause = " AND datname in (" + ",".join(["'{}'".format(b.strip(' ')) for b in filter_dbs.split(',')]) + ")" + in_clause = ( + " AND datname in (" + + ",".join(["'{}'".format(b.strip(" ")) for b in filter_dbs.split(",")]) + + ")" + ) return sql + in_clause def filter_schemas_sql_clause(sql, filer_schemas): - in_clause = " AND n.nspname in (" + ",".join(["'{}'".format(b.strip(' ')) for b in filer_schemas.split(',')]) + ")" + in_clause = ( + " AND n.nspname in (" + + ",".join(["'{}'".format(b.strip(" ")) for b in filer_schemas.split(",")]) + + ")" + ) return sql + in_clause # pylint: disable=too-many-branches,too-many-nested-blocks,too-many-statements def selected_value_to_singer_value_impl(elem, sql_datatype): - sql_datatype = sql_datatype.replace('[]', '') + sql_datatype = sql_datatype.replace("[]", "") if elem is None: cleaned_elem = elem - elif sql_datatype == 'money': + elif sql_datatype == "money": cleaned_elem = elem - elif sql_datatype in ['json', 'jsonb']: + elif sql_datatype in ["json", "jsonb"]: cleaned_elem = json.loads(elem) - elif sql_datatype == 'time with time zone': + elif sql_datatype == "time with time zone": # time with time zone values will be converted to UTC and time zone dropped # Replace hour=24 with hour=0 elem = str(elem) - if elem.startswith('24'): - elem = elem.replace('24', '00', 1) + if elem.startswith("24"): + elem = elem.replace("24", "00", 1) # convert to UTC - elem = datetime.datetime.strptime(elem, '%H:%M:%S%z') + elem = datetime.datetime.strptime(elem, "%H:%M:%S%z") if elem.utcoffset() != datetime.timedelta(seconds=0): - LOGGER.warning('time with time zone values are converted to UTC') + LOGGER.warning("time with time zone values are converted to UTC") elem = elem.astimezone(pytz.utc) # drop time zone - elem = str(elem.strftime('%H:%M:%S')) - cleaned_elem = parse(elem).isoformat().split('T')[1] - elif sql_datatype == 'time without time zone': + elem = str(elem.strftime("%H:%M:%S")) + cleaned_elem = parse(elem).isoformat().split("T")[1] + elif sql_datatype == "time without time zone": # Replace hour=24 with hour=0 elem = str(elem) - if elem.startswith('24'): - elem = elem.replace('24', '00', 1) - cleaned_elem = parse(elem).isoformat().split('T')[1] + if elem.startswith("24"): + elem = elem.replace("24", "00", 1) + cleaned_elem = parse(elem).isoformat().split("T")[1] elif isinstance(elem, datetime.datetime): - if sql_datatype == 'timestamp with time zone': + if sql_datatype == "timestamp with time zone": cleaned_elem = elem.isoformat() else: # timestamp WITH OUT time zone - cleaned_elem = elem.isoformat() + '+00:00' + cleaned_elem = elem.isoformat() + "+00:00" elif isinstance(elem, datetime.date): - cleaned_elem = elem.isoformat() + 'T00:00:00+00:00' - elif sql_datatype == 'bit': - cleaned_elem = elem == '1' - elif sql_datatype == 'boolean': + cleaned_elem = elem.isoformat() + "T00:00:00+00:00" + elif sql_datatype == "bit": + cleaned_elem = elem == "1" + elif sql_datatype == "boolean": cleaned_elem = elem elif isinstance(elem, int): cleaned_elem = elem @@ -149,38 +165,53 @@ def selected_value_to_singer_value_impl(elem, sql_datatype): else: cleaned_elem = elem elif isinstance(elem, dict): - if sql_datatype == 'hstore': + if sql_datatype == "hstore": cleaned_elem = elem else: - raise Exception("do not know how to marshall a dict if its not an hstore or json: {}".format(sql_datatype)) + raise Exception( + "do not know how to marshall a dict if its not an hstore or json: {}".format( + sql_datatype + ) + ) else: raise Exception( - "do not know how to marshall value of class( {} ) and sql_datatype ( {} )".format(elem.__class__, - sql_datatype)) + "do not know how to marshall value of class( {} ) and sql_datatype ( {} )".format( + elem.__class__, sql_datatype + ) + ) return cleaned_elem def selected_array_to_singer_value(elem, sql_datatype): if isinstance(elem, list): - return list(map(lambda elem: selected_array_to_singer_value(elem, sql_datatype), elem)) + return list( + map(lambda elem: selected_array_to_singer_value(elem, sql_datatype), elem) + ) return selected_value_to_singer_value_impl(elem, sql_datatype) def selected_value_to_singer_value(elem, sql_datatype): # are we dealing with an array? - if sql_datatype.find('[]') > 0: - return list(map(lambda elem: selected_array_to_singer_value(elem, sql_datatype), (elem or []))) + if sql_datatype.find("[]") > 0: + return list( + map( + lambda elem: selected_array_to_singer_value(elem, sql_datatype), + (elem or []), + ) + ) return selected_value_to_singer_value_impl(elem, sql_datatype) # pylint: disable=too-many-arguments -def selected_row_to_singer_message(stream, row, version, columns, time_extracted, md_map): +def selected_row_to_singer_message( + stream, row, version, columns, time_extracted, md_map +): row_to_persist = () for idx, elem in enumerate(row): - sql_datatype = md_map.get(('properties', columns[idx]))['sql-datatype'] + sql_datatype = md_map.get(("properties", columns[idx]))["sql-datatype"] cleaned_elem = selected_value_to_singer_value(elem, sql_datatype) row_to_persist += (cleaned_elem,) @@ -190,13 +221,18 @@ def selected_row_to_singer_message(stream, row, version, columns, time_extracted stream=calculate_destination_stream_name(stream, md_map), record=rec, version=version, - time_extracted=time_extracted) + time_extracted=time_extracted, + ) def hstore_available(conn_info): with open_connection(conn_info) as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='stitch_cursor') as cur: - cur.execute(""" SELECT installed_version FROM pg_available_extensions WHERE name = 'hstore' """) + with conn.cursor( + cursor_factory=psycopg2.extras.DictCursor, name="stitch_cursor" + ) as cur: + cur.execute( + """ SELECT installed_version FROM pg_available_extensions WHERE name = 'hstore' """ + ) res = cur.fetchone() if res and res[0]: return True @@ -204,7 +240,7 @@ def hstore_available(conn_info): def compute_tap_stream_id(schema_name, table_name): - return schema_name + '-' + table_name + return schema_name + "-" + table_name # NB> numeric/decimal columns in postgres without a specified scale && precision @@ -220,7 +256,7 @@ def numeric_precision(c): return MAX_PRECISION if c.numeric_precision > MAX_PRECISION: - LOGGER.warning('capping decimal precision to 100. THIS MAY CAUSE TRUNCATION') + LOGGER.warning("capping decimal precision to 100. THIS MAY CAUSE TRUNCATION") return MAX_PRECISION return c.numeric_precision @@ -230,7 +266,7 @@ def numeric_scale(c): if c.numeric_scale is None: return MAX_SCALE if c.numeric_scale > MAX_SCALE: - LOGGER.warning('capping decimal scale to 38. THIS MAY CAUSE TRUNCATION') + LOGGER.warning("capping decimal scale to 38. THIS MAY CAUSE TRUNCATION") return MAX_SCALE return c.numeric_scale @@ -245,28 +281,38 @@ def numeric_max(precision, scale): def numeric_min(precision, scale): - return -10 ** (precision - scale) + return -(10 ** (precision - scale)) def filter_tables_sql_clause(sql, tables: List[str]): - in_clause = " AND pg_class.relname in (" + ",".join(["'{}'".format(b.strip(' ')) for b in tables]) + ")" + in_clause = ( + " AND pg_class.relname in (" + + ",".join(["'{}'".format(b.strip(" ")) for b in tables]) + + ")" + ) return sql + in_clause + def get_database_name(connection): cur = connection.cursor() rows = cur.execute("SELECT name FROM v$database").fetchall() return rows[0][0] + def attempt_connection_to_db(conn_config, dbname): nascent_config = copy.deepcopy(conn_config) - nascent_config['dbname'] = dbname - LOGGER.info('(%s) Testing connectivity...', dbname) + nascent_config["dbname"] = dbname + LOGGER.info("(%s) Testing connectivity...", dbname) try: conn = open_connection(nascent_config) - LOGGER.info('(%s) connectivity verified', dbname) + LOGGER.info("(%s) connectivity verified", dbname) conn.close() return True except Exception as err: - LOGGER.warning('Unable to connect to %s. This maybe harmless if you ' - 'have not desire to replicate from this database: "%s"', dbname, err) + LOGGER.warning( + "Unable to connect to %s. This maybe harmless if you " + 'have not desire to replicate from this database: "%s"', + dbname, + err, + ) return False diff --git a/tap_postgres/discovery_utils.py b/tap_postgres/discovery_utils.py index 1e9891c2..672ee01f 100644 --- a/tap_postgres/discovery_utils.py +++ b/tap_postgres/discovery_utils.py @@ -8,35 +8,49 @@ import tap_postgres.db as post_db # LogMiner do not support LONG, LONG RAW, CLOB, BLOB, NCLOB, ADT, or COLLECTION datatypes. -Column = collections.namedtuple('Column', [ - "column_name", - "is_primary_key", - "sql_data_type", - "character_maximum_length", - "numeric_precision", - "numeric_scale", - "is_array", - "is_enum" - -]) - -INTEGER_TYPES = {'integer', 'smallint', 'bigint'} -FLOAT_TYPES = {'real', 'double precision'} -JSON_TYPES = {'json', 'jsonb'} +Column = collections.namedtuple( + "Column", + [ + "column_name", + "is_primary_key", + "sql_data_type", + "character_maximum_length", + "numeric_precision", + "numeric_scale", + "is_array", + "is_enum", + ], +) + +INTEGER_TYPES = {"integer", "smallint", "bigint"} +FLOAT_TYPES = {"real", "double precision"} +JSON_TYPES = {"json", "jsonb"} BASE_RECURSIVE_SCHEMAS = { - 'sdc_recursive_integer_array': {'type': ['null', 'integer', 'array'], - 'items': {'$ref': '#/definitions/sdc_recursive_integer_array'}}, - 'sdc_recursive_number_array': {'type': ['null', 'number', 'array'], - 'items': {'$ref': '#/definitions/sdc_recursive_number_array'}}, - 'sdc_recursive_string_array': {'type': ['null', 'string', 'array'], - 'items': {'$ref': '#/definitions/sdc_recursive_string_array'}}, - 'sdc_recursive_boolean_array': {'type': ['null', 'boolean', 'array'], - 'items': {'$ref': '#/definitions/sdc_recursive_boolean_array'}}, - 'sdc_recursive_timestamp_array': {'type': ['null', 'string', 'array'], - 'format': 'date-time', - 'items': {'$ref': '#/definitions/sdc_recursive_timestamp_array'}}, - 'sdc_recursive_object_array': {'type': ['null', 'object', 'array'], - 'items': {'$ref': '#/definitions/sdc_recursive_object_array'}} + "sdc_recursive_integer_array": { + "type": ["null", "integer", "array"], + "items": {"$ref": "#/definitions/sdc_recursive_integer_array"}, + }, + "sdc_recursive_number_array": { + "type": ["null", "number", "array"], + "items": {"$ref": "#/definitions/sdc_recursive_number_array"}, + }, + "sdc_recursive_string_array": { + "type": ["null", "string", "array"], + "items": {"$ref": "#/definitions/sdc_recursive_string_array"}, + }, + "sdc_recursive_boolean_array": { + "type": ["null", "boolean", "array"], + "items": {"$ref": "#/definitions/sdc_recursive_boolean_array"}, + }, + "sdc_recursive_timestamp_array": { + "type": ["null", "string", "array"], + "format": "date-time", + "items": {"$ref": "#/definitions/sdc_recursive_timestamp_array"}, + }, + "sdc_recursive_object_array": { + "type": ["null", "object", "array"], + "items": {"$ref": "#/definitions/sdc_recursive_object_array"}, + }, } @@ -61,7 +75,9 @@ def produce_table_info(conn, filter_schemas=None, tables: Optional[List[str]] = # select typname from pg_attribute as pga join pg_type as pgt on pgt.oid = pga.atttypid # where typlen = -1 and typelem != 0 and pga.attndims > 0; - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='stitch_cursor') as cur: + with conn.cursor( + cursor_factory=psycopg2.extras.DictCursor, name="stitch_cursor" + ) as cur: cur.itersize = post_db.CURSOR_ITER_SIZE table_info = {} # SELECT CASE WHEN $2.typtype = 'd' THEN $2.typbasetype ELSE $1.atttypid END @@ -122,11 +138,15 @@ def produce_table_info(conn, filter_schemas=None, tables: Optional[List[str]] = table_info[schema_name] = {} if table_info[schema_name].get(table_name) is None: - table_info[schema_name][table_name] = {'is_view': is_view, 'row_count': row_count, 'columns': {}} + table_info[schema_name][table_name] = { + "is_view": is_view, + "row_count": row_count, + "columns": {}, + } col_name = col_info[0] - table_info[schema_name][table_name]['columns'][col_name] = Column(*col_info) + table_info[schema_name][table_name]["columns"][col_name] = Column(*col_info) return table_info @@ -140,44 +160,69 @@ def discover_columns(connection, table_info): for table_name in table_info[schema_name].keys(): mdata = {} - columns = table_info[schema_name][table_name]['columns'] - table_pks = [col_name for col_name, col_info in columns.items() if col_info.is_primary_key] + columns = table_info[schema_name][table_name]["columns"] + table_pks = [ + col_name + for col_name, col_info in columns.items() + if col_info.is_primary_key + ] with connection.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(" SELECT current_database()") database_name = cur.fetchone()[0] - metadata.write(mdata, (), 'table-key-properties', table_pks) - metadata.write(mdata, (), 'schema-name', schema_name) - metadata.write(mdata, (), 'database-name', database_name) - metadata.write(mdata, (), 'row-count', table_info[schema_name][table_name]['row_count']) - metadata.write(mdata, (), 'is-view', table_info[schema_name][table_name].get('is_view')) + metadata.write(mdata, (), "table-key-properties", table_pks) + metadata.write(mdata, (), "schema-name", schema_name) + metadata.write(mdata, (), "database-name", database_name) + metadata.write( + mdata, (), "row-count", table_info[schema_name][table_name]["row_count"] + ) + metadata.write( + mdata, (), "is-view", table_info[schema_name][table_name].get("is_view") + ) - column_schemas = {col_name: schema_for_column(col_info) for col_name, col_info in columns.items()} + column_schemas = { + col_name: schema_for_column(col_info) + for col_name, col_info in columns.items() + } - schema = {'type': 'object', - 'properties': column_schemas, - 'definitions': {}} + schema = {"type": "object", "properties": column_schemas, "definitions": {}} schema = include_array_schemas(columns, schema) for c_name in column_schemas.keys(): mdata = write_sql_data_type_md(mdata, columns[c_name]) - if column_schemas[c_name].get('type') is None: - mdata = metadata.write(mdata, ('properties', c_name), 'inclusion', 'unsupported') - mdata = metadata.write(mdata, ('properties', c_name), 'selected-by-default', False) - elif table_info[schema_name][table_name]['columns'][c_name].is_primary_key: - mdata = metadata.write(mdata, ('properties', c_name), 'inclusion', 'automatic') - mdata = metadata.write(mdata, ('properties', c_name), 'selected-by-default', True) + if column_schemas[c_name].get("type") is None: + mdata = metadata.write( + mdata, ("properties", c_name), "inclusion", "unsupported" + ) + mdata = metadata.write( + mdata, ("properties", c_name), "selected-by-default", False + ) + elif table_info[schema_name][table_name]["columns"][ + c_name + ].is_primary_key: + mdata = metadata.write( + mdata, ("properties", c_name), "inclusion", "automatic" + ) + mdata = metadata.write( + mdata, ("properties", c_name), "selected-by-default", True + ) else: - mdata = metadata.write(mdata, ('properties', c_name), 'inclusion', 'available') - mdata = metadata.write(mdata, ('properties', c_name), 'selected-by-default', True) - - entry = {'table_name': table_name, - 'stream': table_name, - 'metadata': metadata.to_list(mdata), - 'tap_stream_id': post_db.compute_tap_stream_id(schema_name, table_name), - 'schema': schema} + mdata = metadata.write( + mdata, ("properties", c_name), "inclusion", "available" + ) + mdata = metadata.write( + mdata, ("properties", c_name), "selected-by-default", True + ) + + entry = { + "table_name": table_name, + "stream": table_name, + "metadata": metadata.to_list(mdata), + "tap_stream_id": post_db.compute_tap_stream_id(schema_name, table_name), + "schema": schema, + } entries.append(entry) @@ -191,93 +236,93 @@ def schema_for_column_datatype(col): """ schema = {} # remove any array notation from type information as we use a separate field for that - data_type = col.sql_data_type.lower().replace('[]', '') + data_type = col.sql_data_type.lower().replace("[]", "") if data_type in INTEGER_TYPES: - schema['type'] = nullable_column('integer', col.is_primary_key) - schema['minimum'] = -1 * (2 ** (col.numeric_precision - 1)) - schema['maximum'] = 2 ** (col.numeric_precision - 1) - 1 + schema["type"] = nullable_column("integer", col.is_primary_key) + schema["minimum"] = -1 * (2 ** (col.numeric_precision - 1)) + schema["maximum"] = 2 ** (col.numeric_precision - 1) - 1 return schema - if data_type == 'money': - schema['type'] = nullable_column('string', col.is_primary_key) + if data_type == "money": + schema["type"] = nullable_column("string", col.is_primary_key) return schema if col.is_enum: - schema['type'] = nullable_column('string', col.is_primary_key) + schema["type"] = nullable_column("string", col.is_primary_key) return schema - if data_type == 'bit' and col.character_maximum_length == 1: - schema['type'] = nullable_column('boolean', col.is_primary_key) + if data_type == "bit" and col.character_maximum_length == 1: + schema["type"] = nullable_column("boolean", col.is_primary_key) return schema - if data_type == 'boolean': - schema['type'] = nullable_column('boolean', col.is_primary_key) + if data_type == "boolean": + schema["type"] = nullable_column("boolean", col.is_primary_key) return schema - if data_type == 'uuid': - schema['type'] = nullable_column('string', col.is_primary_key) + if data_type == "uuid": + schema["type"] = nullable_column("string", col.is_primary_key) return schema - if data_type == 'hstore': - schema['type'] = nullable_column('object', col.is_primary_key) - schema['properties'] = {} + if data_type == "hstore": + schema["type"] = nullable_column("object", col.is_primary_key) + schema["properties"] = {} return schema - if data_type == 'citext': - schema['type'] = nullable_column('string', col.is_primary_key) + if data_type == "citext": + schema["type"] = nullable_column("string", col.is_primary_key) return schema if data_type in JSON_TYPES: - schema['type'] = nullable_columns(['object', 'array'], col.is_primary_key) + schema["type"] = nullable_columns(["object", "array"], col.is_primary_key) return schema - if data_type == 'numeric': - schema['type'] = nullable_column('number', col.is_primary_key) + if data_type == "numeric": + schema["type"] = nullable_column("number", col.is_primary_key) scale = post_db.numeric_scale(col) precision = post_db.numeric_precision(col) - schema['exclusiveMaximum'] = True - schema['maximum'] = post_db.numeric_max(precision, scale) - schema['multipleOf'] = post_db.numeric_multiple_of(scale) - schema['exclusiveMinimum'] = True - schema['minimum'] = post_db.numeric_min(precision, scale) + schema["exclusiveMaximum"] = True + schema["maximum"] = post_db.numeric_max(precision, scale) + schema["multipleOf"] = post_db.numeric_multiple_of(scale) + schema["exclusiveMinimum"] = True + schema["minimum"] = post_db.numeric_min(precision, scale) return schema - if data_type in {'time without time zone', 'time with time zone'}: + if data_type in {"time without time zone", "time with time zone"}: # times are treated as ordinary strings as they can not possible match RFC3339 - schema['type'] = nullable_column('string', col.is_primary_key) - schema['format'] = 'time' + schema["type"] = nullable_column("string", col.is_primary_key) + schema["format"] = "time" return schema - if data_type in ('date', 'timestamp without time zone', 'timestamp with time zone'): - schema['type'] = nullable_column('string', col.is_primary_key) + if data_type in ("date", "timestamp without time zone", "timestamp with time zone"): + schema["type"] = nullable_column("string", col.is_primary_key) - schema['format'] = 'date-time' + schema["format"] = "date-time" return schema if data_type in FLOAT_TYPES: - schema['type'] = nullable_column('number', col.is_primary_key) + schema["type"] = nullable_column("number", col.is_primary_key) return schema - if data_type == 'text': - schema['type'] = nullable_column('string', col.is_primary_key) + if data_type == "text": + schema["type"] = nullable_column("string", col.is_primary_key) return schema - if data_type == 'character varying': - schema['type'] = nullable_column('string', col.is_primary_key) + if data_type == "character varying": + schema["type"] = nullable_column("string", col.is_primary_key) if col.character_maximum_length: - schema['maxLength'] = col.character_maximum_length + schema["maxLength"] = col.character_maximum_length return schema - if data_type == 'character': - schema['type'] = nullable_column('string', col.is_primary_key) + if data_type == "character": + schema["type"] = nullable_column("string", col.is_primary_key) if col.character_maximum_length: - schema['maxLength'] = col.character_maximum_length + schema["maxLength"] = col.character_maximum_length return schema - if data_type in {'cidr', 'inet', 'macaddr'}: - schema['type'] = nullable_column('string', col.is_primary_key) + if data_type in {"cidr", "inet", "macaddr"}: + schema["type"] = nullable_column("string", col.is_primary_key) return schema return schema @@ -291,62 +336,62 @@ def schema_for_column(col_info): # either. These means we can say nothing about an array column. its items may be more arrays or primitive types # like integers and this can vary on a row by row basis - column_schema = {'type': ["null", "array"]} + column_schema = {"type": ["null", "array"]} if not col_info.is_array: return schema_for_column_datatype(col_info) - if col_info.sql_data_type == 'integer[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_integer_array'} - elif col_info.sql_data_type == 'bigint[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_integer_array'} - elif col_info.sql_data_type == 'bit[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_boolean_array'} - elif col_info.sql_data_type == 'boolean[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_boolean_array'} - elif col_info.sql_data_type == 'character varying[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_string_array'} - elif col_info.sql_data_type == 'cidr[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_string_array'} - elif col_info.sql_data_type == 'citext[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_string_array'} - elif col_info.sql_data_type == 'date[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_timestamp_array'} - elif col_info.sql_data_type == 'numeric[]': + if col_info.sql_data_type == "integer[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_integer_array"} + elif col_info.sql_data_type == "bigint[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_integer_array"} + elif col_info.sql_data_type == "bit[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_boolean_array"} + elif col_info.sql_data_type == "boolean[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_boolean_array"} + elif col_info.sql_data_type == "character varying[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_string_array"} + elif col_info.sql_data_type == "cidr[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_string_array"} + elif col_info.sql_data_type == "citext[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_string_array"} + elif col_info.sql_data_type == "date[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_timestamp_array"} + elif col_info.sql_data_type == "numeric[]": scale = post_db.numeric_scale(col_info) precision = post_db.numeric_precision(col_info) schema_name = schema_name_for_numeric_array(precision, scale) - column_schema['items'] = {'$ref': '#/definitions/{}'.format(schema_name)} - elif col_info.sql_data_type == 'double precision[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_number_array'} - elif col_info.sql_data_type == 'hstore[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_object_array'} - elif col_info.sql_data_type == 'inet[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_string_array'} - elif col_info.sql_data_type == 'json[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_object_array'} - elif col_info.sql_data_type == 'jsonb[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_object_array'} - elif col_info.sql_data_type == 'mac[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_string_array'} - elif col_info.sql_data_type == 'money[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_string_array'} - elif col_info.sql_data_type == 'real[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_number_array'} - elif col_info.sql_data_type == 'smallint[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_integer_array'} - elif col_info.sql_data_type == 'text[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_string_array'} - elif col_info.sql_data_type == 'timestamp without time zone[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_timestamp_array'} - elif col_info.sql_data_type == 'timestamp with time zone[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_timestamp_array'} - elif col_info.sql_data_type == 'time[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_string_array'} - elif col_info.sql_data_type == 'uuid[]': - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_string_array'} + column_schema["items"] = {"$ref": "#/definitions/{}".format(schema_name)} + elif col_info.sql_data_type == "double precision[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_number_array"} + elif col_info.sql_data_type == "hstore[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_object_array"} + elif col_info.sql_data_type == "inet[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_string_array"} + elif col_info.sql_data_type == "json[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_object_array"} + elif col_info.sql_data_type == "jsonb[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_object_array"} + elif col_info.sql_data_type == "mac[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_string_array"} + elif col_info.sql_data_type == "money[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_string_array"} + elif col_info.sql_data_type == "real[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_number_array"} + elif col_info.sql_data_type == "smallint[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_integer_array"} + elif col_info.sql_data_type == "text[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_string_array"} + elif col_info.sql_data_type == "timestamp without time zone[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_timestamp_array"} + elif col_info.sql_data_type == "timestamp with time zone[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_timestamp_array"} + elif col_info.sql_data_type == "time[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_string_array"} + elif col_info.sql_data_type == "uuid[]": + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_string_array"} else: # custom datatypes like enums - column_schema['items'] = {'$ref': '#/definitions/sdc_recursive_string_array'} + column_schema["items"] = {"$ref": "#/definitions/sdc_recursive_string_array"} return column_schema @@ -354,45 +399,55 @@ def schema_for_column(col_info): def nullable_columns(col_types, pk): if pk: return col_types - return ['null'] + col_types + return ["null"] + col_types def nullable_column(col_type, pk): if pk: return [col_type] - return ['null', col_type] + return ["null", col_type] def schema_name_for_numeric_array(precision, scale): - schema_name = 'sdc_recursive_decimal_{}_{}_array'.format(precision, scale) + schema_name = "sdc_recursive_decimal_{}_{}_array".format(precision, scale) return schema_name def include_array_schemas(columns, schema): - schema['definitions'] = copy.deepcopy(BASE_RECURSIVE_SCHEMAS) + schema["definitions"] = copy.deepcopy(BASE_RECURSIVE_SCHEMAS) - decimal_array_columns = [key for key, value in columns.items() if value.sql_data_type == 'numeric[]'] + decimal_array_columns = [ + key for key, value in columns.items() if value.sql_data_type == "numeric[]" + ] for col in decimal_array_columns: scale = post_db.numeric_scale(columns[col]) precision = post_db.numeric_precision(columns[col]) schema_name = schema_name_for_numeric_array(precision, scale) - schema['definitions'][schema_name] = {'type': ['null', 'number', 'array'], - 'multipleOf': post_db.numeric_multiple_of(scale), - 'exclusiveMaximum': True, - 'maximum': post_db.numeric_max(precision, scale), - 'exclusiveMinimum': True, - 'minimum': post_db.numeric_min(precision, scale), - 'items': {'$ref': '#/definitions/{}'.format(schema_name)}} + schema["definitions"][schema_name] = { + "type": ["null", "number", "array"], + "multipleOf": post_db.numeric_multiple_of(scale), + "exclusiveMaximum": True, + "maximum": post_db.numeric_max(precision, scale), + "exclusiveMinimum": True, + "minimum": post_db.numeric_min(precision, scale), + "items": {"$ref": "#/definitions/{}".format(schema_name)}, + } return schema def write_sql_data_type_md(mdata, col_info): c_name = col_info.column_name - if col_info.sql_data_type == 'bit' and col_info.character_maximum_length > 1: - mdata = metadata.write(mdata, ('properties', c_name), - 'sql-datatype', "bit({})".format(col_info.character_maximum_length)) + if col_info.sql_data_type == "bit" and col_info.character_maximum_length > 1: + mdata = metadata.write( + mdata, + ("properties", c_name), + "sql-datatype", + "bit({})".format(col_info.character_maximum_length), + ) else: - mdata = metadata.write(mdata, ('properties', c_name), 'sql-datatype', col_info.sql_data_type) + mdata = metadata.write( + mdata, ("properties", c_name), "sql-datatype", col_info.sql_data_type + ) return mdata diff --git a/tap_postgres/stream_utils.py b/tap_postgres/stream_utils.py index c5f0a564..de168ac7 100644 --- a/tap_postgres/stream_utils.py +++ b/tap_postgres/stream_utils.py @@ -9,7 +9,7 @@ from tap_postgres.db import open_connection from tap_postgres.discovery_utils import discover_db -LOGGER = singer.get_logger('tap_postgres') +LOGGER = singer.get_logger("tap_postgres") def dump_catalog(all_streams: List[Dict]) -> None: @@ -18,7 +18,7 @@ def dump_catalog(all_streams: List[Dict]) -> None: Args: all_streams: List of streams to dump """ - json.dump({'streams': all_streams}, sys.stdout, indent=2) + json.dump({"streams": all_streams}, sys.stdout, indent=2) def is_selected_via_metadata(stream: Dict) -> bool: @@ -29,29 +29,35 @@ def is_selected_via_metadata(stream: Dict) -> bool: Returns: True if selected, False otherwise. """ - table_md = metadata_util.to_map(stream['metadata']).get((), {}) - return table_md.get('selected', False) + table_md = metadata_util.to_map(stream["metadata"]).get((), {}) + return table_md.get("selected", False) -def clear_state_on_replication_change(state: Dict, - tap_stream_id: str, - replication_key: str, - replication_method: str) -> Dict: +def clear_state_on_replication_change( + state: Dict, tap_stream_id: str, replication_key: str, replication_method: str +) -> Dict: """ Update state if replication method change is detected Returns: new state dictionary """ # user changed replication, nuke state - last_replication_method = singer.get_bookmark(state, tap_stream_id, 'last_replication_method') - if last_replication_method is not None and (replication_method != last_replication_method): + last_replication_method = singer.get_bookmark( + state, tap_stream_id, "last_replication_method" + ) + if last_replication_method is not None and ( + replication_method != last_replication_method + ): state = singer.reset_stream(state, tap_stream_id) # key changed - if replication_method == 'INCREMENTAL' and \ - replication_key != singer.get_bookmark(state, tap_stream_id, 'replication_key'): + if replication_method == "INCREMENTAL" and replication_key != singer.get_bookmark( + state, tap_stream_id, "replication_key" + ): state = singer.reset_stream(state, tap_stream_id) - state = singer.write_bookmark(state, tap_stream_id, 'last_replication_method', replication_method) + state = singer.write_bookmark( + state, tap_stream_id, "last_replication_method", replication_method + ) return state @@ -61,26 +67,30 @@ def refresh_streams_schema(conn_config: Dict, streams: List[Dict]): Updates the streams schema & metadata with new discovery The given streams list of dictionaries would be mutated and updated """ - LOGGER.debug('Refreshing streams schemas ...') + LOGGER.debug("Refreshing streams schemas ...") - LOGGER.debug('Current streams schemas %s', streams) + LOGGER.debug("Current streams schemas %s", streams) # Run discovery to get the streams most up to date json schemas with open_connection(conn_config) as conn: new_discovery = { - stream['tap_stream_id']: stream - for stream in discover_db(conn, conn_config.get('filter_schemas'), [st['table_name'] for st in streams]) + stream["tap_stream_id"]: stream + for stream in discover_db( + conn, + conn_config.get("filter_schemas"), + [st["table_name"] for st in streams], + ) } - LOGGER.debug('New discovery schemas %s', new_discovery) + LOGGER.debug("New discovery schemas %s", new_discovery) # For every stream, update the schema and metadata from the corresponding discovered stream for idx, stream in enumerate(streams): - discovered_stream = new_discovery[stream['tap_stream_id']] - streams[idx]['schema'] = _merge_stream_schema(stream, discovered_stream) - streams[idx]['metadata'] = _merge_stream_metadata(stream, discovered_stream) + discovered_stream = new_discovery[stream["tap_stream_id"]] + streams[idx]["schema"] = _merge_stream_schema(stream, discovered_stream) + streams[idx]["metadata"] = _merge_stream_metadata(stream, discovered_stream) - LOGGER.debug('Updated streams schemas %s', streams) + LOGGER.debug("Updated streams schemas %s", streams) def _merge_stream_schema(stream, discovered_stream): @@ -88,13 +98,21 @@ def _merge_stream_schema(stream, discovered_stream): A discovered stream doesn't include any schema overrides from the catalog file. Merges overrides from the catalog file into the discovered schema. """ - discovered_schema = copy.deepcopy(discovered_stream['schema']) - - for column, column_schema in stream['schema']['properties'].items(): - if column in discovered_schema['properties'] and column_schema != discovered_schema['properties'][column]: - override = copy.deepcopy(stream['schema']['properties'][column]) - LOGGER.info('Overriding schema for %s.%s with %s', stream['tap_stream_id'], column, override) - discovered_schema['properties'][column].update(override) + discovered_schema = copy.deepcopy(discovered_stream["schema"]) + + for column, column_schema in stream["schema"]["properties"].items(): + if ( + column in discovered_schema["properties"] + and column_schema != discovered_schema["properties"][column] + ): + override = copy.deepcopy(stream["schema"]["properties"][column]) + LOGGER.info( + "Overriding schema for %s.%s with %s", + stream["tap_stream_id"], + column, + override, + ) + discovered_schema["properties"][column].update(override) return discovered_schema @@ -106,8 +124,8 @@ def _merge_stream_metadata(stream, discovered_stream): arbitrary overridden metadata from the catalog file. Merges the discovered metadata into the metadata from the catalog file. """ - stream_md = metadata_util.to_map(stream['metadata']) - discovery_md = metadata_util.to_map(discovered_stream['metadata']) + stream_md = metadata_util.to_map(stream["metadata"]) + discovery_md = metadata_util.to_map(discovered_stream["metadata"]) for breadcrumb, metadata in discovery_md.items(): if breadcrumb in stream_md: @@ -123,9 +141,11 @@ def any_logical_streams(streams, default_replication_method): Checks if streams list contains any stream with log_based method """ for stream in streams: - stream_metadata = metadata_util.to_map(stream['metadata']) - replication_method = stream_metadata.get((), {}).get('replication-method', default_replication_method) - if replication_method == 'LOG_BASED': + stream_metadata = metadata_util.to_map(stream["metadata"]) + replication_method = stream_metadata.get((), {}).get( + "replication-method", default_replication_method + ) + if replication_method == "LOG_BASED": return True return False diff --git a/tap_postgres/sync_strategies/common.py b/tap_postgres/sync_strategies/common.py index 9d6300d6..326a167a 100644 --- a/tap_postgres/sync_strategies/common.py +++ b/tap_postgres/sync_strategies/common.py @@ -1,34 +1,36 @@ import sys import simplejson as json import singer -from singer import metadata +from singer import metadata import tap_postgres.db as post_db # pylint: disable=invalid-name,missing-function-docstring def should_sync_column(md_map, field_name): - field_metadata = md_map.get(('properties', field_name), {}) - return singer.should_sync_field(field_metadata.get('inclusion'), - field_metadata.get('selected'), - True) + field_metadata = md_map.get(("properties", field_name), {}) + return singer.should_sync_field( + field_metadata.get("inclusion"), field_metadata.get("selected"), True + ) def write_schema_message(schema_message): - sys.stdout.write(json.dumps(schema_message, use_decimal=True) + '\n') + sys.stdout.write(json.dumps(schema_message, use_decimal=True) + "\n") sys.stdout.flush() def send_schema_message(stream, bookmark_properties): - s_md = metadata.to_map(stream['metadata']) - if s_md.get((), {}).get('is-view'): - key_properties = s_md.get((), {}).get('view-key-properties', []) + s_md = metadata.to_map(stream["metadata"]) + if s_md.get((), {}).get("is-view"): + key_properties = s_md.get((), {}).get("view-key-properties", []) else: - key_properties = s_md.get((), {}).get('table-key-properties', []) - - schema_message = {'type' : 'SCHEMA', - 'stream' : post_db.calculate_destination_stream_name(stream, s_md), - 'schema' : stream['schema'], - 'key_properties' : key_properties, - 'bookmark_properties': bookmark_properties} + key_properties = s_md.get((), {}).get("table-key-properties", []) + + schema_message = { + "type": "SCHEMA", + "stream": post_db.calculate_destination_stream_name(stream, s_md), + "schema": stream["schema"], + "key_properties": key_properties, + "bookmark_properties": bookmark_properties, + } write_schema_message(schema_message) diff --git a/tap_postgres/sync_strategies/full_table.py b/tap_postgres/sync_strategies/full_table.py index 92040e43..7e017eed 100644 --- a/tap_postgres/sync_strategies/full_table.py +++ b/tap_postgres/sync_strategies/full_table.py @@ -10,7 +10,7 @@ import tap_postgres.db as post_db -LOGGER = singer.get_logger('tap_postgres') +LOGGER = singer.get_logger("tap_postgres") UPDATE_BOOKMARK_PERIOD = 1000 @@ -20,49 +20,58 @@ def sync_view(conn_info, stream, state, desired_columns, md_map): time_extracted = utils.now() # before writing the table version to state, check if we had one to begin with - first_run = singer.get_bookmark(state, stream['tap_stream_id'], 'version') is None + first_run = singer.get_bookmark(state, stream["tap_stream_id"], "version") is None nascent_stream_version = int(time.time() * 1000) - state = singer.write_bookmark(state, - stream['tap_stream_id'], - 'version', - nascent_stream_version) + state = singer.write_bookmark( + state, stream["tap_stream_id"], "version", nascent_stream_version + ) singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) - schema_name = md_map.get(()).get('schema-name') + schema_name = md_map.get(()).get("schema-name") escaped_columns = map(post_db.prepare_columns_sql, desired_columns) activate_version_message = singer.ActivateVersionMessage( stream=post_db.calculate_destination_stream_name(stream, md_map), - version=nascent_stream_version) + version=nascent_stream_version, + ) if first_run: singer.write_message(activate_version_message) with metrics.record_counter(None) as counter: with post_db.open_connection(conn_info) as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='stitch_cursor') as cur: + with conn.cursor( + cursor_factory=psycopg2.extras.DictCursor, name="stitch_cursor" + ) as cur: cur.itersize = post_db.CURSOR_ITER_SIZE - select_sql = 'SELECT {} FROM {}'.format(','.join(escaped_columns), - post_db.fully_qualified_table_name(schema_name, - stream['table_name'])) + select_sql = "SELECT {} FROM {}".format( + ",".join(escaped_columns), + post_db.fully_qualified_table_name( + schema_name, stream["table_name"] + ), + ) LOGGER.info("select %s with itersize %s", select_sql, cur.itersize) cur.execute(select_sql) rows_saved = 0 for rec in cur: - record_message = post_db.selected_row_to_singer_message(stream, - rec, - nascent_stream_version, - desired_columns, - time_extracted, - md_map) + record_message = post_db.selected_row_to_singer_message( + stream, + rec, + nascent_stream_version, + desired_columns, + time_extracted, + md_map, + ) singer.write_message(record_message) rows_saved = rows_saved + 1 if rows_saved % UPDATE_BOOKMARK_PERIOD == 0: - singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + singer.write_message( + singer.StateMessage(value=copy.deepcopy(state)) + ) counter.increment() @@ -77,28 +86,32 @@ def sync_table(conn_info, stream, state, desired_columns, md_map): time_extracted = utils.now() # before writing the table version to state, check if we had one to begin with - first_run = singer.get_bookmark(state, stream['tap_stream_id'], 'version') is None + first_run = singer.get_bookmark(state, stream["tap_stream_id"], "version") is None # pick a new table version IFF we do not have an xmin in our state # the presence of an xmin indicates that we were interrupted last time through - if singer.get_bookmark(state, stream['tap_stream_id'], 'xmin') is None: + if singer.get_bookmark(state, stream["tap_stream_id"], "xmin") is None: nascent_stream_version = int(time.time() * 1000) else: - nascent_stream_version = singer.get_bookmark(state, stream['tap_stream_id'], 'version') + nascent_stream_version = singer.get_bookmark( + state, stream["tap_stream_id"], "version" + ) - state = singer.write_bookmark(state, - stream['tap_stream_id'], - 'version', - nascent_stream_version) + state = singer.write_bookmark( + state, stream["tap_stream_id"], "version", nascent_stream_version + ) singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) - schema_name = md_map.get(()).get('schema-name') + schema_name = md_map.get(()).get("schema-name") - escaped_columns = map(partial(post_db.prepare_columns_for_select_sql, md_map=md_map), desired_columns) + escaped_columns = map( + partial(post_db.prepare_columns_for_select_sql, md_map=md_map), desired_columns + ) activate_version_message = singer.ActivateVersionMessage( stream=post_db.calculate_destination_stream_name(stream, md_map), - version=nascent_stream_version) + version=nascent_stream_version, + ) if first_run: singer.write_message(activate_version_message) @@ -121,49 +134,67 @@ def sync_table(conn_info, stream, state, desired_columns, md_map): else: LOGGER.info("hstore is UNavailable") - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='stitch_cursor') as cur: + with conn.cursor( + cursor_factory=psycopg2.extras.DictCursor, name="stitch_cursor" + ) as cur: cur.itersize = post_db.CURSOR_ITER_SIZE - fq_table_name = post_db.fully_qualified_table_name(schema_name, stream['table_name']) - xmin = singer.get_bookmark(state, stream['tap_stream_id'], 'xmin') + fq_table_name = post_db.fully_qualified_table_name( + schema_name, stream["table_name"] + ) + xmin = singer.get_bookmark(state, stream["tap_stream_id"], "xmin") if xmin: - LOGGER.info("Resuming Full Table replication %s from xmin %s", nascent_stream_version, xmin) + LOGGER.info( + "Resuming Full Table replication %s from xmin %s", + nascent_stream_version, + xmin, + ) select_sql = """SELECT {}, xmin::text::bigint FROM {} where age(xmin::xid) <= age('{}'::xid) - ORDER BY xmin::text ASC""".format(','.join(escaped_columns), - fq_table_name, - xmin) + ORDER BY xmin::text ASC""".format( + ",".join(escaped_columns), fq_table_name, xmin + ) else: - LOGGER.info("Beginning new Full Table replication %s", nascent_stream_version) + LOGGER.info( + "Beginning new Full Table replication %s", + nascent_stream_version, + ) select_sql = """SELECT {}, xmin::text::bigint FROM {} - ORDER BY xmin::text ASC""".format(','.join(escaped_columns), - fq_table_name) + ORDER BY xmin::text ASC""".format( + ",".join(escaped_columns), fq_table_name + ) LOGGER.info("select %s with itersize %s", select_sql, cur.itersize) cur.execute(select_sql) rows_saved = 0 for rec in cur: - xmin = rec['xmin'] + xmin = rec["xmin"] rec = rec[:-1] - record_message = post_db.selected_row_to_singer_message(stream, - rec, - nascent_stream_version, - desired_columns, - time_extracted, - md_map) + record_message = post_db.selected_row_to_singer_message( + stream, + rec, + nascent_stream_version, + desired_columns, + time_extracted, + md_map, + ) singer.write_message(record_message) - state = singer.write_bookmark(state, stream['tap_stream_id'], 'xmin', xmin) + state = singer.write_bookmark( + state, stream["tap_stream_id"], "xmin", xmin + ) rows_saved = rows_saved + 1 if rows_saved % UPDATE_BOOKMARK_PERIOD == 0: - singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + singer.write_message( + singer.StateMessage(value=copy.deepcopy(state)) + ) counter.increment() # once we have completed the full table replication, discard the xmin bookmark. # the xmin bookmark only comes into play when a full table replication is interrupted - state = singer.write_bookmark(state, stream['tap_stream_id'], 'xmin', None) + state = singer.write_bookmark(state, stream["tap_stream_id"], "xmin", None) # always send the activate version whether first run or subsequent singer.write_message(activate_version_message) diff --git a/tap_postgres/sync_strategies/incremental.py b/tap_postgres/sync_strategies/incremental.py index b98da879..a3d047f6 100644 --- a/tap_postgres/sync_strategies/incremental.py +++ b/tap_postgres/sync_strategies/incremental.py @@ -11,7 +11,7 @@ import tap_postgres.db as post_db -LOGGER = singer.get_logger('tap_postgres') +LOGGER = singer.get_logger("tap_postgres") UPDATE_BOOKMARK_PERIOD = 10000 @@ -21,8 +21,10 @@ def fetch_max_replication_key(conn_config, replication_key, schema_name, table_n with post_db.open_connection(conn_config, False) as conn: with conn.cursor() as cur: max_key_sql = """SELECT max({}) - FROM {}""".format(post_db.prepare_columns_sql(replication_key), - post_db.fully_qualified_table_name(schema_name, table_name)) + FROM {}""".format( + post_db.prepare_columns_sql(replication_key), + post_db.fully_qualified_table_name(schema_name, table_name), + ) LOGGER.info("determine max replication key value: %s", max_key_sql) cur.execute(max_key_sql) max_key = cur.fetchone()[0] @@ -34,30 +36,35 @@ def fetch_max_replication_key(conn_config, replication_key, schema_name, table_n def sync_table(conn_info, stream, state, desired_columns, md_map): time_extracted = utils.now() - stream_version = singer.get_bookmark(state, stream['tap_stream_id'], 'version') + stream_version = singer.get_bookmark(state, stream["tap_stream_id"], "version") if stream_version is None: stream_version = int(time.time() * 1000) - state = singer.write_bookmark(state, - stream['tap_stream_id'], - 'version', - stream_version) + state = singer.write_bookmark( + state, stream["tap_stream_id"], "version", stream_version + ) singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) - schema_name = md_map.get(()).get('schema-name') + schema_name = md_map.get(()).get("schema-name") - escaped_columns = map(partial(post_db.prepare_columns_for_select_sql, md_map=md_map), desired_columns) + escaped_columns = map( + partial(post_db.prepare_columns_for_select_sql, md_map=md_map), desired_columns + ) activate_version_message = singer.ActivateVersionMessage( stream=post_db.calculate_destination_stream_name(stream, md_map), - version=stream_version) - + version=stream_version, + ) singer.write_message(activate_version_message) - replication_key = md_map.get((), {}).get('replication-key') - replication_key_value = singer.get_bookmark(state, stream['tap_stream_id'], 'replication_key_value') - replication_key_sql_datatype = md_map.get(('properties', replication_key)).get('sql-datatype') + replication_key = md_map.get((), {}).get("replication-key") + replication_key_value = singer.get_bookmark( + state, stream["tap_stream_id"], "replication_key_value" + ) + replication_key_sql_datatype = md_map.get(("properties", replication_key)).get( + "sql-datatype" + ) hstore_available = post_db.hstore_available(conn_info) with metrics.record_counter(None) as counter: @@ -77,59 +84,77 @@ def sync_table(conn_info, stream, state, desired_columns, md_map): else: LOGGER.info("hstore is UNavailable") - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='pipelinewise') as cur: + with conn.cursor( + cursor_factory=psycopg2.extras.DictCursor, name="pipelinewise" + ) as cur: cur.itersize = post_db.CURSOR_ITER_SIZE - LOGGER.info("Beginning new incremental replication sync %s", stream_version) + LOGGER.info( + "Beginning new incremental replication sync %s", stream_version + ) if replication_key_value: select_sql = """SELECT {} FROM {} WHERE {} >= '{}'::{} - ORDER BY {} ASC""".format(','.join(escaped_columns), - post_db.fully_qualified_table_name(schema_name, - stream['table_name']), - post_db.prepare_columns_sql(replication_key), - replication_key_value, - replication_key_sql_datatype, - post_db.prepare_columns_sql(replication_key)) + ORDER BY {} ASC""".format( + ",".join(escaped_columns), + post_db.fully_qualified_table_name( + schema_name, stream["table_name"] + ), + post_db.prepare_columns_sql(replication_key), + replication_key_value, + replication_key_sql_datatype, + post_db.prepare_columns_sql(replication_key), + ) else: - #if not replication_key_value + # if not replication_key_value select_sql = """SELECT {} FROM {} - ORDER BY {} ASC""".format(','.join(escaped_columns), - post_db.fully_qualified_table_name(schema_name, - stream['table_name']), - post_db.prepare_columns_sql(replication_key)) - - LOGGER.info('select statement: %s with itersize %s', select_sql, cur.itersize) + ORDER BY {} ASC""".format( + ",".join(escaped_columns), + post_db.fully_qualified_table_name( + schema_name, stream["table_name"] + ), + post_db.prepare_columns_sql(replication_key), + ) + + LOGGER.info( + "select statement: %s with itersize %s", select_sql, cur.itersize + ) cur.execute(select_sql) rows_saved = 0 for rec in cur: - record_message = post_db.selected_row_to_singer_message(stream, - rec, - stream_version, - desired_columns, - time_extracted, - md_map) + record_message = post_db.selected_row_to_singer_message( + stream, + rec, + stream_version, + desired_columns, + time_extracted, + md_map, + ) singer.write_message(record_message) rows_saved = rows_saved + 1 - #Picking a replication_key with NULL values will result in it ALWAYS been synced which is not great - #event worse would be allowing the NULL value to enter into the state + # Picking a replication_key with NULL values will result in it ALWAYS been synced which is not great + # event worse would be allowing the NULL value to enter into the state try: if record_message.record[replication_key] is not None: - state = singer.write_bookmark(state, - stream['tap_stream_id'], - 'replication_key_value', - record_message.record[replication_key]) + state = singer.write_bookmark( + state, + stream["tap_stream_id"], + "replication_key_value", + record_message.record[replication_key], + ) except KeyError as e: # Replication key not present in table - treat like None pass if rows_saved % UPDATE_BOOKMARK_PERIOD == 0: - singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + singer.write_message( + singer.StateMessage(value=copy.deepcopy(state)) + ) counter.increment() diff --git a/tap_postgres/sync_strategies/logical_replication.py b/tap_postgres/sync_strategies/logical_replication.py index 53b6ad04..1cfb18d9 100644 --- a/tap_postgres/sync_strategies/logical_replication.py +++ b/tap_postgres/sync_strategies/logical_replication.py @@ -17,16 +17,17 @@ import tap_postgres.sync_strategies.common as sync_common from tap_postgres.stream_utils import refresh_streams_schema -LOGGER = singer.get_logger('tap_postgres') +LOGGER = singer.get_logger("tap_postgres") UPDATE_BOOKMARK_PERIOD = 10000 -FALLBACK_DATETIME = '9999-12-31T23:59:59.999+00:00' -FALLBACK_DATE = '9999-12-31T00:00:00+00:00' +FALLBACK_DATETIME = "9999-12-31T23:59:59.999+00:00" +FALLBACK_DATE = "9999-12-31T00:00:00+00:00" class ReplicationSlotNotFoundError(Exception): """Custom exception when replication slot not found""" + class UnsupportedPayloadKindError(Exception): """Custom exception when waljson payload is not insert, update nor delete""" @@ -35,9 +36,11 @@ class UnsupportedPayloadKindError(Exception): def get_pg_version(conn_info): with post_db.open_connection(conn_info, False) as conn: with conn.cursor() as cur: - cur.execute("SELECT setting::int AS version FROM pg_settings WHERE name='server_version_num'") + cur.execute( + "SELECT setting::int AS version FROM pg_settings WHERE name='server_version_num'" + ) version = cur.fetchone()[0] - LOGGER.debug('Detected PostgreSQL version: %s', version) + LOGGER.debug("Detected PostgreSQL version: %s", version) return version @@ -47,7 +50,7 @@ def lsn_to_int(lsn): if not lsn: return None - file, index = lsn.split('/') + file, index = lsn.split("/") lsni = (int(file, 16) << 32) + int(index, 16) return lsni @@ -59,16 +62,16 @@ def int_to_lsn(lsni): return None # Convert the integer to binary - lsnb = '{0:b}'.format(lsni) + lsnb = "{0:b}".format(lsni) # file is the binary before the 32nd character, converted to hex if len(lsnb) > 32: - file = (format(int(lsnb[:-32], 2), 'x')).upper() + file = (format(int(lsnb[:-32], 2), "x")).upper() else: - file = '0' + file = "0" # index is the binary from the 32nd character, converted to hex - index = (format(int(lsnb[-32:], 2), 'x')).upper() + index = (format(int(lsnb[-32:], 2), "x")).upper() # Formatting lsn = "{}/{}".format(file, index) return lsn @@ -80,17 +83,17 @@ def fetch_current_lsn(conn_config): # Make sure PostgreSQL version is 9.4 or higher # Do not allow minor versions with PostgreSQL BUG #15114 if (version >= 110000) and (version < 110002): - raise Exception('PostgreSQL upgrade required to minor version 11.2') + raise Exception("PostgreSQL upgrade required to minor version 11.2") if (version >= 100000) and (version < 100007): - raise Exception('PostgreSQL upgrade required to minor version 10.7') + raise Exception("PostgreSQL upgrade required to minor version 10.7") if (version >= 90600) and (version < 90612): - raise Exception('PostgreSQL upgrade required to minor version 9.6.12') + raise Exception("PostgreSQL upgrade required to minor version 9.6.12") if (version >= 90500) and (version < 90516): - raise Exception('PostgreSQL upgrade required to minor version 9.5.16') + raise Exception("PostgreSQL upgrade required to minor version 9.5.16") if (version >= 90400) and (version < 90421): - raise Exception('PostgreSQL upgrade required to minor version 9.4.21') + raise Exception("PostgreSQL upgrade required to minor version 9.4.21") if version < 90400: - raise Exception('Logical replication not supported before PostgreSQL 9.4') + raise Exception("Logical replication not supported before PostgreSQL 9.4") with post_db.open_connection(conn_config, False) as conn: with conn.cursor() as cur: @@ -100,26 +103,31 @@ def fetch_current_lsn(conn_config): elif version >= 90400: cur.execute("SELECT pg_current_xlog_location() AS current_lsn") else: - raise Exception('Logical replication not supported before PostgreSQL 9.4') + raise Exception( + "Logical replication not supported before PostgreSQL 9.4" + ) current_lsn = cur.fetchone()[0] return lsn_to_int(current_lsn) def add_automatic_properties(stream, debug_lsn: bool = False): - stream['schema']['properties']['_sdc_deleted_at'] = {'type': ['null', 'string'], 'format': 'date-time'} + stream["schema"]["properties"]["_sdc_deleted_at"] = { + "type": ["null", "string"], + "format": "date-time", + } if debug_lsn: - LOGGER.debug('debug_lsn is ON') - stream['schema']['properties']['_sdc_lsn'] = {'type': ['null', 'string']} + LOGGER.debug("debug_lsn is ON") + stream["schema"]["properties"]["_sdc_lsn"] = {"type": ["null", "string"]} else: - LOGGER.debug('debug_lsn is OFF') + LOGGER.debug("debug_lsn is OFF") return stream def get_stream_version(tap_stream_id, state): - stream_version = singer.get_bookmark(state, tap_stream_id, 'version') + stream_version = singer.get_bookmark(state, tap_stream_id, "version") if stream_version is None: raise Exception("version not found for log miner {}".format(tap_stream_id)) @@ -142,7 +150,9 @@ def create_hstore_elem(conn_info, elem): query = create_hstore_elem_query(elem) cur.execute(query) res = cur.fetchone()[0] - hstore_elem = reduce(tuples_to_map, [res[i:i + 2] for i in range(0, len(res), 2)], {}) + hstore_elem = reduce( + tuples_to_map, [res[i : i + 2] for i in range(0, len(res), 2)], {} + ) return hstore_elem @@ -152,54 +162,59 @@ def create_array_elem(elem, sql_datatype, conn_info): with post_db.open_connection(conn_info) as conn: with conn.cursor() as cur: - if sql_datatype == 'bit[]': - cast_datatype = 'boolean[]' - elif sql_datatype == 'boolean[]': - cast_datatype = 'boolean[]' - elif sql_datatype == 'character varying[]': - cast_datatype = 'character varying[]' - elif sql_datatype == 'cidr[]': - cast_datatype = 'cidr[]' - elif sql_datatype == 'citext[]': - cast_datatype = 'text[]' - elif sql_datatype == 'date[]': - cast_datatype = 'text[]' - elif sql_datatype == 'double precision[]': - cast_datatype = 'double precision[]' - elif sql_datatype == 'hstore[]': - cast_datatype = 'text[]' - elif sql_datatype == 'integer[]': - cast_datatype = 'integer[]' - elif sql_datatype == 'inet[]': - cast_datatype = 'inet[]' - elif sql_datatype == 'json[]': - cast_datatype = 'text[]' - elif sql_datatype == 'jsonb[]': - cast_datatype = 'text[]' - elif sql_datatype == 'macaddr[]': - cast_datatype = 'macaddr[]' - elif sql_datatype == 'money[]': - cast_datatype = 'text[]' - elif sql_datatype == 'numeric[]': - cast_datatype = 'text[]' - elif sql_datatype == 'real[]': - cast_datatype = 'real[]' - elif sql_datatype == 'smallint[]': - cast_datatype = 'smallint[]' - elif sql_datatype == 'text[]': - cast_datatype = 'text[]' - elif sql_datatype in ('time without time zone[]', 'time with time zone[]'): - cast_datatype = 'text[]' - elif sql_datatype in ('timestamp with time zone[]', 'timestamp without time zone[]'): - cast_datatype = 'text[]' - elif sql_datatype == 'uuid[]': - cast_datatype = 'text[]' + if sql_datatype == "bit[]": + cast_datatype = "boolean[]" + elif sql_datatype == "boolean[]": + cast_datatype = "boolean[]" + elif sql_datatype == "character varying[]": + cast_datatype = "character varying[]" + elif sql_datatype == "cidr[]": + cast_datatype = "cidr[]" + elif sql_datatype == "citext[]": + cast_datatype = "text[]" + elif sql_datatype == "date[]": + cast_datatype = "text[]" + elif sql_datatype == "double precision[]": + cast_datatype = "double precision[]" + elif sql_datatype == "hstore[]": + cast_datatype = "text[]" + elif sql_datatype == "integer[]": + cast_datatype = "integer[]" + elif sql_datatype == "inet[]": + cast_datatype = "inet[]" + elif sql_datatype == "json[]": + cast_datatype = "text[]" + elif sql_datatype == "jsonb[]": + cast_datatype = "text[]" + elif sql_datatype == "macaddr[]": + cast_datatype = "macaddr[]" + elif sql_datatype == "money[]": + cast_datatype = "text[]" + elif sql_datatype == "numeric[]": + cast_datatype = "text[]" + elif sql_datatype == "real[]": + cast_datatype = "real[]" + elif sql_datatype == "smallint[]": + cast_datatype = "smallint[]" + elif sql_datatype == "text[]": + cast_datatype = "text[]" + elif sql_datatype in ("time without time zone[]", "time with time zone[]"): + cast_datatype = "text[]" + elif sql_datatype in ( + "timestamp with time zone[]", + "timestamp without time zone[]", + ): + cast_datatype = "text[]" + elif sql_datatype == "uuid[]": + cast_datatype = "text[]" else: # custom datatypes like enums - cast_datatype = 'text[]' + cast_datatype = "text[]" - sql_stmt = """SELECT $stitch_quote${}$stitch_quote$::{}""".format(elem, cast_datatype) + sql_stmt = """SELECT $stitch_quote${}$stitch_quote$::{}""".format( + elem, cast_datatype + ) cur.execute(sql_stmt) res = cur.fetchone()[0] return res @@ -207,25 +222,25 @@ def create_array_elem(elem, sql_datatype, conn_info): # pylint: disable=too-many-branches,too-many-nested-blocks,too-many-return-statements def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info): - sql_datatype = og_sql_datatype.replace('[]', '') + sql_datatype = og_sql_datatype.replace("[]", "") if elem is None: return elem - if sql_datatype == 'money': + if sql_datatype == "money": return elem - if sql_datatype in ['json', 'jsonb']: + if sql_datatype in ["json", "jsonb"]: return json.loads(elem) - if sql_datatype == 'timestamp without time zone': + if sql_datatype == "timestamp without time zone": if isinstance(elem, datetime.datetime): # we don't want a datetime like datetime(9999, 12, 31, 23, 59, 59, 999999) to be returned # compare the date in UTC tz to the max allowed if elem > datetime.datetime(9999, 12, 31, 23, 59, 59, 999000): return FALLBACK_DATETIME - return elem.isoformat() + '+00:00' + return elem.isoformat() + "+00:00" with warnings.catch_warnings(): # we need to catch and handle this warning @@ -233,7 +248,7 @@ def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info): # dateutil/dateutil/blob/c496b4f872b50e8845c0f46b585a1e3830ed3648/dateutil/parser/_parser.py#L1213 # otherwise ad date like this '0001-12-31 23:40:28 BC' would be parsed as # '0001-12-31T23:40:28+00:00' instead of using the fallback date - warnings.filterwarnings('error') + warnings.filterwarnings("error") # parsing dates with era is not possible at moment # github.com/dateutil/dateutil/blob/c496b4f872b50e8845c0f46b585a1e3830ed3648/dateutil/parser/_parser.py#L297 @@ -244,11 +259,11 @@ def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info): if parsed > datetime.datetime(9999, 12, 31, 23, 59, 59, 999000): return FALLBACK_DATETIME - return parsed.isoformat() + '+00:00' + return parsed.isoformat() + "+00:00" except (ParserError, UnknownTimezoneWarning): return FALLBACK_DATETIME - if sql_datatype == 'timestamp with time zone': + if sql_datatype == "timestamp with time zone": if isinstance(elem, datetime.datetime): try: # compare the date in UTC tz to the max allowed @@ -266,7 +281,7 @@ def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info): # dateutil/dateutil/blob/c496b4f872b50e8845c0f46b585a1e3830ed3648/dateutil/parser/_parser.py#L1213 # otherwise ad date like this '0001-12-31 23:40:28 BC' would be parsed as # '0001-12-31T23:40:28+00:00' instead of using the fallback date - warnings.filterwarnings('error') + warnings.filterwarnings("error") # parsing dates with era is not possible at moment # github.com/dateutil/dateutil/blob/c496b4f872b50e8845c0f46b585a1e3830ed3648/dateutil/parser/_parser.py#L297 @@ -274,8 +289,9 @@ def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info): parsed = parse(elem) # compare the date in UTC tz to the max allowed - if parsed.astimezone(pytz.UTC).replace(tzinfo=None) > \ - datetime.datetime(9999, 12, 31, 23, 59, 59, 999000): + if parsed.astimezone(pytz.UTC).replace(tzinfo=None) > datetime.datetime( + 9999, 12, 31, 23, 59, 59, 999000 + ): return FALLBACK_DATETIME return parsed.isoformat() @@ -283,47 +299,52 @@ def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info): except (ParserError, UnknownTimezoneWarning, OverflowError): return FALLBACK_DATETIME - if sql_datatype == 'date': + if sql_datatype == "date": if isinstance(elem, datetime.date): # logical replication gives us dates as strings UNLESS they from an array - return elem.isoformat() + 'T00:00:00+00:00' + return elem.isoformat() + "T00:00:00+00:00" try: return parse(elem).isoformat() + "+00:00" except ValueError as e: - match = re.match(r'year (\d+) is out of range', str(e)) + match = re.match(r"year (\d+) is out of range", str(e)) if match and int(match.group(1)) > 9999: - LOGGER.warning('datetimes cannot handle years past 9999, returning %s for %s', - FALLBACK_DATE, elem) + LOGGER.warning( + "datetimes cannot handle years past 9999, returning %s for %s", + FALLBACK_DATE, + elem, + ) return FALLBACK_DATE raise - if sql_datatype == 'time with time zone': + if sql_datatype == "time with time zone": # time with time zone values will be converted to UTC and time zone dropped # Replace hour=24 with hour=0 - if elem.startswith('24'): - elem = elem.replace('24', '00', 1) + if elem.startswith("24"): + elem = elem.replace("24", "00", 1) # convert to UTC - elem = elem + '00' - elem_obj = datetime.datetime.strptime(elem, '%H:%M:%S%z') + elem = elem + "00" + elem_obj = datetime.datetime.strptime(elem, "%H:%M:%S%z") if elem_obj.utcoffset() != datetime.timedelta(seconds=0): - LOGGER.warning('time with time zone values are converted to UTC: %s', og_sql_datatype) + LOGGER.warning( + "time with time zone values are converted to UTC: %s", og_sql_datatype + ) elem_obj = elem_obj.astimezone(pytz.utc) # drop time zone - elem = elem_obj.strftime('%H:%M:%S') - return parse(elem).isoformat().split('T')[1] - if sql_datatype == 'time without time zone': + elem = elem_obj.strftime("%H:%M:%S") + return parse(elem).isoformat().split("T")[1] + if sql_datatype == "time without time zone": # Replace hour=24 with hour=0 - if elem.startswith('24'): - elem = elem.replace('24', '00', 1) - return parse(elem).isoformat().split('T')[1] - if sql_datatype == 'bit': + if elem.startswith("24"): + elem = elem.replace("24", "00", 1) + return parse(elem).isoformat().split("T")[1] + if sql_datatype == "bit": # for arrays, elem will == True # for ordinary bits, elem will == '1' - return elem == '1' or elem is True - if sql_datatype == 'boolean': + return elem == "1" or elem is True + if sql_datatype == "boolean": return elem - if sql_datatype == 'hstore': + if sql_datatype == "hstore": return create_hstore_elem(conn_info, elem) - if 'numeric' in sql_datatype: + if "numeric" in sql_datatype: return decimal.Decimal(elem) if isinstance(elem, int): return elem @@ -332,33 +353,52 @@ def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info): if isinstance(elem, str): return elem - raise Exception("do not know how to marshall value of type {}".format(elem.__class__)) + raise Exception( + "do not know how to marshall value of type {}".format(elem.__class__) + ) def selected_array_to_singer_value(elem, sql_datatype, conn_info): if isinstance(elem, list): - return list(map(lambda elem: selected_array_to_singer_value(elem, sql_datatype, conn_info), elem)) + return list( + map( + lambda elem: selected_array_to_singer_value( + elem, sql_datatype, conn_info + ), + elem, + ) + ) return selected_value_to_singer_value_impl(elem, sql_datatype, conn_info) def selected_value_to_singer_value(elem, sql_datatype, conn_info): # are we dealing with an array? - if sql_datatype.find('[]') > 0: + if sql_datatype.find("[]") > 0: cleaned_elem = create_array_elem(elem, sql_datatype, conn_info) - return list(map(lambda elem: selected_array_to_singer_value(elem, sql_datatype, conn_info), - (cleaned_elem or []))) + return list( + map( + lambda elem: selected_array_to_singer_value( + elem, sql_datatype, conn_info + ), + (cleaned_elem or []), + ) + ) return selected_value_to_singer_value_impl(elem, sql_datatype, conn_info) -def row_to_singer_message(stream, row, version, columns, time_extracted, md_map, conn_info): +def row_to_singer_message( + stream, row, version, columns, time_extracted, md_map, conn_info +): row_to_persist = () - md_map[('properties', '_sdc_deleted_at')] = {'sql-datatype': 'timestamp with time zone'} - md_map[('properties', '_sdc_lsn')] = {'sql-datatype': "character varying"} + md_map[("properties", "_sdc_deleted_at")] = { + "sql-datatype": "timestamp with time zone" + } + md_map[("properties", "_sdc_lsn")] = {"sql-datatype": "character varying"} for idx, elem in enumerate(row): - sql_datatype = md_map.get(('properties', columns[idx])).get('sql-datatype') + sql_datatype = md_map.get(("properties", columns[idx])).get("sql-datatype") if not sql_datatype: LOGGER.info("No sql-datatype found for stream %s: %s", stream, columns[idx]) @@ -373,109 +413,125 @@ def row_to_singer_message(stream, row, version, columns, time_extracted, md_map, stream=post_db.calculate_destination_stream_name(stream, md_map), record=rec, version=version, - time_extracted=time_extracted) + time_extracted=time_extracted, + ) # pylint: disable=unused-argument,too-many-locals def consume_message(streams, state, msg, time_extracted, conn_info): # Strip leading comma generated by write-in-chunks and parse valid JSON try: - payload = json.loads(msg.payload.lstrip(',')) + payload = json.loads(msg.payload.lstrip(",")) except Exception: return state lsn = msg.data_start - streams_lookup = {s['tap_stream_id']: s for s in streams} + streams_lookup = {s["tap_stream_id"]: s for s in streams} - tap_stream_id = post_db.compute_tap_stream_id(payload['schema'], payload['table']) + tap_stream_id = post_db.compute_tap_stream_id(payload["schema"], payload["table"]) if streams_lookup.get(tap_stream_id) is None: return state target_stream = streams_lookup[tap_stream_id] - if payload['kind'] not in {'insert', 'update', 'delete'}: - raise UnsupportedPayloadKindError("unrecognized replication operation: {}".format(payload['kind'])) + if payload["kind"] not in {"insert", "update", "delete"}: + raise UnsupportedPayloadKindError( + "unrecognized replication operation: {}".format(payload["kind"]) + ) # Get the additional fields in payload that are not in schema properties: # only inserts and updates have the list of columns that can be used to detect any different in columns diff = set() - if payload['kind'] in {'insert', 'update'}: - diff = set(payload['columnnames']).difference(target_stream['schema']['properties'].keys()) + if payload["kind"] in {"insert", "update"}: + diff = set(payload["columnnames"]).difference( + target_stream["schema"]["properties"].keys() + ) # if there is new columns in the payload that are not in the schema properties then refresh the stream schema if diff: - LOGGER.info('Detected new columns "%s", refreshing schema of stream %s', diff, target_stream['stream']) + LOGGER.info( + 'Detected new columns "%s", refreshing schema of stream %s', + diff, + target_stream["stream"], + ) # encountered a column that is not in the schema # refresh the stream schema and metadata by running discovery refresh_streams_schema(conn_info, [target_stream]) # add the automatic properties back to the stream - add_automatic_properties(target_stream, conn_info.get('debug_lsn', False)) + add_automatic_properties(target_stream, conn_info.get("debug_lsn", False)) # publish new schema - sync_common.send_schema_message(target_stream, ['lsn']) + sync_common.send_schema_message(target_stream, ["lsn"]) - stream_version = get_stream_version(target_stream['tap_stream_id'], state) - stream_md_map = metadata.to_map(target_stream['metadata']) + stream_version = get_stream_version(target_stream["tap_stream_id"], state) + stream_md_map = metadata.to_map(target_stream["metadata"]) - desired_columns = {c for c in target_stream['schema']['properties'].keys() if sync_common.should_sync_column( - stream_md_map, c)} + desired_columns = { + c + for c in target_stream["schema"]["properties"].keys() + if sync_common.should_sync_column(stream_md_map, c) + } - if payload['kind'] in {'insert', 'update'}: + if payload["kind"] in {"insert", "update"}: col_names = [] col_vals = [] - for idx, col in enumerate(payload['columnnames']): + for idx, col in enumerate(payload["columnnames"]): if col in desired_columns: col_names.append(col) - col_vals.append(payload['columnvalues'][idx]) + col_vals.append(payload["columnvalues"][idx]) - col_names.append('_sdc_deleted_at') + col_names.append("_sdc_deleted_at") col_vals.append(None) - if conn_info.get('debug_lsn'): - col_names.append('_sdc_lsn') + if conn_info.get("debug_lsn"): + col_names.append("_sdc_lsn") col_vals.append(str(lsn)) - record_message = row_to_singer_message(target_stream, - col_vals, - stream_version, - col_names, - time_extracted, - stream_md_map, - conn_info) - - elif payload['kind'] == 'delete': + record_message = row_to_singer_message( + target_stream, + col_vals, + stream_version, + col_names, + time_extracted, + stream_md_map, + conn_info, + ) + + elif payload["kind"] == "delete": col_names = [] col_vals = [] - for idx, col in enumerate(payload['oldkeys']['keynames']): + for idx, col in enumerate(payload["oldkeys"]["keynames"]): if col in desired_columns: col_names.append(col) - col_vals.append(payload['oldkeys']['keyvalues'][idx]) + col_vals.append(payload["oldkeys"]["keyvalues"][idx]) - col_names.append('_sdc_deleted_at') + col_names.append("_sdc_deleted_at") col_vals.append(singer.utils.strftime(time_extracted)) - if conn_info.get('debug_lsn'): - col_names.append('_sdc_lsn') + if conn_info.get("debug_lsn"): + col_names.append("_sdc_lsn") col_vals.append(str(lsn)) - record_message = row_to_singer_message(target_stream, - col_vals, - stream_version, - col_names, - time_extracted, - stream_md_map, - conn_info) + record_message = row_to_singer_message( + target_stream, + col_vals, + stream_version, + col_names, + time_extracted, + stream_md_map, + conn_info, + ) singer.write_message(record_message) - state = singer.write_bookmark(state, target_stream['tap_stream_id'], 'lsn', lsn) + state = singer.write_bookmark(state, target_stream["tap_stream_id"], "lsn", lsn) return state -def generate_replication_slot_name(dbname, tap_id=None, prefix='pipelinewise'): +def generate_replication_slot_name(dbname, tap_id=None, prefix="pipelinewise"): """Generate replication slot name with :param str dbname: Database name that will be part of the replication slot name @@ -486,39 +542,48 @@ def generate_replication_slot_name(dbname, tap_id=None, prefix='pipelinewise'): """ # Add tap_id to the end of the slot name if provided if tap_id: - tap_id = f'_{tap_id}' + tap_id = f"_{tap_id}" # Convert None to empty string else: - tap_id = '' + tap_id = "" - slot_name = f'{prefix}_{dbname}{tap_id}'.lower() + slot_name = f"{prefix}_{dbname}{tap_id}".lower() # Replace invalid characters to ensure replication slot name is in accordance with Postgres spec - return re.sub('[^a-z0-9_]', '_', slot_name) + return re.sub("[^a-z0-9_]", "_", slot_name) + def locate_replication_slot_by_cur(cursor, dbname, tap_id=None): slot_name_v15 = generate_replication_slot_name(dbname) slot_name_v16 = generate_replication_slot_name(dbname, tap_id) # Backward compatibility: try to locate existing v15 slot first. PPW <= 0.15.0 - cursor.execute(f"SELECT * FROM pg_replication_slots WHERE slot_name = '{slot_name_v15}'") + cursor.execute( + f"SELECT * FROM pg_replication_slots WHERE slot_name = '{slot_name_v15}'" + ) if len(cursor.fetchall()) == 1: - LOGGER.info('Using pg_replication_slot %s', slot_name_v15) + LOGGER.info("Using pg_replication_slot %s", slot_name_v15) return slot_name_v15 # v15 style replication slot not found, try to locate v16 - cursor.execute(f"SELECT * FROM pg_replication_slots WHERE slot_name = '{slot_name_v16}'") + cursor.execute( + f"SELECT * FROM pg_replication_slots WHERE slot_name = '{slot_name_v16}'" + ) if len(cursor.fetchall()) == 1: - LOGGER.info('Using pg_replication_slot %s', slot_name_v16) + LOGGER.info("Using pg_replication_slot %s", slot_name_v16) return slot_name_v16 - raise ReplicationSlotNotFoundError(f'Unable to find replication slot {slot_name_v16}') + raise ReplicationSlotNotFoundError( + f"Unable to find replication slot {slot_name_v16}" + ) def locate_replication_slot(conn_info): with post_db.open_connection(conn_info, False) as conn: with conn.cursor() as cur: - return locate_replication_slot_by_cur(cur, conn_info['dbname'], conn_info['tap_id']) + return locate_replication_slot_by_cur( + cur, conn_info["dbname"], conn_info["tap_id"] + ) # pylint: disable=anomalous-backslash-in-string @@ -534,26 +599,32 @@ def streams_to_wal2json_tables(streams): :return: tables(str): comma separated and escaped list of tables, compatible for wal2json plugin :rtype: str """ + def escape_spec_chars(string): escaped = string wal2json_special_chars = " ',.*" for ch in wal2json_special_chars: - escaped = escaped.replace(ch, f'\\{ch}') + escaped = escaped.replace(ch, f"\\{ch}") return escaped tables = [] for s in streams: - schema_name = escape_spec_chars(s['metadata'][0]['metadata']['schema-name']) - table_name = escape_spec_chars(s['table_name']) + schema_name = escape_spec_chars(s["metadata"][0]["metadata"]["schema-name"]) + table_name = escape_spec_chars(s["table_name"]) - tables.append(f'{schema_name}.{table_name}') + tables.append(f"{schema_name}.{table_name}") - return ','.join(tables) + return ",".join(tables) def sync_tables(conn_info, logical_streams, state, end_lsn, state_file): state_comitted = state - lsn_comitted = min([get_bookmark(state_comitted, s['tap_stream_id'], 'lsn') for s in logical_streams]) + lsn_comitted = min( + [ + get_bookmark(state_comitted, s["tap_stream_id"], "lsn") + for s in logical_streams + ] + ) start_lsn = lsn_comitted lsn_to_flush = None time_extracted = utils.now() @@ -563,14 +634,16 @@ def sync_tables(conn_info, logical_streams, state, end_lsn, state_file): lsn_received_timestamp = None lsn_processed_count = 0 start_run_timestamp = datetime.datetime.utcnow() - max_run_seconds = conn_info['max_run_seconds'] - break_at_end_lsn = conn_info['break_at_end_lsn'] - logical_poll_total_seconds = conn_info['logical_poll_total_seconds'] or 10800 # 3 hours + max_run_seconds = conn_info["max_run_seconds"] + break_at_end_lsn = conn_info["break_at_end_lsn"] + logical_poll_total_seconds = ( + conn_info["logical_poll_total_seconds"] or 10800 + ) # 3 hours poll_interval = 10 poll_timestamp = None for s in logical_streams: - sync_common.send_schema_message(s, ['lsn']) + sync_common.send_schema_message(s, ["lsn"]) version = get_pg_version(conn_info) @@ -581,26 +654,34 @@ def sync_tables(conn_info, logical_streams, state, end_lsn, state_file): # Set session wal_sender_timeout for PG12 and above if version >= 120000: wal_sender_timeout = 10800000 # 10800000ms = 3 hours - LOGGER.info('Set session wal_sender_timeout = %i milliseconds', wal_sender_timeout) + LOGGER.info( + "Set session wal_sender_timeout = %i milliseconds", wal_sender_timeout + ) cur.execute("SET SESSION wal_sender_timeout = {}".format(wal_sender_timeout)) try: - LOGGER.info('Request wal streaming from %s to %s (slot %s)', - int_to_lsn(start_lsn), - int_to_lsn(end_lsn), - slot) + LOGGER.info( + "Request wal streaming from %s to %s (slot %s)", + int_to_lsn(start_lsn), + int_to_lsn(end_lsn), + slot, + ) # psycopg2 2.8.4 will send a keep-alive message to postgres every status_interval - cur.start_replication(slot_name=slot, - decode=True, - start_lsn=start_lsn, - status_interval=poll_interval, - options={ - 'write-in-chunks': 1, - 'add-tables': streams_to_wal2json_tables(logical_streams) - }) + cur.start_replication( + slot_name=slot, + decode=True, + start_lsn=start_lsn, + status_interval=poll_interval, + options={ + "write-in-chunks": 1, + "add-tables": streams_to_wal2json_tables(logical_streams), + }, + ) except psycopg2.ProgrammingError as ex: - raise Exception("Unable to start replication with logical replication (slot {})".format(ex)) from ex + raise Exception( + "Unable to start replication with logical replication (slot {})".format(ex) + ) from ex lsn_received_timestamp = datetime.datetime.utcnow() poll_timestamp = datetime.datetime.utcnow() @@ -608,9 +689,11 @@ def sync_tables(conn_info, logical_streams, state, end_lsn, state_file): while True: # Disconnect when no data received for logical_poll_total_seconds # needs to be long enough to wait for the largest single wal payload to avoid unplanned timeouts - poll_duration = (datetime.datetime.utcnow() - lsn_received_timestamp).total_seconds() + poll_duration = ( + datetime.datetime.utcnow() - lsn_received_timestamp + ).total_seconds() if poll_duration > logical_poll_total_seconds: - LOGGER.info('Breaking - %i seconds of polling with no data', poll_duration) + LOGGER.info("Breaking - %i seconds of polling with no data", poll_duration) break try: @@ -621,31 +704,47 @@ def sync_tables(conn_info, logical_streams, state, end_lsn, state_file): if msg: if (break_at_end_lsn) and (msg.data_start > end_lsn): - LOGGER.info('Breaking - latest wal message %s is past end_lsn %s', - int_to_lsn(msg.data_start), - int_to_lsn(end_lsn)) + LOGGER.info( + "Breaking - latest wal message %s is past end_lsn %s", + int_to_lsn(msg.data_start), + int_to_lsn(end_lsn), + ) break - if datetime.datetime.utcnow() >= (start_run_timestamp + datetime.timedelta(seconds=max_run_seconds)): - LOGGER.info('Breaking - reached max_run_seconds of %i', max_run_seconds) + if datetime.datetime.utcnow() >= ( + start_run_timestamp + datetime.timedelta(seconds=max_run_seconds) + ): + LOGGER.info("Breaking - reached max_run_seconds of %i", max_run_seconds) break - state = consume_message(logical_streams, state, msg, time_extracted, conn_info) + state = consume_message( + logical_streams, state, msg, time_extracted, conn_info + ) # When using wal2json with write-in-chunks, multiple messages can have the same lsn # This is to ensure we only flush to lsn that has completed entirely if lsn_currently_processing is None: lsn_currently_processing = msg.data_start - LOGGER.info('First wal message received is %s', int_to_lsn(lsn_currently_processing)) + LOGGER.info( + "First wal message received is %s", + int_to_lsn(lsn_currently_processing), + ) # Flush Postgres wal up to lsn comitted in previous run, or first lsn received in this run lsn_to_flush = lsn_comitted if lsn_currently_processing < lsn_to_flush: lsn_to_flush = lsn_currently_processing - LOGGER.info('Confirming write up to %s, flush to %s', - int_to_lsn(lsn_to_flush), - int_to_lsn(lsn_to_flush)) - cur.send_feedback(write_lsn=lsn_to_flush, flush_lsn=lsn_to_flush, reply=True, force=True) + LOGGER.info( + "Confirming write up to %s, flush to %s", + int_to_lsn(lsn_to_flush), + int_to_lsn(lsn_to_flush), + ) + cur.send_feedback( + write_lsn=lsn_to_flush, + flush_lsn=lsn_to_flush, + reply=True, + force=True, + ) elif int(msg.data_start) > lsn_currently_processing: lsn_last_processed = lsn_currently_processing @@ -653,34 +752,58 @@ def sync_tables(conn_info, logical_streams, state, end_lsn, state_file): lsn_received_timestamp = datetime.datetime.utcnow() lsn_processed_count = lsn_processed_count + 1 if lsn_processed_count >= UPDATE_BOOKMARK_PERIOD: - LOGGER.debug('Updating bookmarks for all streams to lsn = %s (%s)', - lsn_last_processed, - int_to_lsn(lsn_last_processed)) + LOGGER.debug( + "Updating bookmarks for all streams to lsn = %s (%s)", + lsn_last_processed, + int_to_lsn(lsn_last_processed), + ) for s in logical_streams: - state = singer.write_bookmark(state, s['tap_stream_id'], 'lsn', lsn_last_processed) - singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + state = singer.write_bookmark( + state, s["tap_stream_id"], "lsn", lsn_last_processed + ) + singer.write_message( + singer.StateMessage(value=copy.deepcopy(state)) + ) lsn_processed_count = 0 # Every poll_interval, update latest comitted lsn position from the state_file - if datetime.datetime.utcnow() >= (poll_timestamp + datetime.timedelta(seconds=poll_interval)): + if datetime.datetime.utcnow() >= ( + poll_timestamp + datetime.timedelta(seconds=poll_interval) + ): if lsn_currently_processing is None: - LOGGER.info('Waiting for first wal message') + LOGGER.info("Waiting for first wal message") else: - LOGGER.info('Lastest wal message received was %s', int_to_lsn(lsn_last_processed)) + LOGGER.info( + "Lastest wal message received was %s", + int_to_lsn(lsn_last_processed), + ) try: with open(state_file, mode="r", encoding="utf-8") as fh: state_comitted = json.load(fh) except Exception: - LOGGER.debug('Unable to open and parse %s', state_file) + LOGGER.debug("Unable to open and parse %s", state_file) finally: lsn_comitted = min( - [get_bookmark(state_comitted, s['tap_stream_id'], 'lsn') for s in logical_streams]) - if (lsn_currently_processing > lsn_comitted) and (lsn_comitted > lsn_to_flush): + [ + get_bookmark(state_comitted, s["tap_stream_id"], "lsn") + for s in logical_streams + ] + ) + if (lsn_currently_processing > lsn_comitted) and ( + lsn_comitted > lsn_to_flush + ): lsn_to_flush = lsn_comitted - LOGGER.info('Confirming write up to %s, flush to %s', - int_to_lsn(lsn_to_flush), - int_to_lsn(lsn_to_flush)) - cur.send_feedback(write_lsn=lsn_to_flush, flush_lsn=lsn_to_flush, reply=True, force=True) + LOGGER.info( + "Confirming write up to %s, flush to %s", + int_to_lsn(lsn_to_flush), + int_to_lsn(lsn_to_flush), + ) + cur.send_feedback( + write_lsn=lsn_to_flush, + flush_lsn=lsn_to_flush, + reply=True, + force=True, + ) poll_timestamp = datetime.datetime.utcnow() @@ -691,16 +814,22 @@ def sync_tables(conn_info, logical_streams, state, end_lsn, state_file): if lsn_last_processed: if lsn_comitted > lsn_last_processed: lsn_last_processed = lsn_comitted - LOGGER.info('Current lsn_last_processed %s is older than lsn_comitted %s', - int_to_lsn(lsn_last_processed), - int_to_lsn(lsn_comitted)) - - LOGGER.info('Updating bookmarks for all streams to lsn = %s (%s)', - lsn_last_processed, - int_to_lsn(lsn_last_processed)) + LOGGER.info( + "Current lsn_last_processed %s is older than lsn_comitted %s", + int_to_lsn(lsn_last_processed), + int_to_lsn(lsn_comitted), + ) + + LOGGER.info( + "Updating bookmarks for all streams to lsn = %s (%s)", + lsn_last_processed, + int_to_lsn(lsn_last_processed), + ) for s in logical_streams: - state = singer.write_bookmark(state, s['tap_stream_id'], 'lsn', lsn_last_processed) + state = singer.write_bookmark( + state, s["tap_stream_id"], "lsn", lsn_last_processed + ) singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) return state diff --git a/tests/test_clear_state_on_replication_change.py b/tests/test_clear_state_on_replication_change.py index 86ae98b7..9c248e54 100644 --- a/tests/test_clear_state_on_replication_change.py +++ b/tests/test_clear_state_on_replication_change.py @@ -1,98 +1,285 @@ import unittest import tap_postgres -tap_stream_id = 'chicken_table' +tap_stream_id = "chicken_table" -class TestClearState(unittest.TestCase): +class TestClearState(unittest.TestCase): def test_incremental_happy(self): - state = {'bookmarks' : {tap_stream_id : { 'version' : 1, "replication_key" : 'updated_at', 'replication_key_value' : '2017-01-01T00:00:03+00:00', 'last_replication_method' : 'INCREMENTAL'}}} - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, 'updated_at', 'INCREMENTAL') + state = { + "bookmarks": { + tap_stream_id: { + "version": 1, + "replication_key": "updated_at", + "replication_key_value": "2017-01-01T00:00:03+00:00", + "last_replication_method": "INCREMENTAL", + } + } + } + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, "updated_at", "INCREMENTAL" + ) self.assertEqual(nascent_state, state) def test_incremental_changing_replication_keys(self): - state = {'bookmarks' : {tap_stream_id : { 'version' : 1, "replication_key" : 'updated_at', 'replication_key_value' : '2017-01-01T00:00:03+00:00', 'last_replication_method' : 'INCREMENTAL'}}} + state = { + "bookmarks": { + tap_stream_id: { + "version": 1, + "replication_key": "updated_at", + "replication_key_value": "2017-01-01T00:00:03+00:00", + "last_replication_method": "INCREMENTAL", + } + } + } - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, 'updated_at_2', 'INCREMENTAL') - self.assertEqual(nascent_state, {'bookmarks' : {tap_stream_id : {'last_replication_method' : 'INCREMENTAL'}}}) + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, "updated_at_2", "INCREMENTAL" + ) + self.assertEqual( + nascent_state, + {"bookmarks": {tap_stream_id: {"last_replication_method": "INCREMENTAL"}}}, + ) def test_incremental_changing_replication_key_interrupted(self): - xmin = '3737373' - state = {'bookmarks' : {tap_stream_id : { 'version' : 1, 'xmin' : xmin, "replication_key" : 'updated_at', 'replication_key_value' : '2017-01-01T00:00:03+00:00', - 'last_replication_method' : 'INCREMENTAL'}}} - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, 'updated_at_2', 'INCREMENTAL') - self.assertEqual(nascent_state, {'bookmarks' : {tap_stream_id : { 'last_replication_method' : 'INCREMENTAL'}}}) + xmin = "3737373" + state = { + "bookmarks": { + tap_stream_id: { + "version": 1, + "xmin": xmin, + "replication_key": "updated_at", + "replication_key_value": "2017-01-01T00:00:03+00:00", + "last_replication_method": "INCREMENTAL", + } + } + } + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, "updated_at_2", "INCREMENTAL" + ) + self.assertEqual( + nascent_state, + {"bookmarks": {tap_stream_id: {"last_replication_method": "INCREMENTAL"}}}, + ) def test_full_table_to_incremental(self): - #interrupted full table -> incremental - xmin = '3737373' - state = {'bookmarks' : {tap_stream_id : { 'version' : 1, 'xmin' : xmin, "last_replication_method" : "FULL_TABLE"}}} - - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, 'updated_at', 'INCREMENTAL') - self.assertEqual(nascent_state, {'bookmarks' : {tap_stream_id : {"last_replication_method" : "INCREMENTAL"}}}) + # interrupted full table -> incremental + xmin = "3737373" + state = { + "bookmarks": { + tap_stream_id: { + "version": 1, + "xmin": xmin, + "last_replication_method": "FULL_TABLE", + } + } + } - state = {'bookmarks' : {tap_stream_id : { 'version' : 1, "last_replication_method" : "FULL_TABLE"}}} - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, 'updated_at', 'INCREMENTAL') - self.assertEqual(nascent_state, {'bookmarks' : {tap_stream_id : {"last_replication_method" : "INCREMENTAL"}}}) + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, "updated_at", "INCREMENTAL" + ) + self.assertEqual( + nascent_state, + {"bookmarks": {tap_stream_id: {"last_replication_method": "INCREMENTAL"}}}, + ) + state = { + "bookmarks": { + tap_stream_id: {"version": 1, "last_replication_method": "FULL_TABLE"} + } + } + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, "updated_at", "INCREMENTAL" + ) + self.assertEqual( + nascent_state, + {"bookmarks": {tap_stream_id: {"last_replication_method": "INCREMENTAL"}}}, + ) def test_log_based_to_incremental(self): - state = {'bookmarks' : {tap_stream_id : { 'version' : 1, 'lsn' : 34343434, "last_replication_method" : "LOG_BASED"}}} - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, 'updated_at', 'INCREMENTAL') - self.assertEqual(nascent_state, {'bookmarks' : {tap_stream_id : {"last_replication_method" : "INCREMENTAL"}}}) + state = { + "bookmarks": { + tap_stream_id: { + "version": 1, + "lsn": 34343434, + "last_replication_method": "LOG_BASED", + } + } + } + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, "updated_at", "INCREMENTAL" + ) + self.assertEqual( + nascent_state, + {"bookmarks": {tap_stream_id: {"last_replication_method": "INCREMENTAL"}}}, + ) - state = {'bookmarks' : {tap_stream_id : { 'version' : 1, 'lsn' : 34343434, 'xmin' : 34343, "last_replication_method" : "LOG_BASED"}}} - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, 'updated_at', 'INCREMENTAL') - self.assertEqual(nascent_state, {'bookmarks' : {tap_stream_id : {"last_replication_method" : "INCREMENTAL"}}}) + state = { + "bookmarks": { + tap_stream_id: { + "version": 1, + "lsn": 34343434, + "xmin": 34343, + "last_replication_method": "LOG_BASED", + } + } + } + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, "updated_at", "INCREMENTAL" + ) + self.assertEqual( + nascent_state, + {"bookmarks": {tap_stream_id: {"last_replication_method": "INCREMENTAL"}}}, + ) - #full table tests + # full table tests def test_full_table_happy(self): - state = {'bookmarks' : {tap_stream_id : { 'version' : 88, "last_replication_method" : "FULL_TABLE"}}} - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, None, 'FULL_TABLE') + state = { + "bookmarks": { + tap_stream_id: {"version": 88, "last_replication_method": "FULL_TABLE"} + } + } + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, None, "FULL_TABLE" + ) self.assertEqual(nascent_state, state) def test_full_table_interrupted(self): xmin = 333333 - state = {'bookmarks' : {tap_stream_id : { 'version' : 88, "last_replication_method" : "FULL_TABLE", 'xmin' : xmin}}} - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, None, 'FULL_TABLE') - self.assertEqual(nascent_state, {'bookmarks' : {tap_stream_id : { "last_replication_method" : "FULL_TABLE", 'version': 88, 'xmin' : xmin}}}) + state = { + "bookmarks": { + tap_stream_id: { + "version": 88, + "last_replication_method": "FULL_TABLE", + "xmin": xmin, + } + } + } + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, None, "FULL_TABLE" + ) + self.assertEqual( + nascent_state, + { + "bookmarks": { + tap_stream_id: { + "last_replication_method": "FULL_TABLE", + "version": 88, + "xmin": xmin, + } + } + }, + ) def test_incremental_to_full_table(self): - state = {'bookmarks' : {tap_stream_id : { 'version' : 88, "last_replication_method" : "INCREMENTAL", 'replication_key' : 'updated_at', 'replication_key_value' : 'i will be removed'}}} - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, None, 'FULL_TABLE') - self.assertEqual(nascent_state, {'bookmarks' : {tap_stream_id : { "last_replication_method" : "FULL_TABLE"}}}) + state = { + "bookmarks": { + tap_stream_id: { + "version": 88, + "last_replication_method": "INCREMENTAL", + "replication_key": "updated_at", + "replication_key_value": "i will be removed", + } + } + } + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, None, "FULL_TABLE" + ) + self.assertEqual( + nascent_state, + {"bookmarks": {tap_stream_id: {"last_replication_method": "FULL_TABLE"}}}, + ) def test_log_based_to_full_table(self): - state = {'bookmarks' : {tap_stream_id : { 'version' : 88, "last_replication_method" : "LOG_BASED", 'lsn' : 343434}}} - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, None, 'FULL_TABLE') - self.assertEqual(nascent_state, {'bookmarks' : {tap_stream_id : { "last_replication_method" : "FULL_TABLE"}}}) + state = { + "bookmarks": { + tap_stream_id: { + "version": 88, + "last_replication_method": "LOG_BASED", + "lsn": 343434, + } + } + } + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, None, "FULL_TABLE" + ) + self.assertEqual( + nascent_state, + {"bookmarks": {tap_stream_id: {"last_replication_method": "FULL_TABLE"}}}, + ) - - #log based tests + # log based tests def test_log_based_happy(self): lsn = 43434343 - state = {'bookmarks' : {tap_stream_id : { 'version' : 88, "last_replication_method" : "LOG_BASED", 'lsn' : lsn}}} - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, None, 'LOG_BASED') + state = { + "bookmarks": { + tap_stream_id: { + "version": 88, + "last_replication_method": "LOG_BASED", + "lsn": lsn, + } + } + } + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, None, "LOG_BASED" + ) self.assertEqual(nascent_state, state) lsn = 43434343 xmin = 11111 - state = {'bookmarks' : {tap_stream_id : { 'version' : 88, "last_replication_method" : "LOG_BASED", 'lsn' : lsn, 'xmin' : xmin}}} - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, None, 'LOG_BASED') + state = { + "bookmarks": { + tap_stream_id: { + "version": 88, + "last_replication_method": "LOG_BASED", + "lsn": lsn, + "xmin": xmin, + } + } + } + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, None, "LOG_BASED" + ) self.assertEqual(nascent_state, state) def test_incremental_to_log_based(self): - state = {'bookmarks' : {tap_stream_id : { 'version' : 88, "last_replication_method" : "INCREMENTAL", 'replication_key' : 'updated_at', 'replication_key_value' : 'i will be removed'}}} - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, None, 'LOG_BASED') - self.assertEqual(nascent_state, {'bookmarks' : {tap_stream_id : { "last_replication_method" : "LOG_BASED"}}}) + state = { + "bookmarks": { + tap_stream_id: { + "version": 88, + "last_replication_method": "INCREMENTAL", + "replication_key": "updated_at", + "replication_key_value": "i will be removed", + } + } + } + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, None, "LOG_BASED" + ) + self.assertEqual( + nascent_state, + {"bookmarks": {tap_stream_id: {"last_replication_method": "LOG_BASED"}}}, + ) def test_full_table_to_log_based(self): - state = {'bookmarks' : {tap_stream_id : { 'version' : 2222, "last_replication_method" : "FULL_TABLE", 'xmin' : 2}}} - nascent_state = tap_postgres.clear_state_on_replication_change(state, tap_stream_id, None, 'LOG_BASED') - self.assertEqual(nascent_state, {'bookmarks' : {tap_stream_id : { "last_replication_method" : "LOG_BASED"}}}) - + state = { + "bookmarks": { + tap_stream_id: { + "version": 2222, + "last_replication_method": "FULL_TABLE", + "xmin": 2, + } + } + } + nascent_state = tap_postgres.clear_state_on_replication_change( + state, tap_stream_id, None, "LOG_BASED" + ) + self.assertEqual( + nascent_state, + {"bookmarks": {tap_stream_id: {"last_replication_method": "LOG_BASED"}}}, + ) -if __name__== "__main__": +if __name__ == "__main__": test1 = TestClearState() test1.test_full_table_to_log_based() diff --git a/tests/test_db.py b/tests/test_db.py index 7488a0d5..5f2984e0 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -9,147 +9,180 @@ class TestDbFunctions(unittest.TestCase): def test_value_to_singer_value(self): """Test if every element converted from sql_datatype to the correct singer type""" # JSON and JSONB should be converted to dictionaries - self.assertEqual(db.selected_value_to_singer_value_impl('{"test": 123}', 'json'), {'test': 123}) - self.assertEqual(db.selected_value_to_singer_value_impl('{"test": 123}', 'jsonb'), {'test': 123}) + self.assertEqual( + db.selected_value_to_singer_value_impl('{"test": 123}', "json"), + {"test": 123}, + ) + self.assertEqual( + db.selected_value_to_singer_value_impl('{"test": 123}', "jsonb"), + {"test": 123}, + ) # time with time zone values should be converted to UTC and time zone dropped # Hour 24 should be taken as 0 - self.assertEqual(db.selected_value_to_singer_value_impl('12:00:00-0800', 'time with time zone'), '20:00:00') - self.assertEqual(db.selected_value_to_singer_value_impl('24:00:00-0800', 'time with time zone'), '08:00:00') + self.assertEqual( + db.selected_value_to_singer_value_impl( + "12:00:00-0800", "time with time zone" + ), + "20:00:00", + ) + self.assertEqual( + db.selected_value_to_singer_value_impl( + "24:00:00-0800", "time with time zone" + ), + "08:00:00", + ) # time without time zone values should be converted to UTC and time zone dropped - self.assertEqual(db.selected_value_to_singer_value_impl('12:00:00', 'time without time zone'), '12:00:00') + self.assertEqual( + db.selected_value_to_singer_value_impl( + "12:00:00", "time without time zone" + ), + "12:00:00", + ) # Hour 24 should be taken as 0 - self.assertEqual(db.selected_value_to_singer_value_impl('24:00:00', 'time without time zone'), '00:00:00') + self.assertEqual( + db.selected_value_to_singer_value_impl( + "24:00:00", "time without time zone" + ), + "00:00:00", + ) # timestamp with time zone should be converted to iso format - self.assertEqual(db.selected_value_to_singer_value_impl('2020-05-01T12:00:00-0800', - 'timestamp with time zone'), - '2020-05-01T12:00:00-0800') + self.assertEqual( + db.selected_value_to_singer_value_impl( + "2020-05-01T12:00:00-0800", "timestamp with time zone" + ), + "2020-05-01T12:00:00-0800", + ) # bit should be True only if elem is '1' - self.assertEqual(db.selected_value_to_singer_value_impl('1', 'bit'), True) - self.assertEqual(db.selected_value_to_singer_value_impl('0', 'bit'), False) - self.assertEqual(db.selected_value_to_singer_value_impl(1, 'bit'), False) - self.assertEqual(db.selected_value_to_singer_value_impl(0, 'bit'), False) + self.assertEqual(db.selected_value_to_singer_value_impl("1", "bit"), True) + self.assertEqual(db.selected_value_to_singer_value_impl("0", "bit"), False) + self.assertEqual(db.selected_value_to_singer_value_impl(1, "bit"), False) + self.assertEqual(db.selected_value_to_singer_value_impl(0, "bit"), False) # boolean should be True in case of numeric 1 and logical True - self.assertEqual(db.selected_value_to_singer_value_impl(1, 'boolean'), True) - self.assertEqual(db.selected_value_to_singer_value_impl(True, 'boolean'), True) - self.assertEqual(db.selected_value_to_singer_value_impl(0, 'boolean'), False) - self.assertEqual(db.selected_value_to_singer_value_impl(False, 'boolean'), False) + self.assertEqual(db.selected_value_to_singer_value_impl(1, "boolean"), True) + self.assertEqual(db.selected_value_to_singer_value_impl(True, "boolean"), True) + self.assertEqual(db.selected_value_to_singer_value_impl(0, "boolean"), False) + self.assertEqual( + db.selected_value_to_singer_value_impl(False, "boolean"), False + ) def test_prepare_columns_sql(self): - self.assertEqual(' "my_column" ', db.prepare_columns_sql('my_column')) + self.assertEqual(' "my_column" ', db.prepare_columns_sql("my_column")) def test_prepare_columns_for_select_sql_with_timestamp_ntz_column(self): self.assertEqual( - 'CASE WHEN "my_column" < \'0001-01-01 00:00:00.000\' OR ' - ' "my_column" > \'9999-12-31 23:59:59.999\' THEN \'9999-12-31 23:59:59.999\' ' + "CASE WHEN \"my_column\" < '0001-01-01 00:00:00.000' OR " + " \"my_column\" > '9999-12-31 23:59:59.999' THEN '9999-12-31 23:59:59.999' " 'ELSE "my_column" END AS "my_column" ', - db.prepare_columns_for_select_sql('my_column', - { - ('properties', 'my_column'): { - 'sql-datatype': 'timestamp without time zone' - } - } - ) + db.prepare_columns_for_select_sql( + "my_column", + { + ("properties", "my_column"): { + "sql-datatype": "timestamp without time zone" + } + }, + ), ) def test_prepare_columns_for_select_sql_with_timestamp_tz_column(self): self.assertEqual( - 'CASE WHEN "my_column" < \'0001-01-01 00:00:00.000\' OR ' - ' "my_column" > \'9999-12-31 23:59:59.999\' THEN \'9999-12-31 23:59:59.999\' ' + "CASE WHEN \"my_column\" < '0001-01-01 00:00:00.000' OR " + " \"my_column\" > '9999-12-31 23:59:59.999' THEN '9999-12-31 23:59:59.999' " 'ELSE "my_column" END AS "my_column" ', - db.prepare_columns_for_select_sql('my_column', - { - ('properties', 'my_column'): { - 'sql-datatype': 'timestamp with time zone' - } - } - ) + db.prepare_columns_for_select_sql( + "my_column", + { + ("properties", "my_column"): { + "sql-datatype": "timestamp with time zone" + } + }, + ), ) def test_prepare_columns_for_select_sql_with_timestamp_ntz_array_column(self): self.assertEqual( ' "my_column" ', - db.prepare_columns_for_select_sql('my_column', - { - ('properties', 'my_column'): { - 'sql-datatype': 'timestamp without time zone[]' - } - } - ) + db.prepare_columns_for_select_sql( + "my_column", + { + ("properties", "my_column"): { + "sql-datatype": "timestamp without time zone[]" + } + }, + ), ) def test_prepare_columns_for_select_sql_with_timestamp_tz_array_column(self): self.assertEqual( ' "my_column" ', - db.prepare_columns_for_select_sql('my_column', - { - ('properties', 'my_column'): { - 'sql-datatype': 'timestamp with time zone[]' - } - } - ) + db.prepare_columns_for_select_sql( + "my_column", + { + ("properties", "my_column"): { + "sql-datatype": "timestamp with time zone[]" + } + }, + ), ) def test_prepare_columns_for_select_sql_with_not_timestamp_column(self): self.assertEqual( ' "my_column" ', - db.prepare_columns_for_select_sql('my_column', - { - ('properties', 'my_column'): { - 'sql-datatype': 'int' - } - } - ) + db.prepare_columns_for_select_sql( + "my_column", {("properties", "my_column"): {"sql-datatype": "int"}} + ), ) def test_prepare_columns_for_select_sql_with_column_not_in_map(self): self.assertEqual( ' "my_column" ', - db.prepare_columns_for_select_sql('my_column', - { - ('properties', 'nope'): { - 'sql-datatype': 'int' - } - } - ) + db.prepare_columns_for_select_sql( + "my_column", {("properties", "nope"): {"sql-datatype": "int"}} + ), ) def test_selected_value_to_singer_value_impl_with_null_json_returns_None(self): - output = db.selected_value_to_singer_value_impl(None, 'json') + output = db.selected_value_to_singer_value_impl(None, "json") self.assertEqual(None, output) - def test_selected_value_to_singer_value_impl_with_empty_json_returns_empty_dict(self): - output = db.selected_value_to_singer_value_impl('{}', 'json') + def test_selected_value_to_singer_value_impl_with_empty_json_returns_empty_dict( + self, + ): + output = db.selected_value_to_singer_value_impl("{}", "json") self.assertEqual({}, output) - def test_selected_value_to_singer_value_impl_with_non_empty_json_returns_equivalent_dict(self): - output = db.selected_value_to_singer_value_impl('{"key1": "A", "key2": [{"kk": "yo"}, {}]}', 'json') + def test_selected_value_to_singer_value_impl_with_non_empty_json_returns_equivalent_dict( + self, + ): + output = db.selected_value_to_singer_value_impl( + '{"key1": "A", "key2": [{"kk": "yo"}, {}]}', "json" + ) - self.assertEqual({ - 'key1': 'A', - 'key2': [{'kk': 'yo'}, {}] - }, output) + self.assertEqual({"key1": "A", "key2": [{"kk": "yo"}, {}]}, output) def test_selected_value_to_singer_value_impl_with_null_jsonb_returns_None(self): - output = db.selected_value_to_singer_value_impl(None, 'jsonb') + output = db.selected_value_to_singer_value_impl(None, "jsonb") self.assertEqual(None, output) - def test_selected_value_to_singer_value_impl_with_empty_jsonb_returns_empty_dict(self): - output = db.selected_value_to_singer_value_impl('{}', 'jsonb') + def test_selected_value_to_singer_value_impl_with_empty_jsonb_returns_empty_dict( + self, + ): + output = db.selected_value_to_singer_value_impl("{}", "jsonb") self.assertEqual({}, output) - def test_selected_value_to_singer_value_impl_with_non_empty_jsonb_returns_equivalent_dict(self): - output = db.selected_value_to_singer_value_impl('{"key1": "A", "key2": [{"kk": "yo"}, {}]}', 'jsonb') + def test_selected_value_to_singer_value_impl_with_non_empty_jsonb_returns_equivalent_dict( + self, + ): + output = db.selected_value_to_singer_value_impl( + '{"key1": "A", "key2": [{"kk": "yo"}, {}]}', "jsonb" + ) - self.assertEqual({ - 'key1': 'A', - 'key2': [{'kk': 'yo'}, {}] - }, output) + self.assertEqual({"key1": "A", "key2": [{"kk": "yo"}, {}]}, output) diff --git a/tests/test_discovery.py b/tests/test_discovery.py index 1dd48aaa..7d9b39cb 100644 --- a/tests/test_discovery.py +++ b/tests/test_discovery.py @@ -6,591 +6,1115 @@ import tap_postgres.db as post_db from singer import get_logger, metadata from psycopg2.extensions import quote_ident + try: - from tests.utils import get_test_connection, ensure_test_table, get_test_connection_config + from tests.utils import ( + get_test_connection, + ensure_test_table, + get_test_connection_config, + ) except ImportError: from utils import get_test_connection, ensure_test_table, get_test_connection_config LOGGER = get_logger() + def do_not_dump_catalog(catalog): pass + tap_postgres.dump_catalog = do_not_dump_catalog + class TestStringTableWithPK(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' + table_name = "CHICKEN TIMES" def setUp(self): - table_spec = {"columns": [{"name" : "id", "type" : "integer", "primary_key" : True, "serial" : True}, - {"name" : '"character-varying_name"', "type": "character varying"}, - {"name" : '"varchar-name"', "type": "varchar(28)"}, - {"name" : 'char_name', "type": "char(10)"}, - {"name" : '"text-name"', "type": "text"}], - "name" : TestStringTableWithPK.table_name} - ensure_test_table(table_spec) + table_spec = { + "columns": [ + {"name": "id", "type": "integer", "primary_key": True, "serial": True}, + {"name": '"character-varying_name"', "type": "character varying"}, + {"name": '"varchar-name"', "type": "varchar(28)"}, + {"name": "char_name", "type": "char(10)"}, + {"name": '"text-name"', "type": "text"}, + ], + "name": TestStringTableWithPK.table_name, + } + ensure_test_table(table_spec) def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == "public-CHICKEN TIMES"] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - self.assertEqual(TestStringTableWithPK.table_name, stream_dict.get('table_name')) - self.assertEqual(TestStringTableWithPK.table_name, stream_dict.get('stream')) - - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) - - - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {() : {'table-key-properties': ['id'], 'database-name': 'postgres', - 'schema-name': 'public', 'is-view': False, 'row-count': 0}, - ('properties', 'character-varying_name') : {'inclusion': 'available', 'sql-datatype' : 'character varying', 'selected-by-default' : True}, - ('properties', 'id') : {'inclusion': 'automatic', 'sql-datatype' : 'integer', 'selected-by-default' : True}, - ('properties', 'varchar-name') : {'inclusion': 'available', 'sql-datatype' : 'character varying', 'selected-by-default' : True}, - ('properties', 'text-name') : {'inclusion': 'available', 'sql-datatype' : 'text', 'selected-by-default' : True}, - ('properties', 'char_name'): {'selected-by-default': True, 'inclusion': 'available', 'sql-datatype': 'character'}}) - - self.assertEqual({'properties': {'id': {'type': ['integer'], - 'maximum': 2147483647, - 'minimum': -2147483648}, - 'character-varying_name': {'type': ['null', 'string']}, - 'varchar-name': {'type': ['null', 'string'], 'maxLength': 28}, - 'char_name': {'type': ['null', 'string'], 'maxLength': 10}, - 'text-name': {'type': ['null', 'string']}}, - 'type': 'object', - 'definitions' : BASE_RECURSIVE_SCHEMAS}, stream_dict.get('schema')) + self.assertEqual( + TestStringTableWithPK.table_name, stream_dict.get("table_name") + ) + self.assertEqual(TestStringTableWithPK.table_name, stream_dict.get("stream")) + + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) + + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "table-key-properties": ["id"], + "database-name": "postgres", + "schema-name": "public", + "is-view": False, + "row-count": 0, + }, + ("properties", "character-varying_name"): { + "inclusion": "available", + "sql-datatype": "character varying", + "selected-by-default": True, + }, + ("properties", "id"): { + "inclusion": "automatic", + "sql-datatype": "integer", + "selected-by-default": True, + }, + ("properties", "varchar-name"): { + "inclusion": "available", + "sql-datatype": "character varying", + "selected-by-default": True, + }, + ("properties", "text-name"): { + "inclusion": "available", + "sql-datatype": "text", + "selected-by-default": True, + }, + ("properties", "char_name"): { + "selected-by-default": True, + "inclusion": "available", + "sql-datatype": "character", + }, + }, + ) + + self.assertEqual( + { + "properties": { + "id": { + "type": ["integer"], + "maximum": 2147483647, + "minimum": -2147483648, + }, + "character-varying_name": {"type": ["null", "string"]}, + "varchar-name": {"type": ["null", "string"], "maxLength": 28}, + "char_name": {"type": ["null", "string"], "maxLength": 10}, + "text-name": {"type": ["null", "string"]}, + }, + "type": "object", + "definitions": BASE_RECURSIVE_SCHEMAS, + }, + stream_dict.get("schema"), + ) class TestIntegerTable(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' + table_name = "CHICKEN TIMES" def setUp(self): - table_spec = {"columns": [{"name" : "id", "type" : "integer", "serial" : True}, - {"name" : 'size integer', "type" : "integer", "quoted" : True}, - {"name" : 'size smallint', "type" : "smallint", "quoted" : True}, - {"name" : 'size bigint', "type" : "bigint", "quoted" : True}], - "name" : TestIntegerTable.table_name} - ensure_test_table(table_spec) + table_spec = { + "columns": [ + {"name": "id", "type": "integer", "serial": True}, + {"name": "size integer", "type": "integer", "quoted": True}, + {"name": "size smallint", "type": "smallint", "quoted": True}, + {"name": "size bigint", "type": "bigint", "quoted": True}, + ], + "name": TestIntegerTable.table_name, + } + ensure_test_table(table_spec) def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == 'public-CHICKEN TIMES'] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - self.assertEqual(TestStringTableWithPK.table_name, stream_dict.get('table_name')) - self.assertEqual(TestStringTableWithPK.table_name, stream_dict.get('stream')) - - - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) - - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {() : {'table-key-properties': [], 'database-name': 'postgres', 'schema-name': 'public', 'is-view': False, 'row-count': 0}, - ('properties', 'id') : {'inclusion': 'available', 'sql-datatype' : 'integer', 'selected-by-default' : True}, - ('properties', 'size integer') : {'inclusion': 'available', 'sql-datatype' : 'integer', 'selected-by-default' : True}, - ('properties', 'size smallint') : {'inclusion': 'available', 'sql-datatype' : 'smallint', 'selected-by-default' : True}, - ('properties', 'size bigint') : {'inclusion': 'available', 'sql-datatype' : 'bigint', 'selected-by-default' : True}}) - - self.assertEqual({'definitions' : BASE_RECURSIVE_SCHEMAS, - 'type': 'object', - 'properties': {'id': {'type': ['null', 'integer'], 'minimum': -2147483648, 'maximum': 2147483647}, - 'size smallint': {'type': ['null', 'integer'], 'minimum': -32768, 'maximum': 32767}, - 'size integer': {'type': ['null', 'integer'], 'minimum': -2147483648, 'maximum': 2147483647}, - 'size bigint': {'type': ['null', 'integer'], 'minimum': -9223372036854775808, 'maximum': 9223372036854775807}}}, - stream_dict.get('schema')) - + self.assertEqual( + TestStringTableWithPK.table_name, stream_dict.get("table_name") + ) + self.assertEqual(TestStringTableWithPK.table_name, stream_dict.get("stream")) + + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) + + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "table-key-properties": [], + "database-name": "postgres", + "schema-name": "public", + "is-view": False, + "row-count": 0, + }, + ("properties", "id"): { + "inclusion": "available", + "sql-datatype": "integer", + "selected-by-default": True, + }, + ("properties", "size integer"): { + "inclusion": "available", + "sql-datatype": "integer", + "selected-by-default": True, + }, + ("properties", "size smallint"): { + "inclusion": "available", + "sql-datatype": "smallint", + "selected-by-default": True, + }, + ("properties", "size bigint"): { + "inclusion": "available", + "sql-datatype": "bigint", + "selected-by-default": True, + }, + }, + ) + + self.assertEqual( + { + "definitions": BASE_RECURSIVE_SCHEMAS, + "type": "object", + "properties": { + "id": { + "type": ["null", "integer"], + "minimum": -2147483648, + "maximum": 2147483647, + }, + "size smallint": { + "type": ["null", "integer"], + "minimum": -32768, + "maximum": 32767, + }, + "size integer": { + "type": ["null", "integer"], + "minimum": -2147483648, + "maximum": 2147483647, + }, + "size bigint": { + "type": ["null", "integer"], + "minimum": -9223372036854775808, + "maximum": 9223372036854775807, + }, + }, + }, + stream_dict.get("schema"), + ) class TestDecimalPK(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' + table_name = "CHICKEN TIMES" def setUp(self): - table_spec = {"columns": [{"name" : 'our_decimal', "type" : "numeric", "primary_key": True}, - {"name" : 'our_decimal_10_2', "type" : "decimal(10,2)"}, - {"name" : 'our_decimal_38_4', "type" : "decimal(38,4)"}], - "name" : TestDecimalPK.table_name} - ensure_test_table(table_spec) + table_spec = { + "columns": [ + {"name": "our_decimal", "type": "numeric", "primary_key": True}, + {"name": "our_decimal_10_2", "type": "decimal(10,2)"}, + {"name": "our_decimal_38_4", "type": "decimal(38,4)"}, + ], + "name": TestDecimalPK.table_name, + } + ensure_test_table(table_spec) def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == 'public-CHICKEN TIMES'] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) - - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {() : {'table-key-properties': ['our_decimal'], 'database-name': 'postgres', 'schema-name': 'public', 'is-view': False, 'row-count': 0}, - ('properties', 'our_decimal') : {'inclusion': 'automatic', 'sql-datatype' : 'numeric', 'selected-by-default' : True}, - ('properties', 'our_decimal_38_4') : {'inclusion': 'available', 'sql-datatype' : 'numeric', 'selected-by-default' : True}, - ('properties', 'our_decimal_10_2') : {'inclusion': 'available', 'sql-datatype' : 'numeric', 'selected-by-default' : True}}) - - self.assertEqual({'properties': {'our_decimal': {'exclusiveMaximum': True, - 'exclusiveMinimum': True, - 'multipleOf': 10 ** (0 - post_db.MAX_SCALE), - 'maximum': 10 ** (post_db.MAX_PRECISION - post_db.MAX_SCALE), - 'minimum': -10 ** (post_db.MAX_PRECISION - post_db.MAX_SCALE), - 'type': ['number']}, - 'our_decimal_10_2': {'exclusiveMaximum': True, - 'exclusiveMinimum': True, - 'maximum': 100000000, - 'minimum': -100000000, - 'multipleOf': 0.01, - 'type': ['null', 'number']}, - 'our_decimal_38_4': {'exclusiveMaximum': True, - 'exclusiveMinimum': True, - 'maximum': 10000000000000000000000000000000000, - 'minimum': -10000000000000000000000000000000000, - 'multipleOf': 0.0001, - 'type': ['null', 'number']}}, - 'type': 'object', - 'definitions' : BASE_RECURSIVE_SCHEMAS}, - stream_dict.get('schema')) - + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) + + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "table-key-properties": ["our_decimal"], + "database-name": "postgres", + "schema-name": "public", + "is-view": False, + "row-count": 0, + }, + ("properties", "our_decimal"): { + "inclusion": "automatic", + "sql-datatype": "numeric", + "selected-by-default": True, + }, + ("properties", "our_decimal_38_4"): { + "inclusion": "available", + "sql-datatype": "numeric", + "selected-by-default": True, + }, + ("properties", "our_decimal_10_2"): { + "inclusion": "available", + "sql-datatype": "numeric", + "selected-by-default": True, + }, + }, + ) + + self.assertEqual( + { + "properties": { + "our_decimal": { + "exclusiveMaximum": True, + "exclusiveMinimum": True, + "multipleOf": 10 ** (0 - post_db.MAX_SCALE), + "maximum": 10 ** (post_db.MAX_PRECISION - post_db.MAX_SCALE), + "minimum": -(10 ** (post_db.MAX_PRECISION - post_db.MAX_SCALE)), + "type": ["number"], + }, + "our_decimal_10_2": { + "exclusiveMaximum": True, + "exclusiveMinimum": True, + "maximum": 100000000, + "minimum": -100000000, + "multipleOf": 0.01, + "type": ["null", "number"], + }, + "our_decimal_38_4": { + "exclusiveMaximum": True, + "exclusiveMinimum": True, + "maximum": 10000000000000000000000000000000000, + "minimum": -10000000000000000000000000000000000, + "multipleOf": 0.0001, + "type": ["null", "number"], + }, + }, + "type": "object", + "definitions": BASE_RECURSIVE_SCHEMAS, + }, + stream_dict.get("schema"), + ) class TestDatesTablePK(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' + table_name = "CHICKEN TIMES" def setUp(self): - table_spec = {"columns": [{"name" : 'our_date', "type" : "DATE", "primary_key": True }, - {"name" : 'our_ts', "type" : "TIMESTAMP"}, - {"name" : 'our_ts_tz', "type" : "TIMESTAMP WITH TIME ZONE"}, - {"name" : 'our_time', "type" : "TIME"}, - {"name" : 'our_time_tz', "type" : "TIME WITH TIME ZONE"}], - "name" : TestDatesTablePK.table_name} - ensure_test_table(table_spec) + table_spec = { + "columns": [ + {"name": "our_date", "type": "DATE", "primary_key": True}, + {"name": "our_ts", "type": "TIMESTAMP"}, + {"name": "our_ts_tz", "type": "TIMESTAMP WITH TIME ZONE"}, + {"name": "our_time", "type": "TIME"}, + {"name": "our_time_tz", "type": "TIME WITH TIME ZONE"}, + ], + "name": TestDatesTablePK.table_name, + } + ensure_test_table(table_spec) def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == 'public-CHICKEN TIMES'] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) - - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {() : {'table-key-properties': ['our_date'], 'database-name': 'postgres', 'schema-name': 'public', 'is-view': False, 'row-count': 0}, - ('properties', 'our_date') : {'inclusion': 'automatic', 'sql-datatype' : 'date', 'selected-by-default' : True}, - ('properties', 'our_ts') : {'inclusion': 'available', 'sql-datatype' : 'timestamp without time zone', 'selected-by-default' : True}, - ('properties', 'our_ts_tz') : {'inclusion': 'available', 'sql-datatype' : 'timestamp with time zone', 'selected-by-default' : True}, - ('properties', 'our_time') : {'inclusion': 'available', 'sql-datatype' : 'time without time zone', 'selected-by-default' : True}, - ('properties', 'our_time_tz') : {'inclusion': 'available', 'sql-datatype' : 'time with time zone', 'selected-by-default' : True}}) - - self.assertEqual({'properties': {'our_date': {'type': ['string'], 'format' : 'date-time'}, - 'our_ts': {'type': ['null', 'string'], 'format' : 'date-time'}, - 'our_ts_tz': {'type': ['null', 'string'], 'format' : 'date-time'}, - 'our_time': {'format': 'time', 'type': ['null', 'string']}, - 'our_time_tz': {'format': 'time', 'type': ['null', 'string']}}, - 'type': 'object', - 'definitions' : BASE_RECURSIVE_SCHEMAS}, - stream_dict.get('schema')) + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) + + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "table-key-properties": ["our_date"], + "database-name": "postgres", + "schema-name": "public", + "is-view": False, + "row-count": 0, + }, + ("properties", "our_date"): { + "inclusion": "automatic", + "sql-datatype": "date", + "selected-by-default": True, + }, + ("properties", "our_ts"): { + "inclusion": "available", + "sql-datatype": "timestamp without time zone", + "selected-by-default": True, + }, + ("properties", "our_ts_tz"): { + "inclusion": "available", + "sql-datatype": "timestamp with time zone", + "selected-by-default": True, + }, + ("properties", "our_time"): { + "inclusion": "available", + "sql-datatype": "time without time zone", + "selected-by-default": True, + }, + ("properties", "our_time_tz"): { + "inclusion": "available", + "sql-datatype": "time with time zone", + "selected-by-default": True, + }, + }, + ) + + self.assertEqual( + { + "properties": { + "our_date": {"type": ["string"], "format": "date-time"}, + "our_ts": {"type": ["null", "string"], "format": "date-time"}, + "our_ts_tz": {"type": ["null", "string"], "format": "date-time"}, + "our_time": {"format": "time", "type": ["null", "string"]}, + "our_time_tz": {"format": "time", "type": ["null", "string"]}, + }, + "type": "object", + "definitions": BASE_RECURSIVE_SCHEMAS, + }, + stream_dict.get("schema"), + ) + class TestFloatTablePK(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' + table_name = "CHICKEN TIMES" def setUp(self): - table_spec = {"columns": [{"name" : 'our_float', "type" : "float", "primary_key": True }, - {"name" : 'our_real', "type" : "real"}, - {"name" : 'our_double', "type" : "double precision"}], - "name" : TestFloatTablePK.table_name} - ensure_test_table(table_spec) + table_spec = { + "columns": [ + {"name": "our_float", "type": "float", "primary_key": True}, + {"name": "our_real", "type": "real"}, + {"name": "our_double", "type": "double precision"}, + ], + "name": TestFloatTablePK.table_name, + } + ensure_test_table(table_spec) def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == 'public-CHICKEN TIMES'] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {() : {'table-key-properties': ['our_float'], 'database-name': 'postgres', 'schema-name': 'public', 'is-view': False, 'row-count': 0}, - ('properties', 'our_float') : {'inclusion': 'automatic', 'sql-datatype' : 'double precision', 'selected-by-default' : True}, - ('properties', 'our_real') : {'inclusion': 'available', 'sql-datatype' : 'real', 'selected-by-default' : True}, - ('properties', 'our_double') : {'inclusion': 'available', 'sql-datatype' : 'double precision', 'selected-by-default' : True}}) - + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "table-key-properties": ["our_float"], + "database-name": "postgres", + "schema-name": "public", + "is-view": False, + "row-count": 0, + }, + ("properties", "our_float"): { + "inclusion": "automatic", + "sql-datatype": "double precision", + "selected-by-default": True, + }, + ("properties", "our_real"): { + "inclusion": "available", + "sql-datatype": "real", + "selected-by-default": True, + }, + ("properties", "our_double"): { + "inclusion": "available", + "sql-datatype": "double precision", + "selected-by-default": True, + }, + }, + ) + + self.assertEqual( + { + "properties": { + "our_float": {"type": ["number"]}, + "our_real": {"type": ["null", "number"]}, + "our_double": {"type": ["null", "number"]}, + }, + "type": "object", + "definitions": BASE_RECURSIVE_SCHEMAS, + }, + stream_dict.get("schema"), + ) - self.assertEqual({'properties': {'our_float': {'type': ['number']}, - 'our_real': {'type': ['null', 'number']}, - 'our_double': {'type': ['null', 'number']}}, - 'type': 'object', - 'definitions' : BASE_RECURSIVE_SCHEMAS}, - stream_dict.get('schema')) class TestBoolsAndBits(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' + table_name = "CHICKEN TIMES" def setUp(self): - table_spec = {"columns": [{"name" : 'our_bool', "type" : "boolean" }, - {"name" : 'our_bit', "type" : "bit" }], - "name" : TestBoolsAndBits.table_name} - ensure_test_table(table_spec) + table_spec = { + "columns": [ + {"name": "our_bool", "type": "boolean"}, + {"name": "our_bit", "type": "bit"}, + ], + "name": TestBoolsAndBits.table_name, + } + ensure_test_table(table_spec) def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == 'public-CHICKEN TIMES'] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {() : {'table-key-properties': [], 'database-name': 'postgres', 'schema-name': 'public', 'is-view': False, 'row-count': 0}, - ('properties', 'our_bool') : {'inclusion': 'available', 'sql-datatype' : 'boolean', 'selected-by-default' : True}, - ('properties', 'our_bit') : {'inclusion': 'available', 'sql-datatype' : 'bit', 'selected-by-default' : True}}) + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "table-key-properties": [], + "database-name": "postgres", + "schema-name": "public", + "is-view": False, + "row-count": 0, + }, + ("properties", "our_bool"): { + "inclusion": "available", + "sql-datatype": "boolean", + "selected-by-default": True, + }, + ("properties", "our_bit"): { + "inclusion": "available", + "sql-datatype": "bit", + "selected-by-default": True, + }, + }, + ) + + self.assertEqual( + { + "properties": { + "our_bool": {"type": ["null", "boolean"]}, + "our_bit": {"type": ["null", "boolean"]}, + }, + "definitions": BASE_RECURSIVE_SCHEMAS, + "type": "object", + }, + stream_dict.get("schema"), + ) - self.assertEqual({'properties': {'our_bool': {'type': ['null', 'boolean']}, - 'our_bit': {'type': ['null', 'boolean']}}, - 'definitions' : BASE_RECURSIVE_SCHEMAS, - 'type': 'object'}, - stream_dict.get('schema')) - class TestJsonTables(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' + table_name = "CHICKEN TIMES" def setUp(self): - table_spec = {"columns": [{"name" : 'our_secrets', "type" : "json" }, - {"name" : 'our_secrets_b', "type" : "jsonb" }], - "name" : TestJsonTables.table_name} - ensure_test_table(table_spec) + table_spec = { + "columns": [ + {"name": "our_secrets", "type": "json"}, + {"name": "our_secrets_b", "type": "jsonb"}, + ], + "name": TestJsonTables.table_name, + } + ensure_test_table(table_spec) def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == 'public-CHICKEN TIMES'] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) - - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {() : {'table-key-properties': [], 'database-name': 'postgres', 'schema-name': 'public', 'is-view': False, 'row-count': 0}, - ('properties', 'our_secrets') : {'inclusion': 'available', 'sql-datatype' : 'json', 'selected-by-default' : True}, - ('properties', 'our_secrets_b') : {'inclusion': 'available', 'sql-datatype' : 'jsonb', 'selected-by-default' : True}}) - - - self.assertEqual({'properties': {'our_secrets': {'type': ['null', 'object', 'array']}, - 'our_secrets_b': {'type': ['null', 'object', 'array']}}, - 'definitions' : BASE_RECURSIVE_SCHEMAS, - 'type': 'object'}, - stream_dict.get('schema')) + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) + + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "table-key-properties": [], + "database-name": "postgres", + "schema-name": "public", + "is-view": False, + "row-count": 0, + }, + ("properties", "our_secrets"): { + "inclusion": "available", + "sql-datatype": "json", + "selected-by-default": True, + }, + ("properties", "our_secrets_b"): { + "inclusion": "available", + "sql-datatype": "jsonb", + "selected-by-default": True, + }, + }, + ) + + self.assertEqual( + { + "properties": { + "our_secrets": {"type": ["null", "object", "array"]}, + "our_secrets_b": {"type": ["null", "object", "array"]}, + }, + "definitions": BASE_RECURSIVE_SCHEMAS, + "type": "object", + }, + stream_dict.get("schema"), + ) class TestUUIDTables(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' + table_name = "CHICKEN TIMES" def setUp(self): - table_spec = {"columns": [{"name" : 'our_pk', "type" : "uuid", "primary_key" : True }, - {"name" : 'our_uuid', "type" : "uuid" }], - "name" : TestUUIDTables.table_name} - ensure_test_table(table_spec) + table_spec = { + "columns": [ + {"name": "our_pk", "type": "uuid", "primary_key": True}, + {"name": "our_uuid", "type": "uuid"}, + ], + "name": TestUUIDTables.table_name, + } + ensure_test_table(table_spec) def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == "public-CHICKEN TIMES"] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) + + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "table-key-properties": ["our_pk"], + "database-name": "postgres", + "schema-name": "public", + "is-view": False, + "row-count": 0, + }, + ("properties", "our_pk"): { + "inclusion": "automatic", + "sql-datatype": "uuid", + "selected-by-default": True, + }, + ("properties", "our_uuid"): { + "inclusion": "available", + "sql-datatype": "uuid", + "selected-by-default": True, + }, + }, + ) + + self.assertEqual( + { + "properties": { + "our_uuid": {"type": ["null", "string"]}, + "our_pk": {"type": ["string"]}, + }, + "type": "object", + "definitions": BASE_RECURSIVE_SCHEMAS, + }, + stream_dict.get("schema"), + ) - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {() : {'table-key-properties': ['our_pk'], 'database-name': 'postgres', 'schema-name': 'public', 'is-view': False, 'row-count': 0}, - ('properties', 'our_pk') : {'inclusion': 'automatic', 'sql-datatype' : 'uuid', 'selected-by-default' : True}, - ('properties', 'our_uuid') : {'inclusion': 'available', 'sql-datatype' : 'uuid', 'selected-by-default' : True}}) - - - self.assertEqual({'properties': {'our_uuid': {'type': ['null', 'string']}, - 'our_pk': {'type': ['string']}}, - 'type': 'object', - 'definitions' : BASE_RECURSIVE_SCHEMAS}, - stream_dict.get('schema')) class TestHStoreTable(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' + table_name = "CHICKEN TIMES" def setUp(self): - table_spec = {"columns": [{"name" : 'our_pk', "type" : "hstore", "primary_key" : True }, - {"name" : 'our_hstore', "type" : "hstore" }], - "name" : TestHStoreTable.table_name} - with get_test_connection() as conn: - cur = conn.cursor() - cur.execute(""" SELECT installed_version FROM pg_available_extensions WHERE name = 'hstore' """) - if cur.fetchone()[0] is None: - cur.execute(""" CREATE EXTENSION hstore; """) - + table_spec = { + "columns": [ + {"name": "our_pk", "type": "hstore", "primary_key": True}, + {"name": "our_hstore", "type": "hstore"}, + ], + "name": TestHStoreTable.table_name, + } + with get_test_connection() as conn: + cur = conn.cursor() + cur.execute( + """ SELECT installed_version FROM pg_available_extensions WHERE name = 'hstore' """ + ) + if cur.fetchone()[0] is None: + cur.execute(""" CREATE EXTENSION hstore; """) - ensure_test_table(table_spec) + ensure_test_table(table_spec) def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == 'public-CHICKEN TIMES'] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) with get_test_connection() as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - cur.execute("""INSERT INTO "CHICKEN TIMES" (our_pk, our_hstore) VALUES ('size=>"small",name=>"betty"', 'size=>"big",name=>"fred"')""") + cur.execute( + """INSERT INTO "CHICKEN TIMES" (our_pk, our_hstore) VALUES ('size=>"small",name=>"betty"', 'size=>"big",name=>"fred"')""" + ) cur.execute("""SELECT * FROM "CHICKEN TIMES" """) - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {() : {'table-key-properties': ['our_pk'], 'database-name': 'postgres', 'schema-name': 'public', 'is-view': False, 'row-count': 0}, - ('properties', 'our_pk') : {'inclusion': 'automatic', 'sql-datatype' : 'hstore', 'selected-by-default' : True}, - ('properties', 'our_hstore') : {'inclusion': 'available', 'sql-datatype' : 'hstore', 'selected-by-default' : True}}) - - - self.assertEqual({'properties': {'our_hstore': {'type': ['null', 'object'], 'properties' : {}}, - 'our_pk': {'type': ['object'], 'properties': {}}}, - 'type': 'object', - 'definitions' : BASE_RECURSIVE_SCHEMAS}, - stream_dict.get('schema')) + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "table-key-properties": ["our_pk"], + "database-name": "postgres", + "schema-name": "public", + "is-view": False, + "row-count": 0, + }, + ("properties", "our_pk"): { + "inclusion": "automatic", + "sql-datatype": "hstore", + "selected-by-default": True, + }, + ("properties", "our_hstore"): { + "inclusion": "available", + "sql-datatype": "hstore", + "selected-by-default": True, + }, + }, + ) + + self.assertEqual( + { + "properties": { + "our_hstore": { + "type": ["null", "object"], + "properties": {}, + }, + "our_pk": {"type": ["object"], "properties": {}}, + }, + "type": "object", + "definitions": BASE_RECURSIVE_SCHEMAS, + }, + stream_dict.get("schema"), + ) def test_escaping_values(self): - key = 'nickname' + key = "nickname" value = "Dave's Courtyard" elem = '"{}"=>"{}"'.format(key, value) with get_test_connection() as conn: - with conn.cursor() as cur: - query = tap_postgres.sync_strategies.logical_replication.create_hstore_elem_query(elem) - self.assertEqual(query.as_string(cur), "SELECT hstore_to_array('\"nickname\"=>\"Dave''s Courtyard\"')") + with conn.cursor() as cur: + query = tap_postgres.sync_strategies.logical_replication.create_hstore_elem_query( + elem + ) + self.assertEqual( + query.as_string(cur), + "SELECT hstore_to_array('\"nickname\"=>\"Dave''s Courtyard\"')", + ) class TestEnumTable(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' + table_name = "CHICKEN TIMES" def setUp(self): - table_spec = {"columns": [{"name" : 'our_mood_enum_pk', "type" : "mood_enum", "primary_key" : True }, - {"name" : 'our_mood_enum', "type" : "mood_enum" }], - "name" : TestHStoreTable.table_name} - with get_test_connection() as conn: - cur = conn.cursor() - cur.execute(""" DROP TYPE IF EXISTS mood_enum CASCADE """) - cur.execute(""" CREATE TYPE mood_enum AS ENUM ('sad', 'ok', 'happy'); """) + table_spec = { + "columns": [ + {"name": "our_mood_enum_pk", "type": "mood_enum", "primary_key": True}, + {"name": "our_mood_enum", "type": "mood_enum"}, + ], + "name": TestHStoreTable.table_name, + } + with get_test_connection() as conn: + cur = conn.cursor() + cur.execute(""" DROP TYPE IF EXISTS mood_enum CASCADE """) + cur.execute( + """ CREATE TYPE mood_enum AS ENUM ('sad', 'ok', 'happy'); """ + ) - ensure_test_table(table_spec) + ensure_test_table(table_spec) def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == 'public-CHICKEN TIMES'] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) with get_test_connection() as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - cur.execute("""INSERT INTO "CHICKEN TIMES" (our_mood_enum_pk, our_mood_enum) VALUES ('sad', 'happy')""") + cur.execute( + """INSERT INTO "CHICKEN TIMES" (our_mood_enum_pk, our_mood_enum) VALUES ('sad', 'happy')""" + ) cur.execute("""SELECT * FROM "CHICKEN TIMES" """) - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {() : {'table-key-properties': ['our_mood_enum_pk'], 'database-name': 'postgres', 'schema-name': 'public', 'is-view': False, 'row-count': 0}, - ('properties', 'our_mood_enum_pk') : {'inclusion': 'automatic', 'sql-datatype' : 'mood_enum', 'selected-by-default' : True}, - ('properties', 'our_mood_enum') : {'inclusion': 'available', 'sql-datatype' : 'mood_enum', 'selected-by-default' : True}}) - - - self.assertEqual({'properties': {'our_mood_enum': {'type': ['null', 'string']}, - 'our_mood_enum_pk': {'type': ['string']}}, - 'type': 'object', - 'definitions' : BASE_RECURSIVE_SCHEMAS}, - stream_dict.get('schema')) + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "table-key-properties": ["our_mood_enum_pk"], + "database-name": "postgres", + "schema-name": "public", + "is-view": False, + "row-count": 0, + }, + ("properties", "our_mood_enum_pk"): { + "inclusion": "automatic", + "sql-datatype": "mood_enum", + "selected-by-default": True, + }, + ("properties", "our_mood_enum"): { + "inclusion": "available", + "sql-datatype": "mood_enum", + "selected-by-default": True, + }, + }, + ) + + self.assertEqual( + { + "properties": { + "our_mood_enum": {"type": ["null", "string"]}, + "our_mood_enum_pk": {"type": ["string"]}, + }, + "type": "object", + "definitions": BASE_RECURSIVE_SCHEMAS, + }, + stream_dict.get("schema"), + ) class TestMoney(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' + table_name = "CHICKEN TIMES" def setUp(self): - table_spec = {"columns": [{"name" : 'our_money_pk', "type" : "money", "primary_key" : True }, - {"name" : 'our_money', "type" : "money" }], - "name" : TestHStoreTable.table_name} - ensure_test_table(table_spec) + table_spec = { + "columns": [ + {"name": "our_money_pk", "type": "money", "primary_key": True}, + {"name": "our_money", "type": "money"}, + ], + "name": TestHStoreTable.table_name, + } + ensure_test_table(table_spec) def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == 'public-CHICKEN TIMES'] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) with get_test_connection() as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - cur.execute("""INSERT INTO "CHICKEN TIMES" (our_money_pk, our_money) VALUES ('$1.24', '$777.63')""") + cur.execute( + """INSERT INTO "CHICKEN TIMES" (our_money_pk, our_money) VALUES ('$1.24', '$777.63')""" + ) cur.execute("""SELECT * FROM "CHICKEN TIMES" """) - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {() : {'table-key-properties': ['our_money_pk'], 'database-name': 'postgres', 'schema-name': 'public', 'is-view': False, 'row-count': 0}, - ('properties', 'our_money_pk') : {'inclusion': 'automatic', 'sql-datatype' : 'money', 'selected-by-default' : True}, - ('properties', 'our_money') : {'inclusion': 'available', 'sql-datatype' : 'money', 'selected-by-default' : True}}) + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "table-key-properties": ["our_money_pk"], + "database-name": "postgres", + "schema-name": "public", + "is-view": False, + "row-count": 0, + }, + ("properties", "our_money_pk"): { + "inclusion": "automatic", + "sql-datatype": "money", + "selected-by-default": True, + }, + ("properties", "our_money"): { + "inclusion": "available", + "sql-datatype": "money", + "selected-by-default": True, + }, + }, + ) + + self.assertEqual( + { + "properties": { + "our_money": {"type": ["null", "string"]}, + "our_money_pk": {"type": ["string"]}, + }, + "type": "object", + "definitions": BASE_RECURSIVE_SCHEMAS, + }, + stream_dict.get("schema"), + ) - self.assertEqual({'properties': {'our_money': {'type': ['null', 'string']}, - 'our_money_pk': {'type': ['string']}}, - 'type': 'object', - 'definitions' : BASE_RECURSIVE_SCHEMAS}, - stream_dict.get('schema')) - class TestArraysTable(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' + table_name = "CHICKEN TIMES" def setUp(self): - table_spec = {"columns": [{"name" : 'our_int_array_pk', "type" : "integer[]", "primary_key" : True }, - {"name" : 'our_string_array', "type" : "varchar[]" }], - "name" : TestHStoreTable.table_name} - ensure_test_table(table_spec) + table_spec = { + "columns": [ + {"name": "our_int_array_pk", "type": "integer[]", "primary_key": True}, + {"name": "our_string_array", "type": "varchar[]"}, + ], + "name": TestHStoreTable.table_name, + } + ensure_test_table(table_spec) def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == 'public-CHICKEN TIMES'] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) with get_test_connection() as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - cur.execute("""INSERT INTO "CHICKEN TIMES" (our_int_array_pk, our_string_array) VALUES ('{{1,2,3},{4,5,6}}', '{{"a","b","c"}}' )""") + cur.execute( + """INSERT INTO "CHICKEN TIMES" (our_int_array_pk, our_string_array) VALUES ('{{1,2,3},{4,5,6}}', '{{"a","b","c"}}' )""" + ) cur.execute("""SELECT * FROM "CHICKEN TIMES" """) - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {() : {'table-key-properties': ['our_int_array_pk'], 'database-name': 'postgres', 'schema-name': 'public', 'is-view': False, 'row-count': 0}, - ('properties', 'our_int_array_pk') : {'inclusion': 'automatic', 'sql-datatype' : 'integer[]', 'selected-by-default' : True}, - ('properties', 'our_string_array') : {'inclusion': 'available', 'sql-datatype' : 'character varying[]', 'selected-by-default' : True}}) - - - self.assertEqual({'properties': {'our_int_array_pk': {'type': ['null', 'array'], 'items' : {'$ref': '#/definitions/sdc_recursive_integer_array'}}, - 'our_string_array': {'type': ['null', 'array'], 'items' : {'$ref': '#/definitions/sdc_recursive_string_array'}}}, - 'type': 'object', - 'definitions' : BASE_RECURSIVE_SCHEMAS}, - stream_dict.get('schema')) + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "table-key-properties": ["our_int_array_pk"], + "database-name": "postgres", + "schema-name": "public", + "is-view": False, + "row-count": 0, + }, + ("properties", "our_int_array_pk"): { + "inclusion": "automatic", + "sql-datatype": "integer[]", + "selected-by-default": True, + }, + ("properties", "our_string_array"): { + "inclusion": "available", + "sql-datatype": "character varying[]", + "selected-by-default": True, + }, + }, + ) + + self.assertEqual( + { + "properties": { + "our_int_array_pk": { + "type": ["null", "array"], + "items": { + "$ref": "#/definitions/sdc_recursive_integer_array" + }, + }, + "our_string_array": { + "type": ["null", "array"], + "items": { + "$ref": "#/definitions/sdc_recursive_string_array" + }, + }, + }, + "type": "object", + "definitions": BASE_RECURSIVE_SCHEMAS, + }, + stream_dict.get("schema"), + ) class TestArraysLikeTable(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' - like_table_name = 'LIKE CHICKEN TIMES' + table_name = "CHICKEN TIMES" + like_table_name = "LIKE CHICKEN TIMES" def setUp(self): - with get_test_connection('postgres') as conn: + with get_test_connection("postgres") as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute('DROP MATERIALIZED VIEW IF EXISTS "LIKE CHICKEN TIMES"') - table_spec = {"columns": [{"name" : 'our_int_array_pk', "type" : "integer[]", "primary_key" : True }, - {"name" : 'our_text_array', "type" : "text[]" }], - "name" : TestArraysLikeTable.table_name} + table_spec = { + "columns": [ + {"name": "our_int_array_pk", "type": "integer[]", "primary_key": True}, + {"name": "our_text_array", "type": "text[]"}, + ], + "name": TestArraysLikeTable.table_name, + } ensure_test_table(table_spec) - with get_test_connection('postgres') as conn: + with get_test_connection("postgres") as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - create_sql = "CREATE MATERIALIZED VIEW {} AS SELECT * FROM {}\n".format(quote_ident(TestArraysLikeTable.like_table_name, cur), - quote_ident(TestArraysLikeTable.table_name, cur)) - - - cur.execute(create_sql) + create_sql = "CREATE MATERIALIZED VIEW {} AS SELECT * FROM {}\n".format( + quote_ident(TestArraysLikeTable.like_table_name, cur), + quote_ident(TestArraysLikeTable.table_name, cur), + ) + cur.execute(create_sql) def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == 'public-LIKE CHICKEN TIMES'] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-LIKE CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) with get_test_connection() as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {() : {'table-key-properties': [], 'database-name': 'postgres', 'schema-name': 'public', 'is-view': True, 'row-count': 0}, - ('properties', 'our_int_array_pk') : {'inclusion': 'available', 'sql-datatype' : 'integer[]', 'selected-by-default' : True}, - ('properties', 'our_text_array') : {'inclusion': 'available', 'sql-datatype' : 'text[]', 'selected-by-default' : True}}) - self.assertEqual({'properties': {'our_int_array_pk': {'type': ['null', 'array'], 'items' : {'$ref': '#/definitions/sdc_recursive_integer_array'}}, - 'our_text_array': {'type': ['null', 'array'], 'items' : {'$ref': '#/definitions/sdc_recursive_string_array'}}}, - 'type': 'object', - 'definitions' : BASE_RECURSIVE_SCHEMAS}, - stream_dict.get('schema')) + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "table-key-properties": [], + "database-name": "postgres", + "schema-name": "public", + "is-view": True, + "row-count": 0, + }, + ("properties", "our_int_array_pk"): { + "inclusion": "available", + "sql-datatype": "integer[]", + "selected-by-default": True, + }, + ("properties", "our_text_array"): { + "inclusion": "available", + "sql-datatype": "text[]", + "selected-by-default": True, + }, + }, + ) + self.assertEqual( + { + "properties": { + "our_int_array_pk": { + "type": ["null", "array"], + "items": { + "$ref": "#/definitions/sdc_recursive_integer_array" + }, + }, + "our_text_array": { + "type": ["null", "array"], + "items": { + "$ref": "#/definitions/sdc_recursive_string_array" + }, + }, + }, + "type": "object", + "definitions": BASE_RECURSIVE_SCHEMAS, + }, + stream_dict.get("schema"), + ) + class TestColumnGrants(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' - user = 'tmp_user_for_grant_tests' - password = 'password' - + table_name = "CHICKEN TIMES" + user = "tmp_user_for_grant_tests" + password = "password" + def setUp(self): - table_spec = {"columns": [{"name" : "id", "type" : "integer", "serial" : True}, - {"name" : 'size integer', "type" : "integer", "quoted" : True}, - {"name" : 'size smallint', "type" : "smallint", "quoted" : True}, - {"name" : 'size bigint', "type" : "bigint", "quoted" : True}], - "name" : TestColumnGrants.table_name} + table_spec = { + "columns": [ + {"name": "id", "type": "integer", "serial": True}, + {"name": "size integer", "type": "integer", "quoted": True}, + {"name": "size smallint", "type": "smallint", "quoted": True}, + {"name": "size bigint", "type": "bigint", "quoted": True}, + ], + "name": TestColumnGrants.table_name, + } ensure_test_table(table_spec) with get_test_connection() as conn: - cur = conn.cursor() - - sql = """ DROP USER IF EXISTS {} """.format(self.user, self.password) - LOGGER.info(sql) - cur.execute(sql) + cur = conn.cursor() - sql = """ CREATE USER {} WITH PASSWORD '{}' """.format(self.user, self.password) - LOGGER.info(sql) - cur.execute(sql) + sql = """ DROP USER IF EXISTS {} """.format(self.user, self.password) + LOGGER.info(sql) + cur.execute(sql) - sql = """ GRANT SELECT ("id") ON "{}" TO {}""".format(TestColumnGrants.table_name, self.user) - LOGGER.info("running sql: {}".format(sql)) - cur.execute(sql) + sql = """ CREATE USER {} WITH PASSWORD '{}' """.format( + self.user, self.password + ) + LOGGER.info(sql) + cur.execute(sql) - - + sql = """ GRANT SELECT ("id") ON "{}" TO {}""".format( + TestColumnGrants.table_name, self.user + ) + LOGGER.info("running sql: {}".format(sql)) + cur.execute(sql) def test_catalog(self): conn_config = get_test_connection_config() - conn_config['user'] = self.user - conn_config['password'] = self.password + conn_config["user"] = self.user + conn_config["password"] = self.password streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == 'public-CHICKEN TIMES'] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - self.assertEqual(TestStringTableWithPK.table_name, stream_dict.get('table_name')) - self.assertEqual(TestStringTableWithPK.table_name, stream_dict.get('stream')) - - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) - - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {(): {'table-key-properties': [], - 'database-name': 'postgres', - 'schema-name': 'public', - 'is-view': False, - 'row-count': 0}, - ('properties', 'id'): {'inclusion': 'available', - 'selected-by-default': True, - 'sql-datatype': 'integer'}}) - - self.assertEqual({'definitions' : BASE_RECURSIVE_SCHEMAS, - 'type': 'object', - 'properties': {'id': {'type': ['null', 'integer'], - 'minimum': -2147483648, - 'maximum': 2147483647}}}, - stream_dict.get('schema')) + self.assertEqual( + TestStringTableWithPK.table_name, stream_dict.get("table_name") + ) + self.assertEqual(TestStringTableWithPK.table_name, stream_dict.get("stream")) + + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) + + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "table-key-properties": [], + "database-name": "postgres", + "schema-name": "public", + "is-view": False, + "row-count": 0, + }, + ("properties", "id"): { + "inclusion": "available", + "selected-by-default": True, + "sql-datatype": "integer", + }, + }, + ) + + self.assertEqual( + { + "definitions": BASE_RECURSIVE_SCHEMAS, + "type": "object", + "properties": { + "id": { + "type": ["null", "integer"], + "minimum": -2147483648, + "maximum": 2147483647, + } + }, + }, + stream_dict.get("schema"), + ) diff --git a/tests/test_full_table_interruption.py b/tests/test_full_table_interruption.py index 21bc0144..f6fb12af 100644 --- a/tests/test_full_table_interruption.py +++ b/tests/test_full_table_interruption.py @@ -4,10 +4,25 @@ import tap_postgres.sync_strategies.common as pg_common import singer from singer import get_logger, metadata, write_bookmark + try: - from tests.utils import get_test_connection, ensure_test_table, select_all_of_stream, set_replication_method_for_stream, insert_record, get_test_connection_config + from tests.utils import ( + get_test_connection, + ensure_test_table, + select_all_of_stream, + set_replication_method_for_stream, + insert_record, + get_test_connection_config, + ) except ImportError: - from utils import get_test_connection, ensure_test_table, select_all_of_stream, set_replication_method_for_stream, insert_record, get_test_connection_config + from utils import ( + get_test_connection, + ensure_test_table, + select_all_of_stream, + set_replication_method_for_stream, + insert_record, + get_test_connection_config, + ) LOGGER = get_logger() @@ -15,10 +30,11 @@ CAUGHT_MESSAGES = [] COW_RECORD_COUNT = 0 + def singer_write_message_no_cow(message): global COW_RECORD_COUNT - if isinstance(message, singer.RecordMessage) and message.stream == 'public-COW': + if isinstance(message, singer.RecordMessage) and message.stream == "public-COW": COW_RECORD_COUNT = COW_RECORD_COUNT + 1 if COW_RECORD_COUNT > 2: raise Exception("simulated exception") @@ -26,36 +42,45 @@ def singer_write_message_no_cow(message): else: CAUGHT_MESSAGES.append(message) + def singer_write_schema_ok(message): CAUGHT_MESSAGES.append(message) + def singer_write_message_ok(message): CAUGHT_MESSAGES.append(message) + def expected_record(fixture_row): expected_record = {} - for k,v in fixture_row.items(): - expected_record[k.replace('"', '')] = v + for k, v in fixture_row.items(): + expected_record[k.replace('"', "")] = v return expected_record + def do_not_dump_catalog(catalog): pass + tap_postgres.dump_catalog = do_not_dump_catalog full_table.UPDATE_BOOKMARK_PERIOD = 1 + class LogicalInterruption(unittest.TestCase): maxDiff = None def setUp(self): - table_spec_1 = {"columns": [{"name": "id", "type" : "serial", "primary_key" : True}, - {"name" : 'name', "type": "character varying"}, - {"name" : 'colour', "type": "character varying"}, - {"name" : 'timestamp_ntz', "type": "timestamp without time zone"}, - {"name" : 'timestamp_tz', "type": "timestamp with time zone"}, - ], - "name" : 'COW'} + table_spec_1 = { + "columns": [ + {"name": "id", "type": "serial", "primary_key": True}, + {"name": "name", "type": "character varying"}, + {"name": "colour", "type": "character varying"}, + {"name": "timestamp_ntz", "type": "timestamp without time zone"}, + {"name": "timestamp_tz", "type": "timestamp with time zone"}, + ], + "name": "COW", + } ensure_test_table(table_spec_1) global COW_RECORD_COUNT COW_RECORD_COUNT = 0 @@ -68,32 +93,46 @@ def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - cow_stream = [s for s in streams if s['table_name'] == 'COW'][0] + cow_stream = [s for s in streams if s["table_name"] == "COW"][0] self.assertIsNotNone(cow_stream) cow_stream = select_all_of_stream(cow_stream) - cow_stream = set_replication_method_for_stream(cow_stream, 'LOG_BASED') + cow_stream = set_replication_method_for_stream(cow_stream, "LOG_BASED") with get_test_connection() as conn: conn.autocommit = True cur = conn.cursor() - cow_rec = {'name' : 'betty', 'colour' : 'blue', - 'timestamp_ntz': '2020-09-01 10:40:59', 'timestamp_tz': '2020-09-01 00:50:59+02'} - insert_record(cur, 'COW', cow_rec) - - cow_rec = {'name' : 'smelly', 'colour' : 'brow', - 'timestamp_ntz': '2020-09-01 10:40:59 BC', 'timestamp_tz': '2020-09-01 00:50:59+02 BC'} - insert_record(cur, 'COW', cow_rec) - - cow_rec = {'name' : 'pooper', 'colour' : 'green', - 'timestamp_ntz': '30000-09-01 10:40:59', 'timestamp_tz': '10000-09-01 00:50:59+02'} - insert_record(cur, 'COW', cow_rec) + cow_rec = { + "name": "betty", + "colour": "blue", + "timestamp_ntz": "2020-09-01 10:40:59", + "timestamp_tz": "2020-09-01 00:50:59+02", + } + insert_record(cur, "COW", cow_rec) + + cow_rec = { + "name": "smelly", + "colour": "brow", + "timestamp_ntz": "2020-09-01 10:40:59 BC", + "timestamp_tz": "2020-09-01 00:50:59+02 BC", + } + insert_record(cur, "COW", cow_rec) + + cow_rec = { + "name": "pooper", + "colour": "green", + "timestamp_ntz": "30000-09-01 10:40:59", + "timestamp_tz": "10000-09-01 00:50:59+02", + } + insert_record(cur, "COW", cow_rec) state = {} - #the initial phase of cows logical replication will be a full table. - #it will sync the first record and then blow up on the 2nd record + # the initial phase of cows logical replication will be a full table. + # it will sync the first record and then blow up on the 2nd record try: - tap_postgres.do_sync(get_test_connection_config(), {'streams' : streams}, None, state) + tap_postgres.do_sync( + get_test_connection_config(), {"streams": streams}, None, state + ) except Exception: blew_up_on_cow = True @@ -101,115 +140,174 @@ def test_catalog(self): self.assertEqual(7, len(CAUGHT_MESSAGES)) - self.assertEqual(CAUGHT_MESSAGES[0]['type'], 'SCHEMA') + self.assertEqual(CAUGHT_MESSAGES[0]["type"], "SCHEMA") self.assertIsInstance(CAUGHT_MESSAGES[1], singer.StateMessage) - self.assertIsNone(CAUGHT_MESSAGES[1].value['bookmarks']['public-COW'].get('xmin')) - self.assertIsNotNone(CAUGHT_MESSAGES[1].value['bookmarks']['public-COW'].get('lsn')) - end_lsn = CAUGHT_MESSAGES[1].value['bookmarks']['public-COW'].get('lsn') + self.assertIsNone( + CAUGHT_MESSAGES[1].value["bookmarks"]["public-COW"].get("xmin") + ) + self.assertIsNotNone( + CAUGHT_MESSAGES[1].value["bookmarks"]["public-COW"].get("lsn") + ) + end_lsn = CAUGHT_MESSAGES[1].value["bookmarks"]["public-COW"].get("lsn") self.assertIsInstance(CAUGHT_MESSAGES[2], singer.ActivateVersionMessage) new_version = CAUGHT_MESSAGES[2].version self.assertIsInstance(CAUGHT_MESSAGES[3], singer.RecordMessage) - self.assertEqual(CAUGHT_MESSAGES[3].record, { - 'colour': 'blue', - 'id': 1, - 'name': 'betty', - 'timestamp_ntz': '2020-09-01T10:40:59+00:00', - 'timestamp_tz': '2020-08-31T22:50:59+00:00' - }) - - self.assertEqual('public-COW', CAUGHT_MESSAGES[3].stream) + self.assertEqual( + CAUGHT_MESSAGES[3].record, + { + "colour": "blue", + "id": 1, + "name": "betty", + "timestamp_ntz": "2020-09-01T10:40:59+00:00", + "timestamp_tz": "2020-08-31T22:50:59+00:00", + }, + ) + + self.assertEqual("public-COW", CAUGHT_MESSAGES[3].stream) self.assertIsInstance(CAUGHT_MESSAGES[4], singer.StateMessage) - #xmin is set while we are processing the full table replication - self.assertIsNotNone(CAUGHT_MESSAGES[4].value['bookmarks']['public-COW']['xmin']) - self.assertEqual(CAUGHT_MESSAGES[4].value['bookmarks']['public-COW']['lsn'], end_lsn) - - self.assertEqual(CAUGHT_MESSAGES[5].record, { - 'colour': 'brow', - 'id': 2, - 'name': 'smelly', - 'timestamp_ntz': '9999-12-31T23:59:59.999000+00:00', - 'timestamp_tz': '9999-12-31T23:59:59.999000+00:00' - }) - - self.assertEqual('public-COW', CAUGHT_MESSAGES[5].stream) + # xmin is set while we are processing the full table replication + self.assertIsNotNone( + CAUGHT_MESSAGES[4].value["bookmarks"]["public-COW"]["xmin"] + ) + self.assertEqual( + CAUGHT_MESSAGES[4].value["bookmarks"]["public-COW"]["lsn"], end_lsn + ) + + self.assertEqual( + CAUGHT_MESSAGES[5].record, + { + "colour": "brow", + "id": 2, + "name": "smelly", + "timestamp_ntz": "9999-12-31T23:59:59.999000+00:00", + "timestamp_tz": "9999-12-31T23:59:59.999000+00:00", + }, + ) + + self.assertEqual("public-COW", CAUGHT_MESSAGES[5].stream) self.assertIsInstance(CAUGHT_MESSAGES[6], singer.StateMessage) - last_xmin = CAUGHT_MESSAGES[6].value['bookmarks']['public-COW']['xmin'] + last_xmin = CAUGHT_MESSAGES[6].value["bookmarks"]["public-COW"]["xmin"] old_state = CAUGHT_MESSAGES[6].value - #run another do_sync, should get the remaining record which effectively finishes the initial full_table - #replication portion of the logical replication + # run another do_sync, should get the remaining record which effectively finishes the initial full_table + # replication portion of the logical replication singer.write_message = singer_write_message_ok global COW_RECORD_COUNT COW_RECORD_COUNT = 0 CAUGHT_MESSAGES.clear() - tap_postgres.do_sync(get_test_connection_config(), {'streams' : streams}, None, old_state) + tap_postgres.do_sync( + get_test_connection_config(), {"streams": streams}, None, old_state + ) self.assertEqual(8, len(CAUGHT_MESSAGES)) - self.assertEqual(CAUGHT_MESSAGES[0]['type'], 'SCHEMA') + self.assertEqual(CAUGHT_MESSAGES[0]["type"], "SCHEMA") self.assertIsInstance(CAUGHT_MESSAGES[1], singer.StateMessage) - self.assertEqual(CAUGHT_MESSAGES[1].value['bookmarks']['public-COW'].get('xmin'), last_xmin) - self.assertEqual(CAUGHT_MESSAGES[1].value['bookmarks']['public-COW'].get('lsn'), end_lsn) - self.assertEqual(CAUGHT_MESSAGES[1].value['bookmarks']['public-COW'].get('version'), new_version) + self.assertEqual( + CAUGHT_MESSAGES[1].value["bookmarks"]["public-COW"].get("xmin"), last_xmin + ) + self.assertEqual( + CAUGHT_MESSAGES[1].value["bookmarks"]["public-COW"].get("lsn"), end_lsn + ) + self.assertEqual( + CAUGHT_MESSAGES[1].value["bookmarks"]["public-COW"].get("version"), + new_version, + ) self.assertIsInstance(CAUGHT_MESSAGES[2], singer.RecordMessage) - self.assertEqual(CAUGHT_MESSAGES[2].record, { - 'colour': 'brow', - 'id': 2, - 'name': 'smelly', - 'timestamp_ntz': '9999-12-31T23:59:59.999000+00:00', - 'timestamp_tz': '9999-12-31T23:59:59.999000+00:00' - }) - - self.assertEqual('public-COW', CAUGHT_MESSAGES[2].stream) + self.assertEqual( + CAUGHT_MESSAGES[2].record, + { + "colour": "brow", + "id": 2, + "name": "smelly", + "timestamp_ntz": "9999-12-31T23:59:59.999000+00:00", + "timestamp_tz": "9999-12-31T23:59:59.999000+00:00", + }, + ) + + self.assertEqual("public-COW", CAUGHT_MESSAGES[2].stream) self.assertIsInstance(CAUGHT_MESSAGES[3], singer.StateMessage) - self.assertTrue(CAUGHT_MESSAGES[3].value['bookmarks']['public-COW'].get('xmin'),last_xmin) - self.assertEqual(CAUGHT_MESSAGES[3].value['bookmarks']['public-COW'].get('lsn'), end_lsn) - self.assertEqual(CAUGHT_MESSAGES[3].value['bookmarks']['public-COW'].get('version'), new_version) + self.assertTrue( + CAUGHT_MESSAGES[3].value["bookmarks"]["public-COW"].get("xmin"), last_xmin + ) + self.assertEqual( + CAUGHT_MESSAGES[3].value["bookmarks"]["public-COW"].get("lsn"), end_lsn + ) + self.assertEqual( + CAUGHT_MESSAGES[3].value["bookmarks"]["public-COW"].get("version"), + new_version, + ) self.assertIsInstance(CAUGHT_MESSAGES[4], singer.RecordMessage) - self.assertEqual(CAUGHT_MESSAGES[4].record, { - 'colour': 'green', - 'id': 3, - 'name': 'pooper', - 'timestamp_ntz': '9999-12-31T23:59:59.999000+00:00', - 'timestamp_tz': '9999-12-31T23:59:59.999000+00:00' - }) - self.assertEqual('public-COW', CAUGHT_MESSAGES[4].stream) + self.assertEqual( + CAUGHT_MESSAGES[4].record, + { + "colour": "green", + "id": 3, + "name": "pooper", + "timestamp_ntz": "9999-12-31T23:59:59.999000+00:00", + "timestamp_tz": "9999-12-31T23:59:59.999000+00:00", + }, + ) + self.assertEqual("public-COW", CAUGHT_MESSAGES[4].stream) self.assertIsInstance(CAUGHT_MESSAGES[5], singer.StateMessage) - self.assertTrue(CAUGHT_MESSAGES[5].value['bookmarks']['public-COW'].get('xmin') > last_xmin) - self.assertEqual(CAUGHT_MESSAGES[5].value['bookmarks']['public-COW'].get('lsn'), end_lsn) - self.assertEqual(CAUGHT_MESSAGES[5].value['bookmarks']['public-COW'].get('version'), new_version) - + self.assertTrue( + CAUGHT_MESSAGES[5].value["bookmarks"]["public-COW"].get("xmin") > last_xmin + ) + self.assertEqual( + CAUGHT_MESSAGES[5].value["bookmarks"]["public-COW"].get("lsn"), end_lsn + ) + self.assertEqual( + CAUGHT_MESSAGES[5].value["bookmarks"]["public-COW"].get("version"), + new_version, + ) self.assertIsInstance(CAUGHT_MESSAGES[6], singer.ActivateVersionMessage) self.assertEqual(CAUGHT_MESSAGES[6].version, new_version) self.assertIsInstance(CAUGHT_MESSAGES[7], singer.StateMessage) - self.assertIsNone(CAUGHT_MESSAGES[7].value['bookmarks']['public-COW'].get('xmin')) - self.assertEqual(CAUGHT_MESSAGES[7].value['bookmarks']['public-COW'].get('lsn'), end_lsn) - self.assertEqual(CAUGHT_MESSAGES[7].value['bookmarks']['public-COW'].get('version'), new_version) + self.assertIsNone( + CAUGHT_MESSAGES[7].value["bookmarks"]["public-COW"].get("xmin") + ) + self.assertEqual( + CAUGHT_MESSAGES[7].value["bookmarks"]["public-COW"].get("lsn"), end_lsn + ) + self.assertEqual( + CAUGHT_MESSAGES[7].value["bookmarks"]["public-COW"].get("version"), + new_version, + ) + class FullTableInterruption(unittest.TestCase): maxDiff = None + def setUp(self): - table_spec_1 = {"columns": [{"name": "id", "type" : "serial", "primary_key" : True}, - {"name" : 'name', "type": "character varying"}, - {"name" : 'colour', "type": "character varying"}], - "name" : 'COW'} + table_spec_1 = { + "columns": [ + {"name": "id", "type": "serial", "primary_key": True}, + {"name": "name", "type": "character varying"}, + {"name": "colour", "type": "character varying"}, + ], + "name": "COW", + } ensure_test_table(table_spec_1) - table_spec_2 = {"columns": [{"name": "id", "type" : "serial", "primary_key" : True}, - {"name" : 'name', "type": "character varying"}, - {"name" : 'colour', "type": "character varying"}], - "name" : 'CHICKEN'} + table_spec_2 = { + "columns": [ + {"name": "id", "type": "serial", "primary_key": True}, + {"name": "name", "type": "character varying"}, + {"name": "colour", "type": "character varying"}, + ], + "name": "CHICKEN", + } ensure_test_table(table_spec_2) global COW_RECORD_COUNT @@ -223,137 +321,162 @@ def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - cow_stream = [s for s in streams if s['table_name'] == 'COW'][0] + cow_stream = [s for s in streams if s["table_name"] == "COW"][0] self.assertIsNotNone(cow_stream) cow_stream = select_all_of_stream(cow_stream) - cow_stream = set_replication_method_for_stream(cow_stream, 'FULL_TABLE') + cow_stream = set_replication_method_for_stream(cow_stream, "FULL_TABLE") - chicken_stream = [s for s in streams if s['table_name'] == 'CHICKEN'][0] + chicken_stream = [s for s in streams if s["table_name"] == "CHICKEN"][0] self.assertIsNotNone(chicken_stream) chicken_stream = select_all_of_stream(chicken_stream) - chicken_stream = set_replication_method_for_stream(chicken_stream, 'FULL_TABLE') + chicken_stream = set_replication_method_for_stream(chicken_stream, "FULL_TABLE") with get_test_connection() as conn: conn.autocommit = True cur = conn.cursor() - cow_rec = {'name' : 'betty', 'colour' : 'blue'} - insert_record(cur, 'COW', cow_rec) - cow_rec = {'name' : 'smelly', 'colour' : 'brow'} - insert_record(cur, 'COW', cow_rec) + cow_rec = {"name": "betty", "colour": "blue"} + insert_record(cur, "COW", cow_rec) + cow_rec = {"name": "smelly", "colour": "brow"} + insert_record(cur, "COW", cow_rec) - cow_rec = {'name' : 'pooper', 'colour' : 'green'} - insert_record(cur, 'COW', cow_rec) + cow_rec = {"name": "pooper", "colour": "green"} + insert_record(cur, "COW", cow_rec) - chicken_rec = {'name' : 'fred', 'colour' : 'red'} - insert_record(cur, 'CHICKEN', chicken_rec) + chicken_rec = {"name": "fred", "colour": "red"} + insert_record(cur, "CHICKEN", chicken_rec) state = {} - #this will sync the CHICKEN but then blow up on the COW + # this will sync the CHICKEN but then blow up on the COW try: - tap_postgres.do_sync(get_test_connection_config(), {'streams' : streams}, None, state) + tap_postgres.do_sync( + get_test_connection_config(), {"streams": streams}, None, state + ) except Exception as ex: # LOGGER.exception(ex) blew_up_on_cow = True self.assertTrue(blew_up_on_cow) - self.assertEqual(14, len(CAUGHT_MESSAGES)) - self.assertEqual(CAUGHT_MESSAGES[0]['type'], 'SCHEMA') + self.assertEqual(CAUGHT_MESSAGES[0]["type"], "SCHEMA") self.assertIsInstance(CAUGHT_MESSAGES[1], singer.StateMessage) - self.assertIsNone(CAUGHT_MESSAGES[1].value['bookmarks']['public-CHICKEN'].get('xmin')) + self.assertIsNone( + CAUGHT_MESSAGES[1].value["bookmarks"]["public-CHICKEN"].get("xmin") + ) self.assertIsInstance(CAUGHT_MESSAGES[2], singer.ActivateVersionMessage) new_version = CAUGHT_MESSAGES[2].version self.assertIsInstance(CAUGHT_MESSAGES[3], singer.RecordMessage) - self.assertEqual('public-CHICKEN', CAUGHT_MESSAGES[3].stream) + self.assertEqual("public-CHICKEN", CAUGHT_MESSAGES[3].stream) self.assertIsInstance(CAUGHT_MESSAGES[4], singer.StateMessage) - #xmin is set while we are processing the full table replication - self.assertIsNotNone(CAUGHT_MESSAGES[4].value['bookmarks']['public-CHICKEN']['xmin']) + # xmin is set while we are processing the full table replication + self.assertIsNotNone( + CAUGHT_MESSAGES[4].value["bookmarks"]["public-CHICKEN"]["xmin"] + ) self.assertIsInstance(CAUGHT_MESSAGES[5], singer.ActivateVersionMessage) self.assertEqual(CAUGHT_MESSAGES[5].version, new_version) self.assertIsInstance(CAUGHT_MESSAGES[6], singer.StateMessage) - self.assertEqual(None, singer.get_currently_syncing( CAUGHT_MESSAGES[6].value)) - #xmin is cleared at the end of the full table replication - self.assertIsNone(CAUGHT_MESSAGES[6].value['bookmarks']['public-CHICKEN']['xmin']) + self.assertEqual(None, singer.get_currently_syncing(CAUGHT_MESSAGES[6].value)) + # xmin is cleared at the end of the full table replication + self.assertIsNone( + CAUGHT_MESSAGES[6].value["bookmarks"]["public-CHICKEN"]["xmin"] + ) + # cow messages + self.assertEqual(CAUGHT_MESSAGES[7]["type"], "SCHEMA") - #cow messages - self.assertEqual(CAUGHT_MESSAGES[7]['type'], 'SCHEMA') - - self.assertEqual("public-COW", CAUGHT_MESSAGES[7]['stream']) + self.assertEqual("public-COW", CAUGHT_MESSAGES[7]["stream"]) self.assertIsInstance(CAUGHT_MESSAGES[8], singer.StateMessage) - self.assertIsNone(CAUGHT_MESSAGES[8].value['bookmarks']['public-COW'].get('xmin')) - self.assertEqual("public-COW", CAUGHT_MESSAGES[8].value['currently_syncing']) + self.assertIsNone( + CAUGHT_MESSAGES[8].value["bookmarks"]["public-COW"].get("xmin") + ) + self.assertEqual("public-COW", CAUGHT_MESSAGES[8].value["currently_syncing"]) self.assertIsInstance(CAUGHT_MESSAGES[9], singer.ActivateVersionMessage) cow_version = CAUGHT_MESSAGES[9].version self.assertIsInstance(CAUGHT_MESSAGES[10], singer.RecordMessage) - self.assertEqual(CAUGHT_MESSAGES[10].record['name'], 'betty') - self.assertEqual('public-COW', CAUGHT_MESSAGES[10].stream) + self.assertEqual(CAUGHT_MESSAGES[10].record["name"], "betty") + self.assertEqual("public-COW", CAUGHT_MESSAGES[10].stream) self.assertIsInstance(CAUGHT_MESSAGES[11], singer.StateMessage) - #xmin is set while we are processing the full table replication - self.assertIsNotNone(CAUGHT_MESSAGES[11].value['bookmarks']['public-COW']['xmin']) - + # xmin is set while we are processing the full table replication + self.assertIsNotNone( + CAUGHT_MESSAGES[11].value["bookmarks"]["public-COW"]["xmin"] + ) - self.assertEqual(CAUGHT_MESSAGES[12].record['name'], 'smelly') - self.assertEqual('public-COW', CAUGHT_MESSAGES[12].stream) + self.assertEqual(CAUGHT_MESSAGES[12].record["name"], "smelly") + self.assertEqual("public-COW", CAUGHT_MESSAGES[12].stream) old_state = CAUGHT_MESSAGES[13].value - #run another do_sync + # run another do_sync singer.write_message = singer_write_message_ok CAUGHT_MESSAGES.clear() global COW_RECORD_COUNT COW_RECORD_COUNT = 0 - tap_postgres.do_sync(get_test_connection_config(), {'streams' : streams}, None, old_state) + tap_postgres.do_sync( + get_test_connection_config(), {"streams": streams}, None, old_state + ) - self.assertEqual(CAUGHT_MESSAGES[0]['type'], 'SCHEMA') + self.assertEqual(CAUGHT_MESSAGES[0]["type"], "SCHEMA") self.assertIsInstance(CAUGHT_MESSAGES[1], singer.StateMessage) # because we were interrupted, we do not switch versions - self.assertEqual(CAUGHT_MESSAGES[1].value['bookmarks']['public-COW']['version'], cow_version) - self.assertIsNotNone(CAUGHT_MESSAGES[1].value['bookmarks']['public-COW']['xmin']) - self.assertEqual("public-COW", singer.get_currently_syncing(CAUGHT_MESSAGES[1].value)) + self.assertEqual( + CAUGHT_MESSAGES[1].value["bookmarks"]["public-COW"]["version"], cow_version + ) + self.assertIsNotNone( + CAUGHT_MESSAGES[1].value["bookmarks"]["public-COW"]["xmin"] + ) + self.assertEqual( + "public-COW", singer.get_currently_syncing(CAUGHT_MESSAGES[1].value) + ) self.assertIsInstance(CAUGHT_MESSAGES[2], singer.RecordMessage) - self.assertEqual(CAUGHT_MESSAGES[2].record['name'], 'smelly') - self.assertEqual('public-COW', CAUGHT_MESSAGES[2].stream) - + self.assertEqual(CAUGHT_MESSAGES[2].record["name"], "smelly") + self.assertEqual("public-COW", CAUGHT_MESSAGES[2].stream) - #after record: activate version, state with no xmin or currently syncing + # after record: activate version, state with no xmin or currently syncing self.assertIsInstance(CAUGHT_MESSAGES[3], singer.StateMessage) - #we still have an xmin for COW because are not yet done with the COW table - self.assertIsNotNone(CAUGHT_MESSAGES[3].value['bookmarks']['public-COW']['xmin']) - self.assertEqual(singer.get_currently_syncing( CAUGHT_MESSAGES[3].value), 'public-COW') + # we still have an xmin for COW because are not yet done with the COW table + self.assertIsNotNone( + CAUGHT_MESSAGES[3].value["bookmarks"]["public-COW"]["xmin"] + ) + self.assertEqual( + singer.get_currently_syncing(CAUGHT_MESSAGES[3].value), "public-COW" + ) self.assertIsInstance(CAUGHT_MESSAGES[4], singer.RecordMessage) - self.assertEqual(CAUGHT_MESSAGES[4].record['name'], 'pooper') - self.assertEqual('public-COW', CAUGHT_MESSAGES[4].stream) + self.assertEqual(CAUGHT_MESSAGES[4].record["name"], "pooper") + self.assertEqual("public-COW", CAUGHT_MESSAGES[4].stream) self.assertIsInstance(CAUGHT_MESSAGES[5], singer.StateMessage) - self.assertIsNotNone(CAUGHT_MESSAGES[5].value['bookmarks']['public-COW']['xmin']) - self.assertEqual(singer.get_currently_syncing( CAUGHT_MESSAGES[5].value), 'public-COW') - - - #xmin is cleared because we are finished the full table replication + self.assertIsNotNone( + CAUGHT_MESSAGES[5].value["bookmarks"]["public-COW"]["xmin"] + ) + self.assertEqual( + singer.get_currently_syncing(CAUGHT_MESSAGES[5].value), "public-COW" + ) + + # xmin is cleared because we are finished the full table replication self.assertIsInstance(CAUGHT_MESSAGES[6], singer.ActivateVersionMessage) self.assertEqual(CAUGHT_MESSAGES[6].version, cow_version) self.assertIsInstance(CAUGHT_MESSAGES[7], singer.StateMessage) - self.assertIsNone(singer.get_currently_syncing( CAUGHT_MESSAGES[7].value)) - self.assertIsNone(CAUGHT_MESSAGES[7].value['bookmarks']['public-CHICKEN']['xmin']) - self.assertIsNone(singer.get_currently_syncing( CAUGHT_MESSAGES[7].value)) + self.assertIsNone(singer.get_currently_syncing(CAUGHT_MESSAGES[7].value)) + self.assertIsNone( + CAUGHT_MESSAGES[7].value["bookmarks"]["public-CHICKEN"]["xmin"] + ) + self.assertIsNone(singer.get_currently_syncing(CAUGHT_MESSAGES[7].value)) -if __name__== "__main__": +if __name__ == "__main__": test1 = LogicalInterruption() test1.setUp() test1.test_catalog() diff --git a/tests/test_logical_replication.py b/tests/test_logical_replication.py index ae0f5832..c9b41bd2 100644 --- a/tests/test_logical_replication.py +++ b/tests/test_logical_replication.py @@ -22,7 +22,10 @@ def __init__(self, existing_slot_name): def execute(self, sql): """Simulating to run an SQL query If the query is selecting the existing_slot_name then the replication slot found""" - if sql == f"SELECT * FROM pg_replication_slots WHERE slot_name = '{self.existing_slot_name}'": + if ( + sql + == f"SELECT * FROM pg_replication_slots WHERE slot_name = '{self.existing_slot_name}'" + ): self.replication_slot_found = True def fetchall(self): @@ -37,118 +40,165 @@ class TestLogicalReplication(unittest.TestCase): maxDiff = None def setUp(self): - self.WalMessage = namedtuple('WalMessage', ['payload', 'data_start']) + self.WalMessage = namedtuple("WalMessage", ["payload", "data_start"]) def test_streams_to_wal2json_tables(self): """Validate if table names are escaped to wal2json format""" streams = [ - {'metadata': [{'metadata': {'schema-name': 'public'}}], - 'table_name': 'dummy_table'}, - {'metadata': [{'metadata': {'schema-name': 'public'}}], - 'table_name': 'CaseSensitiveTable'}, - {'metadata': [{'metadata': {'schema-name': 'public'}}], - 'table_name': 'Case Sensitive Table With Space'}, - {'metadata': [{'metadata': {'schema-name': 'CaseSensitiveSchema'}}], - 'table_name': 'dummy_table'}, - {'metadata': [{'metadata': {'schema-name': 'Case Sensitive Schema With Space'}}], - 'table_name': 'CaseSensitiveTable'}, - {'metadata': [{'metadata': {'schema-name': 'Case Sensitive Schema With Space'}}], - 'table_name': 'Case Sensitive Table With Space'}, - {'metadata': [{'metadata': {'schema-name': 'public'}}], - 'table_name': 'table_with_comma_,'}, - {'metadata': [{'metadata': {'schema-name': 'public'}}], - 'table_name': "table_with_quote_'"} + { + "metadata": [{"metadata": {"schema-name": "public"}}], + "table_name": "dummy_table", + }, + { + "metadata": [{"metadata": {"schema-name": "public"}}], + "table_name": "CaseSensitiveTable", + }, + { + "metadata": [{"metadata": {"schema-name": "public"}}], + "table_name": "Case Sensitive Table With Space", + }, + { + "metadata": [{"metadata": {"schema-name": "CaseSensitiveSchema"}}], + "table_name": "dummy_table", + }, + { + "metadata": [ + {"metadata": {"schema-name": "Case Sensitive Schema With Space"}} + ], + "table_name": "CaseSensitiveTable", + }, + { + "metadata": [ + {"metadata": {"schema-name": "Case Sensitive Schema With Space"}} + ], + "table_name": "Case Sensitive Table With Space", + }, + { + "metadata": [{"metadata": {"schema-name": "public"}}], + "table_name": "table_with_comma_,", + }, + { + "metadata": [{"metadata": {"schema-name": "public"}}], + "table_name": "table_with_quote_'", + }, ] - self.assertEqual(logical_replication.streams_to_wal2json_tables(streams), - 'public.dummy_table,' - 'public.CaseSensitiveTable,' - 'public.Case\\ Sensitive\\ Table\\ With\\ Space,' - 'CaseSensitiveSchema.dummy_table,' - 'Case\\ Sensitive\\ Schema\\ With\\ Space.CaseSensitiveTable,' - 'Case\\ Sensitive\\ Schema\\ With\\ Space.Case\\ Sensitive\\ Table\\ With\\ Space,' - 'public.table_with_comma_\\,,' - "public.table_with_quote_\\'") + self.assertEqual( + logical_replication.streams_to_wal2json_tables(streams), + "public.dummy_table," + "public.CaseSensitiveTable," + "public.Case\\ Sensitive\\ Table\\ With\\ Space," + "CaseSensitiveSchema.dummy_table," + "Case\\ Sensitive\\ Schema\\ With\\ Space.CaseSensitiveTable," + "Case\\ Sensitive\\ Schema\\ With\\ Space.Case\\ Sensitive\\ Table\\ With\\ Space," + "public.table_with_comma_\\,," + "public.table_with_quote_\\'", + ) def test_generate_replication_slot_name(self): """Validate if the replication slot name generated correctly""" # Provide only database name - self.assertEqual(logical_replication.generate_replication_slot_name('some_db'), - 'pipelinewise_some_db') + self.assertEqual( + logical_replication.generate_replication_slot_name("some_db"), + "pipelinewise_some_db", + ) # Provide database name and tap_id - self.assertEqual(logical_replication.generate_replication_slot_name('some_db', - 'some_tap'), - 'pipelinewise_some_db_some_tap') + self.assertEqual( + logical_replication.generate_replication_slot_name("some_db", "some_tap"), + "pipelinewise_some_db_some_tap", + ) # Provide database name, tap_id and prefix - self.assertEqual(logical_replication.generate_replication_slot_name('some_db', - 'some_tap', - prefix='custom_prefix'), - 'custom_prefix_some_db_some_tap') + self.assertEqual( + logical_replication.generate_replication_slot_name( + "some_db", "some_tap", prefix="custom_prefix" + ), + "custom_prefix_some_db_some_tap", + ) # Replication slot name should be lowercase - self.assertEqual(logical_replication.generate_replication_slot_name('SoMe_DB', - 'SoMe_TaP'), - 'pipelinewise_some_db_some_tap') + self.assertEqual( + logical_replication.generate_replication_slot_name("SoMe_DB", "SoMe_TaP"), + "pipelinewise_some_db_some_tap", + ) # Invalid characters should be replaced by underscores - self.assertEqual(logical_replication.generate_replication_slot_name('some-db', - 'some-tap'), - 'pipelinewise_some_db_some_tap') + self.assertEqual( + logical_replication.generate_replication_slot_name("some-db", "some-tap"), + "pipelinewise_some_db_some_tap", + ) - self.assertEqual(logical_replication.generate_replication_slot_name('some.db', - 'some.tap'), - 'pipelinewise_some_db_some_tap') + self.assertEqual( + logical_replication.generate_replication_slot_name("some.db", "some.tap"), + "pipelinewise_some_db_some_tap", + ) def test_locate_replication_slot_by_cur(self): """Validate if both v15 and v16 style replication slot located correctly""" # Should return v15 style slot name if v15 style replication slot exists - cursor = PostgresCurReplicationSlotMock(existing_slot_name='pipelinewise_some_db') - self.assertEqual(logical_replication.locate_replication_slot_by_cur(cursor, - 'some_db', - 'some_tap'), - 'pipelinewise_some_db') + cursor = PostgresCurReplicationSlotMock( + existing_slot_name="pipelinewise_some_db" + ) + self.assertEqual( + logical_replication.locate_replication_slot_by_cur( + cursor, "some_db", "some_tap" + ), + "pipelinewise_some_db", + ) # Should return v16 style slot name if v16 style replication slot exists - cursor = PostgresCurReplicationSlotMock(existing_slot_name='pipelinewise_some_db_some_tap') - self.assertEqual(logical_replication.locate_replication_slot_by_cur(cursor, - 'some_db', - 'some_tap'), - 'pipelinewise_some_db_some_tap') + cursor = PostgresCurReplicationSlotMock( + existing_slot_name="pipelinewise_some_db_some_tap" + ) + self.assertEqual( + logical_replication.locate_replication_slot_by_cur( + cursor, "some_db", "some_tap" + ), + "pipelinewise_some_db_some_tap", + ) # Should return v15 style replication slot if tap_id not provided and the v15 slot exists - cursor = PostgresCurReplicationSlotMock(existing_slot_name='pipelinewise_some_db') - self.assertEqual(logical_replication.locate_replication_slot_by_cur(cursor, - 'some_db'), - 'pipelinewise_some_db') + cursor = PostgresCurReplicationSlotMock( + existing_slot_name="pipelinewise_some_db" + ) + self.assertEqual( + logical_replication.locate_replication_slot_by_cur(cursor, "some_db"), + "pipelinewise_some_db", + ) # Should raise an exception if no v15 or v16 style replication slot found cursor = PostgresCurReplicationSlotMock(existing_slot_name=None) with self.assertRaises(logical_replication.ReplicationSlotNotFoundError): - self.assertEqual(logical_replication.locate_replication_slot_by_cur(cursor, - 'some_db', - 'some_tap'), - 'pipelinewise_some_db_some_tap') + self.assertEqual( + logical_replication.locate_replication_slot_by_cur( + cursor, "some_db", "some_tap" + ), + "pipelinewise_some_db_some_tap", + ) def test_consume_with_message_payload_is_not_json_expect_same_state(self): - output = logical_replication.consume_message([], - {}, - self.WalMessage(payload='this is an invalid json message', - data_start=None), - None, - {} - ) + output = logical_replication.consume_message( + [], + {}, + self.WalMessage(payload="this is an invalid json message", data_start=None), + None, + {}, + ) self.assertDictEqual({}, output) - def test_consume_with_message_stream_in_payload_is_not_selected_expect_same_state(self): + def test_consume_with_message_stream_in_payload_is_not_selected_expect_same_state( + self, + ): output = logical_replication.consume_message( - [{'tap_stream_id': 'myschema-mytable'}], + [{"tap_stream_id": "myschema-mytable"}], {}, - self.WalMessage(payload='{"schema": "myschema", "table": "notmytable"}', - data_start='some lsn'), + self.WalMessage( + payload='{"schema": "myschema", "table": "notmytable"}', + data_start="some lsn", + ), None, - {} + {}, ) self.assertDictEqual({}, output) @@ -156,61 +206,51 @@ def test_consume_with_message_stream_in_payload_is_not_selected_expect_same_stat def test_consume_with_payload_kind_is_not_supported_expect_exception(self): with self.assertRaises(UnsupportedPayloadKindError): logical_replication.consume_message( - [{'tap_stream_id': 'myschema-mytable'}], + [{"tap_stream_id": "myschema-mytable"}], {}, - self.WalMessage(payload='{"kind":"truncate", "schema": "myschema", "table": "mytable"}', - data_start='some lsn'), + self.WalMessage( + payload='{"kind":"truncate", "schema": "myschema", "table": "mytable"}', + data_start="some lsn", + ), None, - {} + {}, ) - @patch('tap_postgres.logical_replication.singer.write_message') - @patch('tap_postgres.logical_replication.sync_common.send_schema_message') - @patch('tap_postgres.logical_replication.refresh_streams_schema') - def test_consume_message_with_new_column_in_payload_will_refresh_schema(self, - refresh_schema_mock, - send_schema_mock, - write_message_mock): + @patch("tap_postgres.logical_replication.singer.write_message") + @patch("tap_postgres.logical_replication.sync_common.send_schema_message") + @patch("tap_postgres.logical_replication.refresh_streams_schema") + def test_consume_message_with_new_column_in_payload_will_refresh_schema( + self, refresh_schema_mock, send_schema_mock, write_message_mock + ): streams = [ { - 'tap_stream_id': 'myschema-mytable', - 'stream': 'mytable', - 'schema': { - 'properties': { - 'id': {}, - 'date_created': {} - } - }, - 'metadata': [ + "tap_stream_id": "myschema-mytable", + "stream": "mytable", + "schema": {"properties": {"id": {}, "date_created": {}}}, + "metadata": [ { - 'breadcrumb': [], - 'metadata': { - 'is-view': False, - 'table-key-properties': ['id'], - 'schema-name': 'myschema' - } + "breadcrumb": [], + "metadata": { + "is-view": False, + "table-key-properties": ["id"], + "schema-name": "myschema", + }, }, { - "breadcrumb": [ - "properties", - "id" - ], + "breadcrumb": ["properties", "id"], "metadata": { "sql-datatype": "integer", "inclusion": "automatic", - } + }, }, { - "breadcrumb": [ - "properties", - "date_created" - ], + "breadcrumb": ["properties", "date_created"], "metadata": { "sql-datatype": "datetime", "inclusion": "available", - "selected": True - } - } + "selected": True, + }, + }, ], } ] @@ -218,338 +258,421 @@ def test_consume_message_with_new_column_in_payload_will_refresh_schema(self, return_v = logical_replication.consume_message( streams, { - 'bookmarks': { + "bookmarks": { "myschema-mytable": { "last_replication_method": "LOG_BASED", "lsn": None, "version": 1000, - "xmin": None + "xmin": None, } } }, - self.WalMessage(payload='{"kind": "insert", ' - '"schema": "myschema", ' - '"table": "mytable",' - '"columnnames": ["id", "date_created", "new_col"],' - '"columnnames": [1, null, "some random text"]' - '}', - data_start='some lsn'), + self.WalMessage( + payload='{"kind": "insert", ' + '"schema": "myschema", ' + '"table": "mytable",' + '"columnnames": ["id", "date_created", "new_col"],' + '"columnnames": [1, null, "some random text"]' + "}", + data_start="some lsn", + ), None, - {} + {}, ) - self.assertDictEqual(return_v, - { - 'bookmarks': { - "myschema-mytable": { - "last_replication_method": "LOG_BASED", - "lsn": "some lsn", - "version": 1000, - "xmin": None - } - } - }) + self.assertDictEqual( + return_v, + { + "bookmarks": { + "myschema-mytable": { + "last_replication_method": "LOG_BASED", + "lsn": "some lsn", + "version": 1000, + "xmin": None, + } + } + }, + ) refresh_schema_mock.assert_called_once_with({}, [streams[0]]) send_schema_mock.assert_called_once() write_message_mock.assert_called_once() - def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_expect_iso_format(self): - output = logical_replication.selected_value_to_singer_value_impl('2020-09-01 20:10:56', - 'timestamp without time zone', - None) + def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_expect_iso_format( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + "2020-09-01 20:10:56", "timestamp without time zone", None + ) - self.assertEqual('2020-09-01T20:10:56+00:00', output) + self.assertEqual("2020-09-01T20:10:56+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_datetime_expect_iso_format(self): - output = logical_replication.selected_value_to_singer_value_impl(datetime(2020, 9, 1, 20, 10, 59), - 'timestamp without time zone', - None) + def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_datetime_expect_iso_format( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + datetime(2020, 9, 1, 20, 10, 59), "timestamp without time zone", None + ) - self.assertEqual('2020-09-01T20:10:59+00:00', output) + self.assertEqual("2020-09-01T20:10:59+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_out_of_range_1(self): + def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_out_of_range_1( + self, + ): """ Test selected_value_to_singer_value_impl with timestamp without tz as string where year is > 9999 should fallback to max datetime allowed """ - output = logical_replication.selected_value_to_singer_value_impl('10000-09-01 20:10:56', - 'timestamp without time zone', - None) + output = logical_replication.selected_value_to_singer_value_impl( + "10000-09-01 20:10:56", "timestamp without time zone", None + ) - self.assertEqual('9999-12-31T23:59:59.999+00:00', output) + self.assertEqual("9999-12-31T23:59:59.999+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_out_of_range_2(self): + def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_out_of_range_2( + self, + ): """ Test selected_value_to_singer_value_impl with timestamp without tz as string where year is < 0001 should fallback to max datetime allowed """ - output = logical_replication.selected_value_to_singer_value_impl('0000-09-01 20:10:56', - 'timestamp without time zone', - None) + output = logical_replication.selected_value_to_singer_value_impl( + "0000-09-01 20:10:56", "timestamp without time zone", None + ) - self.assertEqual('9999-12-31T23:59:59.999+00:00', output) + self.assertEqual("9999-12-31T23:59:59.999+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_BC(self): + def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_BC( + self, + ): """ Test selected_value_to_singer_value_impl with timestamp without tz as string where era is BC should fallback to max datetime allowed """ - output = logical_replication.selected_value_to_singer_value_impl('1000-09-01 20:10:56 BC', - 'timestamp without time zone', - None) + output = logical_replication.selected_value_to_singer_value_impl( + "1000-09-01 20:10:56 BC", "timestamp without time zone", None + ) - self.assertEqual('9999-12-31T23:59:59.999+00:00', output) + self.assertEqual("9999-12-31T23:59:59.999+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_AC(self): + def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_AC( + self, + ): """ Test selected_value_to_singer_value_impl with timestamp without tz as string where era is AC should fallback to max datetime allowed """ - output = logical_replication.selected_value_to_singer_value_impl('1000-09-01 20:10:56 AC', - 'timestamp without time zone', - None) + output = logical_replication.selected_value_to_singer_value_impl( + "1000-09-01 20:10:56 AC", "timestamp without time zone", None + ) - self.assertEqual('9999-12-31T23:59:59.999+00:00', output) + self.assertEqual("9999-12-31T23:59:59.999+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_min(self): - output = logical_replication.selected_value_to_singer_value_impl('0001-01-01 00:00:00.000123', - 'timestamp without time zone', - None) + def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_min( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + "0001-01-01 00:00:00.000123", "timestamp without time zone", None + ) - self.assertEqual('0001-01-01T00:00:00.000123+00:00', output) + self.assertEqual("0001-01-01T00:00:00.000123+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_max(self): - output = logical_replication.selected_value_to_singer_value_impl('9999-12-31 23:59:59.999999', - 'timestamp without time zone', - None) + def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_max( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + "9999-12-31 23:59:59.999999", "timestamp without time zone", None + ) - self.assertEqual('9999-12-31T23:59:59.999+00:00', output) + self.assertEqual("9999-12-31T23:59:59.999+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_datetime_min(self): - output = logical_replication.selected_value_to_singer_value_impl(datetime(1, 1, 1, 0, 0, 0, 123), - 'timestamp without time zone', - None) + def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_datetime_min( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + datetime(1, 1, 1, 0, 0, 0, 123), "timestamp without time zone", None + ) - self.assertEqual('0001-01-01T00:00:00.000123+00:00', output) + self.assertEqual("0001-01-01T00:00:00.000123+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_datetime_max(self): - output = logical_replication.selected_value_to_singer_value_impl(datetime(9999, 12, 31, 23, 59, 59, 999999), - 'timestamp without time zone', - None) + def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_datetime_max( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + datetime(9999, 12, 31, 23, 59, 59, 999999), + "timestamp without time zone", + None, + ) - self.assertEqual('9999-12-31T23:59:59.999+00:00', output) + self.assertEqual("9999-12-31T23:59:59.999+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_expect_iso_format(self): - output = logical_replication.selected_value_to_singer_value_impl('2020-09-01 20:10:56+05', - 'timestamp with time zone', - None) + def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_expect_iso_format( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + "2020-09-01 20:10:56+05", "timestamp with time zone", None + ) - self.assertEqual('2020-09-01T20:10:56+05:00', output) + self.assertEqual("2020-09-01T20:10:56+05:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_datetime_expect_iso_format(self): - output = logical_replication.selected_value_to_singer_value_impl(datetime(2020, 9, 1, 23, 10, 59, - tzinfo=tzoffset(None, -3600)), - 'timestamp with time zone', - None) + def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_datetime_expect_iso_format( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + datetime(2020, 9, 1, 23, 10, 59, tzinfo=tzoffset(None, -3600)), + "timestamp with time zone", + None, + ) - self.assertEqual('2020-09-01T23:10:59-01:00', output) + self.assertEqual("2020-09-01T23:10:59-01:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_out_of_range_1(self): + def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_out_of_range_1( + self, + ): """ Test selected_value_to_singer_value_impl with timestamp with tz as string where year is > 9999 should fallback to max datetime allowed """ - output = logical_replication.selected_value_to_singer_value_impl('10000-09-01 20:10:56+06', - 'timestamp with time zone', - None) + output = logical_replication.selected_value_to_singer_value_impl( + "10000-09-01 20:10:56+06", "timestamp with time zone", None + ) - self.assertEqual('9999-12-31T23:59:59.999+00:00', output) + self.assertEqual("9999-12-31T23:59:59.999+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_out_of_range_2(self): + def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_out_of_range_2( + self, + ): """ Test selected_value_to_singer_value_impl with timestamp with tz as string where year is < 0001 should fallback to max datetime allowed """ - output = logical_replication.selected_value_to_singer_value_impl('0000-09-01 20:10:56+01', - 'timestamp with time zone', - None) + output = logical_replication.selected_value_to_singer_value_impl( + "0000-09-01 20:10:56+01", "timestamp with time zone", None + ) - self.assertEqual('9999-12-31T23:59:59.999+00:00', output) + self.assertEqual("9999-12-31T23:59:59.999+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_BC(self): + def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_BC( + self, + ): """ Test selected_value_to_singer_value_impl with timestamp with tz as string where era is BC should fallback to max datetime allowed """ - output = logical_replication.selected_value_to_singer_value_impl('1000-09-01 20:10:56+05 BC', - 'timestamp with time zone', - None) + output = logical_replication.selected_value_to_singer_value_impl( + "1000-09-01 20:10:56+05 BC", "timestamp with time zone", None + ) - self.assertEqual('9999-12-31T23:59:59.999+00:00', output) + self.assertEqual("9999-12-31T23:59:59.999+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_AC(self): + def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_AC( + self, + ): """ Test selected_value_to_singer_value_impl with timestamp with tz as string where era is AC should fallback to max datetime allowed """ - output = logical_replication.selected_value_to_singer_value_impl('1000-09-01 20:10:56-09 AC', - 'timestamp with time zone', - None) + output = logical_replication.selected_value_to_singer_value_impl( + "1000-09-01 20:10:56-09 AC", "timestamp with time zone", None + ) - self.assertEqual('9999-12-31T23:59:59.999+00:00', output) + self.assertEqual("9999-12-31T23:59:59.999+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_min(self): - output = logical_replication.selected_value_to_singer_value_impl('0001-01-01 00:00:00.000123+04', - 'timestamp with time zone', - None) + def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_min( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + "0001-01-01 00:00:00.000123+04", "timestamp with time zone", None + ) - self.assertEqual('9999-12-31T23:59:59.999+00:00', output) + self.assertEqual("9999-12-31T23:59:59.999+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_max(self): - output = logical_replication.selected_value_to_singer_value_impl('9999-12-31 23:59:59.999999-03', - 'timestamp with time zone', - None) + def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_max( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + "9999-12-31 23:59:59.999999-03", "timestamp with time zone", None + ) - self.assertEqual('9999-12-31T23:59:59.999+00:00', output) + self.assertEqual("9999-12-31T23:59:59.999+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_datetime_min(self): - output = logical_replication.selected_value_to_singer_value_impl(datetime(1, 1, 1, 0, 0, 0, 123, - tzinfo=tzoffset(None, 14400)), - 'timestamp with time zone', - None) + def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_datetime_min( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + datetime(1, 1, 1, 0, 0, 0, 123, tzinfo=tzoffset(None, 14400)), + "timestamp with time zone", + None, + ) - self.assertEqual('9999-12-31T23:59:59.999+00:00', output) + self.assertEqual("9999-12-31T23:59:59.999+00:00", output) - def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_datetime_max(self): - output = logical_replication.selected_value_to_singer_value_impl(datetime(9999, 12, 31, 23, 59, 59, 999999, - tzinfo=tzoffset(None, -14400)), - 'timestamp with time zone', - None) + def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_datetime_max( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=tzoffset(None, -14400)), + "timestamp with time zone", + None, + ) - self.assertEqual('9999-12-31T23:59:59.999+00:00', output) + self.assertEqual("9999-12-31T23:59:59.999+00:00", output) - def test_selected_value_to_singer_value_impl_with_date_value_as_string_expect_iso_format(self): - output = logical_replication.selected_value_to_singer_value_impl('2021-09-07', 'date', None) + def test_selected_value_to_singer_value_impl_with_date_value_as_string_expect_iso_format( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + "2021-09-07", "date", None + ) - self.assertEqual('2021-09-07T00:00:00+00:00', output) + self.assertEqual("2021-09-07T00:00:00+00:00", output) - def test_selected_value_to_singer_value_impl_with_date_value_as_string_out_of_range(self): + def test_selected_value_to_singer_value_impl_with_date_value_as_string_out_of_range( + self, + ): """ Test selected_value_to_singer_value_impl with date as string where year is > 9999 (which is valid in postgres) should fallback to max date allowed """ - output = logical_replication.selected_value_to_singer_value_impl('10000-09-01', 'date', None) + output = logical_replication.selected_value_to_singer_value_impl( + "10000-09-01", "date", None + ) - self.assertEqual('9999-12-31T00:00:00+00:00', output) + self.assertEqual("9999-12-31T00:00:00+00:00", output) def test_row_to_singer_message(self): stream = { - 'stream': 'my_stream', + "stream": "my_stream", } row = [ - '2020-01-01 10:30:45', - '2020-01-01 10:30:45 BC', - '50000-01-01 10:30:45', + "2020-01-01 10:30:45", + "2020-01-01 10:30:45 BC", + "50000-01-01 10:30:45", datetime(2020, 1, 1, 10, 30, 45), - '2020-01-01 10:30:45-02', - '0000-01-01 10:30:45-02', - '2020-01-01 10:30:45-02 AC', + "2020-01-01 10:30:45-02", + "0000-01-01 10:30:45-02", + "2020-01-01 10:30:45-02 AC", datetime(2020, 1, 1, 10, 30, 45, tzinfo=tzoffset(None, 3600)), ] columns = [ - 'c_timestamp_ntz_1', - 'c_timestamp_ntz_2', - 'c_timestamp_ntz_3', - 'c_timestamp_ntz_4', - 'c_timestamp_tz_1', - 'c_timestamp_tz_2', - 'c_timestamp_tz_3', - 'c_timestamp_tz_4', + "c_timestamp_ntz_1", + "c_timestamp_ntz_2", + "c_timestamp_ntz_3", + "c_timestamp_ntz_4", + "c_timestamp_tz_1", + "c_timestamp_tz_2", + "c_timestamp_tz_3", + "c_timestamp_tz_4", ] md_map = { - (): {'schema-name': 'my_schema'}, - ('properties', 'c_timestamp_ntz_1'): {'sql-datatype': 'timestamp without time zone'}, - ('properties', 'c_timestamp_ntz_2'): {'sql-datatype': 'timestamp without time zone'}, - ('properties', 'c_timestamp_ntz_3'): {'sql-datatype': 'timestamp without time zone'}, - ('properties', 'c_timestamp_ntz_4'): {'sql-datatype': 'timestamp without time zone'}, - ('properties', 'c_timestamp_tz_1'): {'sql-datatype': 'timestamp with time zone'}, - ('properties', 'c_timestamp_tz_2'): {'sql-datatype': 'timestamp with time zone'}, - ('properties', 'c_timestamp_tz_3'): {'sql-datatype': 'timestamp with time zone'}, - ('properties', 'c_timestamp_tz_4'): {'sql-datatype': 'timestamp with time zone'}, + (): {"schema-name": "my_schema"}, + ("properties", "c_timestamp_ntz_1"): { + "sql-datatype": "timestamp without time zone" + }, + ("properties", "c_timestamp_ntz_2"): { + "sql-datatype": "timestamp without time zone" + }, + ("properties", "c_timestamp_ntz_3"): { + "sql-datatype": "timestamp without time zone" + }, + ("properties", "c_timestamp_ntz_4"): { + "sql-datatype": "timestamp without time zone" + }, + ("properties", "c_timestamp_tz_1"): { + "sql-datatype": "timestamp with time zone" + }, + ("properties", "c_timestamp_tz_2"): { + "sql-datatype": "timestamp with time zone" + }, + ("properties", "c_timestamp_tz_3"): { + "sql-datatype": "timestamp with time zone" + }, + ("properties", "c_timestamp_tz_4"): { + "sql-datatype": "timestamp with time zone" + }, } - output = logical_replication.row_to_singer_message(stream, - row, - 1000, - columns, - datetime(2020, 9, 1, 10, 10, 10, tzinfo=tzoffset(None, 0)), - md_map, - None) - - self.assertEqual('my_schema-my_stream', output.stream) - self.assertDictEqual({ - 'c_timestamp_ntz_1': '2020-01-01T10:30:45+00:00', - 'c_timestamp_ntz_2': '9999-12-31T23:59:59.999+00:00', - 'c_timestamp_ntz_3': '9999-12-31T23:59:59.999+00:00', - 'c_timestamp_ntz_4': '2020-01-01T10:30:45+00:00', - 'c_timestamp_tz_1': '2020-01-01T10:30:45-02:00', - 'c_timestamp_tz_2': '9999-12-31T23:59:59.999+00:00', - 'c_timestamp_tz_3': '9999-12-31T23:59:59.999+00:00', - 'c_timestamp_tz_4': '2020-01-01T10:30:45+01:00', - }, output.record) + output = logical_replication.row_to_singer_message( + stream, + row, + 1000, + columns, + datetime(2020, 9, 1, 10, 10, 10, tzinfo=tzoffset(None, 0)), + md_map, + None, + ) + + self.assertEqual("my_schema-my_stream", output.stream) + self.assertDictEqual( + { + "c_timestamp_ntz_1": "2020-01-01T10:30:45+00:00", + "c_timestamp_ntz_2": "9999-12-31T23:59:59.999+00:00", + "c_timestamp_ntz_3": "9999-12-31T23:59:59.999+00:00", + "c_timestamp_ntz_4": "2020-01-01T10:30:45+00:00", + "c_timestamp_tz_1": "2020-01-01T10:30:45-02:00", + "c_timestamp_tz_2": "9999-12-31T23:59:59.999+00:00", + "c_timestamp_tz_3": "9999-12-31T23:59:59.999+00:00", + "c_timestamp_tz_4": "2020-01-01T10:30:45+01:00", + }, + output.record, + ) self.assertEqual(1000, output.version) - self.assertEqual(datetime(2020, 9, 1, 10, 10, 10, tzinfo=tzoffset(None, 0)), output.time_extracted) + self.assertEqual( + datetime(2020, 9, 1, 10, 10, 10, tzinfo=tzoffset(None, 0)), + output.time_extracted, + ) def test_selected_value_to_singer_value_impl_with_null_json_returns_None(self): - output = logical_replication.selected_value_to_singer_value_impl(None, - 'json', - None) + output = logical_replication.selected_value_to_singer_value_impl( + None, "json", None + ) self.assertEqual(None, output) - def test_selected_value_to_singer_value_impl_with_empty_json_returns_empty_dict(self): - output = logical_replication.selected_value_to_singer_value_impl('{}', - 'json', - None) + def test_selected_value_to_singer_value_impl_with_empty_json_returns_empty_dict( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + "{}", "json", None + ) self.assertEqual({}, output) - def test_selected_value_to_singer_value_impl_with_non_empty_json_returns_equivalent_dict(self): - output = logical_replication.selected_value_to_singer_value_impl('{"key1": "A", "key2": [{"kk": "yo"}, {}]}', - 'json', - None) + def test_selected_value_to_singer_value_impl_with_non_empty_json_returns_equivalent_dict( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + '{"key1": "A", "key2": [{"kk": "yo"}, {}]}', "json", None + ) - self.assertEqual({ - 'key1': 'A', - 'key2': [{'kk': 'yo'}, {}] - }, output) + self.assertEqual({"key1": "A", "key2": [{"kk": "yo"}, {}]}, output) def test_selected_value_to_singer_value_impl_with_null_jsonb_returns_None(self): - output = logical_replication.selected_value_to_singer_value_impl(None, - 'jsonb', - None) + output = logical_replication.selected_value_to_singer_value_impl( + None, "jsonb", None + ) self.assertEqual(None, output) - def test_selected_value_to_singer_value_impl_with_empty_jsonb_returns_empty_dict(self): - output = logical_replication.selected_value_to_singer_value_impl('{}', - 'jsonb', - None) + def test_selected_value_to_singer_value_impl_with_empty_jsonb_returns_empty_dict( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + "{}", "jsonb", None + ) self.assertEqual({}, output) - def test_selected_value_to_singer_value_impl_with_non_empty_jsonb_returns_equivalent_dict(self): - output = logical_replication.selected_value_to_singer_value_impl('{"key1": "A", "key2": [{"kk": "yo"}, {}]}', - 'jsonb', - None) + def test_selected_value_to_singer_value_impl_with_non_empty_jsonb_returns_equivalent_dict( + self, + ): + output = logical_replication.selected_value_to_singer_value_impl( + '{"key1": "A", "key2": [{"kk": "yo"}, {}]}', "jsonb", None + ) - self.assertEqual({ - 'key1': 'A', - 'key2': [{'kk': 'yo'}, {}] - }, output) + self.assertEqual({"key1": "A", "key2": [{"kk": "yo"}, {}]}, output) diff --git a/tests/test_streams_utils.py b/tests/test_streams_utils.py index 2743e0c1..ebd6a876 100644 --- a/tests/test_streams_utils.py +++ b/tests/test_streams_utils.py @@ -8,7 +8,11 @@ from tap_postgres import stream_utils try: - from tests.utils import get_test_connection, ensure_test_table, get_test_connection_config + from tests.utils import ( + get_test_connection, + ensure_test_table, + get_test_connection_config, + ) except ImportError: from utils import get_test_connection, ensure_test_table, get_test_connection_config @@ -22,16 +26,20 @@ def do_not_dump_catalog(catalog): class TestInit(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' + table_name = "CHICKEN TIMES" def setUp(self): - table_spec = {"columns": [{"name": "id", "type": "integer", "primary_key": True, "serial": True}, - {"name": '"character-varying_name"', "type": "character varying"}, - {"name": '"varchar-name"', "type": "varchar(28)"}, - {"name": 'char_name', "type": "char(10)"}, - {"name": '"text-name"', "type": "text"}, - {"name": "json_name", "type": "jsonb"}], - "name": self.table_name} + table_spec = { + "columns": [ + {"name": "id", "type": "integer", "primary_key": True, "serial": True}, + {"name": '"character-varying_name"', "type": "character varying"}, + {"name": '"varchar-name"', "type": "varchar(28)"}, + {"name": "char_name", "type": "char(10)"}, + {"name": '"text-name"', "type": "text"}, + {"name": "json_name", "type": "jsonb"}, + ], + "name": self.table_name, + } ensure_test_table(table_spec) @@ -40,72 +48,96 @@ def test_refresh_streams_schema(self): streams = [ { - 'table_name': self.table_name, - 'stream': self.table_name, - 'tap_stream_id': f'public-{self.table_name}', - 'schema': {'properties': {'json_name': {'type': ['null', 'string']}}}, - 'metadata': [ + "table_name": self.table_name, + "stream": self.table_name, + "tap_stream_id": f"public-{self.table_name}", + "schema": {"properties": {"json_name": {"type": ["null", "string"]}}}, + "metadata": [ { - 'breadcrumb': [], - 'metadata': { - 'replication-method': 'LOG_BASED', - 'table-key-properties': ['some_id'], - 'row-count': 1000, - } + "breadcrumb": [], + "metadata": { + "replication-method": "LOG_BASED", + "table-key-properties": ["some_id"], + "row-count": 1000, + }, }, { - 'breadcrumb': ['properties', 'char_name'], - 'metadata': { - 'arbitrary_field_metadata': 'should be preserved' - } - } - ] + "breadcrumb": ["properties", "char_name"], + "metadata": {"arbitrary_field_metadata": "should be preserved"}, + }, + ], } ] stream_utils.refresh_streams_schema(conn_config, streams) self.assertEqual(len(streams), 1) - self.assertEqual(self.table_name, streams[0].get('table_name')) - self.assertEqual(self.table_name, streams[0].get('stream')) - - streams[0]['metadata'].sort(key=lambda md: md['breadcrumb']) - - self.assertEqual(metadata.to_map(streams[0]['metadata']), { - (): {'table-key-properties': ['id'], - 'database-name': 'postgres', - 'schema-name': 'public', - 'is-view': False, - 'row-count': 0, - 'replication-method': 'LOG_BASED' - }, - ('properties', 'character-varying_name'): {'inclusion': 'available', - 'sql-datatype': 'character varying', - 'selected-by-default': True}, - ('properties', 'id'): {'inclusion': 'automatic', - 'sql-datatype': 'integer', - 'selected-by-default': True}, - ('properties', 'varchar-name'): {'inclusion': 'available', - 'sql-datatype': 'character varying', - 'selected-by-default': True}, - ('properties', 'text-name'): {'inclusion': 'available', - 'sql-datatype': 'text', - 'selected-by-default': True}, - ('properties', 'char_name'): {'selected-by-default': True, - 'inclusion': 'available', - 'sql-datatype': 'character', - 'arbitrary_field_metadata': 'should be preserved'}, - ('properties', 'json_name'): {'selected-by-default': True, - 'inclusion': 'available', - 'sql-datatype': 'jsonb'}}) - - self.assertEqual({'properties': {'id': {'type': ['integer'], - 'maximum': 2147483647, - 'minimum': -2147483648}, - 'character-varying_name': {'type': ['null', 'string']}, - 'varchar-name': {'type': ['null', 'string'], 'maxLength': 28}, - 'char_name': {'type': ['null', 'string'], 'maxLength': 10}, - 'text-name': {'type': ['null', 'string']}, - 'json_name': {'type': ['null', 'string']}}, - 'type': 'object', - 'definitions': BASE_RECURSIVE_SCHEMAS}, streams[0].get('schema')) + self.assertEqual(self.table_name, streams[0].get("table_name")) + self.assertEqual(self.table_name, streams[0].get("stream")) + + streams[0]["metadata"].sort(key=lambda md: md["breadcrumb"]) + + self.assertEqual( + metadata.to_map(streams[0]["metadata"]), + { + (): { + "table-key-properties": ["id"], + "database-name": "postgres", + "schema-name": "public", + "is-view": False, + "row-count": 0, + "replication-method": "LOG_BASED", + }, + ("properties", "character-varying_name"): { + "inclusion": "available", + "sql-datatype": "character varying", + "selected-by-default": True, + }, + ("properties", "id"): { + "inclusion": "automatic", + "sql-datatype": "integer", + "selected-by-default": True, + }, + ("properties", "varchar-name"): { + "inclusion": "available", + "sql-datatype": "character varying", + "selected-by-default": True, + }, + ("properties", "text-name"): { + "inclusion": "available", + "sql-datatype": "text", + "selected-by-default": True, + }, + ("properties", "char_name"): { + "selected-by-default": True, + "inclusion": "available", + "sql-datatype": "character", + "arbitrary_field_metadata": "should be preserved", + }, + ("properties", "json_name"): { + "selected-by-default": True, + "inclusion": "available", + "sql-datatype": "jsonb", + }, + }, + ) + + self.assertEqual( + { + "properties": { + "id": { + "type": ["integer"], + "maximum": 2147483647, + "minimum": -2147483648, + }, + "character-varying_name": {"type": ["null", "string"]}, + "varchar-name": {"type": ["null", "string"], "maxLength": 28}, + "char_name": {"type": ["null", "string"], "maxLength": 10}, + "text-name": {"type": ["null", "string"]}, + "json_name": {"type": ["null", "string"]}, + }, + "type": "object", + "definitions": BASE_RECURSIVE_SCHEMAS, + }, + streams[0].get("schema"), + ) diff --git a/tests/test_unsupported_pk.py b/tests/test_unsupported_pk.py index 0ab72832..265aac85 100644 --- a/tests/test_unsupported_pk.py +++ b/tests/test_unsupported_pk.py @@ -3,62 +3,133 @@ from singer import get_logger, metadata import tap_postgres -from tests.utils import get_test_connection, ensure_test_table, get_test_connection_config +from tests.utils import ( + get_test_connection, + ensure_test_table, + get_test_connection_config, +) LOGGER = get_logger() + def do_not_dump_catalog(catalog): pass + tap_postgres.dump_catalog = do_not_dump_catalog + class Unsupported(unittest.TestCase): maxDiff = None - table_name = 'CHICKEN TIMES' + table_name = "CHICKEN TIMES" def setUp(self): - table_spec = {"columns": [{"name": "interval_col", "type": "INTERVAL"}, - {"name": "bit_string_col", "type": "bit(5)"}, - {"name": "bytea_col", "type": "bytea"}, - {"name": "point_col", "type": "point"}, - {"name": "line_col", "type": "line"}, - {"name": "lseg_col", "type": "lseg"}, - {"name": "box_col", "type": "box"}, - {"name": "polygon_col", "type": "polygon"}, - {"name": "circle_col", "type": "circle"}, - {"name": "xml_col", "type": "xml"}, - {"name": "composite_col", "type": "person_composite"}, - {"name": "int_range_col", "type": "int4range"}, - ], - "name": Unsupported.table_name} + table_spec = { + "columns": [ + {"name": "interval_col", "type": "INTERVAL"}, + {"name": "bit_string_col", "type": "bit(5)"}, + {"name": "bytea_col", "type": "bytea"}, + {"name": "point_col", "type": "point"}, + {"name": "line_col", "type": "line"}, + {"name": "lseg_col", "type": "lseg"}, + {"name": "box_col", "type": "box"}, + {"name": "polygon_col", "type": "polygon"}, + {"name": "circle_col", "type": "circle"}, + {"name": "xml_col", "type": "xml"}, + {"name": "composite_col", "type": "person_composite"}, + {"name": "int_range_col", "type": "int4range"}, + ], + "name": Unsupported.table_name, + } with get_test_connection() as conn: cur = conn.cursor() cur.execute(""" DROP TYPE IF EXISTS person_composite CASCADE """) - cur.execute(""" CREATE TYPE person_composite AS (age int, name text) """) + cur.execute( + """ CREATE TYPE person_composite AS (age int, name text) """ + ) ensure_test_table(table_spec) def test_catalog(self): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) - chicken_streams = [s for s in streams if s['tap_stream_id'] == "public-CHICKEN TIMES"] + chicken_streams = [ + s for s in streams if s["tap_stream_id"] == "public-CHICKEN TIMES" + ] self.assertEqual(len(chicken_streams), 1) stream_dict = chicken_streams[0] - stream_dict.get('metadata').sort(key=lambda md: md['breadcrumb']) + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) - self.assertEqual(metadata.to_map(stream_dict.get('metadata')), - {(): {'is-view': False, 'table-key-properties': [], 'row-count': 0, 'schema-name': 'public', 'database-name': 'postgres'}, - ('properties', 'bytea_col'): {'sql-datatype': 'bytea', 'selected-by-default': False, 'inclusion': 'unsupported'}, - ('properties', 'bit_string_col'): {'sql-datatype': 'bit(5)', 'selected-by-default': False, 'inclusion': 'unsupported'}, - ('properties', 'line_col'): {'sql-datatype': 'line', 'selected-by-default': False, 'inclusion': 'unsupported'}, - ('properties', 'xml_col'): {'sql-datatype': 'xml', 'selected-by-default': False, 'inclusion': 'unsupported'}, - ('properties', 'int_range_col'): {'sql-datatype': 'int4range', 'selected-by-default': False, 'inclusion': 'unsupported'}, - ('properties', 'circle_col'): {'sql-datatype': 'circle', 'selected-by-default': False, 'inclusion': 'unsupported'}, - ('properties', 'polygon_col'): {'sql-datatype': 'polygon', 'selected-by-default': False, 'inclusion': 'unsupported'}, - ('properties', 'box_col'): {'sql-datatype': 'box', 'selected-by-default': False, 'inclusion': 'unsupported'}, - ('properties', 'lseg_col'): {'sql-datatype': 'lseg', 'selected-by-default': False, 'inclusion': 'unsupported'}, - ('properties', 'composite_col'): {'sql-datatype': 'person_composite', 'selected-by-default': False, 'inclusion': 'unsupported'}, - ('properties', 'interval_col'): {'sql-datatype': 'interval', 'selected-by-default': False, 'inclusion': 'unsupported'}, - ('properties', 'point_col'): {'sql-datatype': 'point', 'selected-by-default': False, 'inclusion': 'unsupported'}} + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "is-view": False, + "table-key-properties": [], + "row-count": 0, + "schema-name": "public", + "database-name": "postgres", + }, + ("properties", "bytea_col"): { + "sql-datatype": "bytea", + "selected-by-default": False, + "inclusion": "unsupported", + }, + ("properties", "bit_string_col"): { + "sql-datatype": "bit(5)", + "selected-by-default": False, + "inclusion": "unsupported", + }, + ("properties", "line_col"): { + "sql-datatype": "line", + "selected-by-default": False, + "inclusion": "unsupported", + }, + ("properties", "xml_col"): { + "sql-datatype": "xml", + "selected-by-default": False, + "inclusion": "unsupported", + }, + ("properties", "int_range_col"): { + "sql-datatype": "int4range", + "selected-by-default": False, + "inclusion": "unsupported", + }, + ("properties", "circle_col"): { + "sql-datatype": "circle", + "selected-by-default": False, + "inclusion": "unsupported", + }, + ("properties", "polygon_col"): { + "sql-datatype": "polygon", + "selected-by-default": False, + "inclusion": "unsupported", + }, + ("properties", "box_col"): { + "sql-datatype": "box", + "selected-by-default": False, + "inclusion": "unsupported", + }, + ("properties", "lseg_col"): { + "sql-datatype": "lseg", + "selected-by-default": False, + "inclusion": "unsupported", + }, + ("properties", "composite_col"): { + "sql-datatype": "person_composite", + "selected-by-default": False, + "inclusion": "unsupported", + }, + ("properties", "interval_col"): { + "sql-datatype": "interval", + "selected-by-default": False, + "inclusion": "unsupported", + }, + ("properties", "point_col"): { + "sql-datatype": "point", + "selected-by-default": False, + "inclusion": "unsupported", + }, + }, ) diff --git a/tests/utils.py b/tests/utils.py index 881cfd3c..391e39e8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,58 +10,77 @@ LOGGER = get_logger() -def get_test_connection_config(target_db='postgres'): - missing_envs = [x for x in [os.getenv('TAP_POSTGRES_HOST'), - os.getenv('TAP_POSTGRES_USER'), - os.getenv('TAP_POSTGRES_PASSWORD'), - os.getenv('TAP_POSTGRES_PORT')] if x == None] - if len(missing_envs) != 0: - raise Exception("set TAP_POSTGRES_HOST, TAP_POSTGRES_USER, TAP_POSTGRES_PASSWORD, TAP_POSTGRES_PORT") - conn_config = {'host': os.environ.get('TAP_POSTGRES_HOST'), - 'user': os.environ.get('TAP_POSTGRES_USER'), - 'password': os.environ.get('TAP_POSTGRES_PASSWORD'), - 'port': os.environ.get('TAP_POSTGRES_PORT'), - 'dbname': target_db} +def get_test_connection_config(target_db="postgres"): + missing_envs = [ + x + for x in [ + os.getenv("TAP_POSTGRES_HOST"), + os.getenv("TAP_POSTGRES_USER"), + os.getenv("TAP_POSTGRES_PASSWORD"), + os.getenv("TAP_POSTGRES_PORT"), + ] + if x == None + ] + if len(missing_envs) != 0: + raise Exception( + "set TAP_POSTGRES_HOST, TAP_POSTGRES_USER, TAP_POSTGRES_PASSWORD, TAP_POSTGRES_PORT" + ) + + conn_config = { + "host": os.environ.get("TAP_POSTGRES_HOST"), + "user": os.environ.get("TAP_POSTGRES_USER"), + "password": os.environ.get("TAP_POSTGRES_PASSWORD"), + "port": os.environ.get("TAP_POSTGRES_PORT"), + "dbname": target_db, + } return conn_config -def get_test_connection(target_db='postgres'): + +def get_test_connection(target_db="postgres"): conn_config = get_test_connection_config(target_db) - conn_string = "host='{}' dbname='{}' user='{}' password='{}' port='{}'".format(conn_config['host'], - conn_config['dbname'], - conn_config['user'], - conn_config['password'], - conn_config['port']) - LOGGER.info("connecting to {}".format(conn_config['host'])) + conn_string = "host='{}' dbname='{}' user='{}' password='{}' port='{}'".format( + conn_config["host"], + conn_config["dbname"], + conn_config["user"], + conn_config["password"], + conn_config["port"], + ) + LOGGER.info("connecting to {}".format(conn_config["host"])) conn = psycopg2.connect(conn_string) conn.autocommit = True return conn + def build_col_sql(col, cur): - if col.get('quoted'): - col_sql = "{} {}".format(quote_ident(col['name'], cur), col['type']) + if col.get("quoted"): + col_sql = "{} {}".format(quote_ident(col["name"], cur), col["type"]) else: - col_sql = "{} {}".format(col['name'], col['type']) + col_sql = "{} {}".format(col["name"], col["type"]) return col_sql + def build_table(table, cur): - create_sql = "CREATE TABLE {}\n".format(quote_ident(table['name'], cur)) - col_sql = map(lambda c: build_col_sql(c, cur), table['columns']) - pks = [c['name'] for c in table['columns'] if c.get('primary_key')] + create_sql = "CREATE TABLE {}\n".format(quote_ident(table["name"], cur)) + col_sql = map(lambda c: build_col_sql(c, cur), table["columns"]) + pks = [c["name"] for c in table["columns"] if c.get("primary_key")] if len(pks) != 0: - pk_sql = ",\n CONSTRAINT {} PRIMARY KEY({})".format(quote_ident(table['name'] + "_pk", cur), " ,".join(pks)) + pk_sql = ",\n CONSTRAINT {} PRIMARY KEY({})".format( + quote_ident(table["name"] + "_pk", cur), " ,".join(pks) + ) else: - pk_sql = "" + pk_sql = "" sql = "{} ( {} {})".format(create_sql, ",\n".join(col_sql), pk_sql) return sql -def ensure_test_table(table_spec, target_db='postgres'): + +def ensure_test_table(table_spec, target_db="postgres"): with get_test_connection(target_db) as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: sql = """SELECT * @@ -69,43 +88,50 @@ def ensure_test_table(table_spec, target_db='postgres'): WHERE table_schema = 'public' AND table_name = %s""" - cur.execute(sql, - [table_spec['name']]) + cur.execute(sql, [table_spec["name"]]) old_table = cur.fetchall() if len(old_table) != 0: - cur.execute('DROP TABLE {} cascade'.format(quote_ident(table_spec['name'], cur))) + cur.execute( + "DROP TABLE {} cascade".format(quote_ident(table_spec["name"], cur)) + ) sql = build_table(table_spec, cur) LOGGER.info("create table sql: %s", sql) cur.execute(sql) + def unselect_column(our_stream, col): - md = metadata.to_map(our_stream['metadata']) - md.get(('properties', col))['selected'] = False - our_stream['metadata'] = metadata.to_list(md) + md = metadata.to_map(our_stream["metadata"]) + md.get(("properties", col))["selected"] = False + our_stream["metadata"] = metadata.to_list(md) return our_stream + def set_replication_method_for_stream(stream, method): - new_md = metadata.to_map(stream['metadata']) + new_md = metadata.to_map(stream["metadata"]) old_md = new_md.get(()) - old_md.update({'replication-method': method}) + old_md.update({"replication-method": method}) - stream['metadata'] = metadata.to_list(new_md) + stream["metadata"] = metadata.to_list(new_md) return stream + def select_all_of_stream(stream): - new_md = metadata.to_map(stream['metadata']) + new_md = metadata.to_map(stream["metadata"]) old_md = new_md.get(()) - old_md.update({'selected': True}) - for col_name, col_schema in stream['schema']['properties'].items(): - #explicitly select column if it is not automatic - if new_md.get(('properties', col_name)).get('inclusion') != 'automatic' and new_md.get(('properties', col_name)).get('inclusion') != 'unsupported': - old_md = new_md.get(('properties', col_name)) - old_md.update({'selected' : True}) - - stream['metadata'] = metadata.to_list(new_md) + old_md.update({"selected": True}) + for col_name, col_schema in stream["schema"]["properties"].items(): + # explicitly select column if it is not automatic + if ( + new_md.get(("properties", col_name)).get("inclusion") != "automatic" + and new_md.get(("properties", col_name)).get("inclusion") != "unsupported" + ): + old_md = new_md.get(("properties", col_name)) + old_md.update({"selected": True}) + + stream["metadata"] = metadata.to_list(new_md) return stream @@ -115,18 +141,18 @@ def crud_up_value(value): elif isinstance(value, int): return str(value) elif isinstance(value, float): - if (value == float('+inf')): + if value == float("+inf"): return "'+Inf'" - elif (value == float('-inf')): + elif value == float("-inf"): return "'-Inf'" - elif (math.isnan(value)): + elif math.isnan(value): return "'NaN'" else: return "{:f}".format(value) elif isinstance(value, decimal.Decimal): return "{:f}".format(value) elif value is None: - return 'NULL' + return "NULL" elif isinstance(value, datetime.datetime) and value.tzinfo is None: return "TIMESTAMP '{}'".format(str(value)) elif isinstance(value, datetime.datetime): @@ -136,18 +162,20 @@ def crud_up_value(value): else: raise Exception("crud_up_value does not yet support {}".format(value.__class__)) + def insert_record(cursor, table_name, data): our_keys = list(data.keys()) our_keys.sort() - our_values = list(map( lambda k: data.get(k), our_keys)) - + our_values = list(map(lambda k: data.get(k), our_keys)) columns_sql = ", \n".join(map(lambda k: quote_ident(k, cursor), our_keys)) value_sql = ",".join(["%s" for i in range(len(our_keys))]) insert_sql = """ INSERT INTO {} ( {} ) - VALUES ( {} )""".format(quote_ident(table_name, cursor), columns_sql, value_sql) + VALUES ( {} )""".format( + quote_ident(table_name, cursor), columns_sql, value_sql + ) LOGGER.info("INSERT: {}".format(insert_sql)) cursor.execute(insert_sql, list(map(crud_up_value, our_values))) @@ -170,25 +198,30 @@ def verify_crud_messages(that, caught_messages, pks): that.assertTrue(isinstance(caught_messages[12], singer.StateMessage)) that.assertTrue(isinstance(caught_messages[13], singer.StateMessage)) - #schema includes scn && _sdc_deleted_at because we selected logminer as our replication method - that.assertEqual({"type" : ['integer']}, caught_messages[0].schema.get('properties').get('scn') ) - that.assertEqual({"type" : ['null', 'string'], "format" : "date-time"}, caught_messages[0].schema.get('properties').get('_sdc_deleted_at') ) + # schema includes scn && _sdc_deleted_at because we selected logminer as our replication method + that.assertEqual( + {"type": ["integer"]}, caught_messages[0].schema.get("properties").get("scn") + ) + that.assertEqual( + {"type": ["null", "string"], "format": "date-time"}, + caught_messages[0].schema.get("properties").get("_sdc_deleted_at"), + ) that.assertEqual(pks, caught_messages[0].key_properties) - #verify first STATE message - bookmarks_1 = caught_messages[2].value.get('bookmarks')['ROOT-CHICKEN'] + # verify first STATE message + bookmarks_1 = caught_messages[2].value.get("bookmarks")["ROOT-CHICKEN"] that.assertIsNotNone(bookmarks_1) - bookmarks_1_scn = bookmarks_1.get('scn') - bookmarks_1_version = bookmarks_1.get('version') + bookmarks_1_scn = bookmarks_1.get("scn") + bookmarks_1_version = bookmarks_1.get("version") that.assertIsNotNone(bookmarks_1_scn) that.assertIsNotNone(bookmarks_1_version) - #verify STATE message after UPDATE - bookmarks_2 = caught_messages[6].value.get('bookmarks')['ROOT-CHICKEN'] + # verify STATE message after UPDATE + bookmarks_2 = caught_messages[6].value.get("bookmarks")["ROOT-CHICKEN"] that.assertIsNotNone(bookmarks_2) - bookmarks_2_scn = bookmarks_2.get('scn') - bookmarks_2_version = bookmarks_2.get('version') + bookmarks_2_scn = bookmarks_2.get("scn") + bookmarks_2_version = bookmarks_2.get("version") that.assertIsNotNone(bookmarks_2_scn) that.assertIsNotNone(bookmarks_2_version) that.assertGreater(bookmarks_2_scn, bookmarks_1_scn) From 326fb81321c40964bad96c622ad0f0744b94e711 Mon Sep 17 00:00:00 2001 From: Jonas Kalderstam Date: Thu, 20 Oct 2022 15:34:26 +0200 Subject: [PATCH 2/5] Initial test for range conversion --- tests/test_conversions.py | 72 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/test_conversions.py diff --git a/tests/test_conversions.py b/tests/test_conversions.py new file mode 100644 index 00000000..dbc9696d --- /dev/null +++ b/tests/test_conversions.py @@ -0,0 +1,72 @@ +import unittest + +from singer import get_logger, metadata + +import tap_postgres +from tests.utils import ( + get_test_connection, + ensure_test_table, + get_test_connection_config, +) + +LOGGER = get_logger() + + +def do_not_dump_catalog(catalog): + pass + + +tap_postgres.dump_catalog = do_not_dump_catalog + + +class Conversions(unittest.TestCase): + maxDiff = None + table_name = "COW FEEDS" + + def setUp(self): + table_spec = { + "columns": [ + {"name": "int_range_col", "type": "int4range"}, + ], + "name": Conversions.table_name, + } + + ensure_test_table(table_spec) + + def test_catalog(self): + conn_config = get_test_connection_config() + streams = tap_postgres.do_discovery(conn_config) + cow_streams = [s for s in streams if s["tap_stream_id"] == "public-COW FEEDS"] + + self.assertEqual(len(cow_streams), 1) + stream_dict = cow_streams[0] + stream_dict.get("metadata").sort(key=lambda md: md["breadcrumb"]) + + # TODO fix it + self.assertEqual( + metadata.to_map(stream_dict.get("metadata")), + { + (): { + "is-view": False, + "table-key-properties": [], + "row-count": 0, + "schema-name": "public", + "database-name": "postgres", + }, + ("properties", "int_range_col"): { + "sql-datatype": "int4range", + "selected-by-default": False, + "inclusion": "unsupported", + }, + ("properties", "int_range_col_lower"): { + "sql-datatype": "int", + "selected-by-default": True, + "inclusion": "automatic", + }, + ("properties", "int_range_col_upper"): { + "sql-datatype": "int", + "selected-by-default": True, + "inclusion": "automatic", + }, + }, + ) From da337e4bdd51083ff19e660c99170c96452a9da9 Mon Sep 17 00:00:00 2001 From: Jonas Kalderstam Date: Thu, 20 Oct 2022 15:38:08 +0200 Subject: [PATCH 3/5] Fix CI --- .github/workflows/ci.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fb6f3686..136402ff 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -35,7 +35,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Setup virtual environment - run: make venv + #run: make venv + run: pip install poetry && poetry install - name: Pylinting run: make pylint From ad3d23386e52ce95e8018f80174bb856499ab108 Mon Sep 17 00:00:00 2001 From: Jonas Kalderstam Date: Thu, 20 Oct 2022 15:41:13 +0200 Subject: [PATCH 4/5] Fixed lint --- tap_postgres/sync_strategies/incremental.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tap_postgres/sync_strategies/incremental.py b/tap_postgres/sync_strategies/incremental.py index a3d047f6..d81c188d 100644 --- a/tap_postgres/sync_strategies/incremental.py +++ b/tap_postgres/sync_strategies/incremental.py @@ -147,7 +147,7 @@ def sync_table(conn_info, stream, state, desired_columns, md_map): "replication_key_value", record_message.record[replication_key], ) - except KeyError as e: + except KeyError: # Replication key not present in table - treat like None pass From 7f9e17f1d114fc69c3ed6efd758dea3c1ac2310c Mon Sep 17 00:00:00 2001 From: Jonas Kalderstam Date: Thu, 20 Oct 2022 15:55:06 +0200 Subject: [PATCH 5/5] Just a test --- tap_postgres/__init__.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tap_postgres/__init__.py b/tap_postgres/__init__.py index c986134b..77f9d434 100644 --- a/tap_postgres/__init__.py +++ b/tap_postgres/__init__.py @@ -294,6 +294,22 @@ def register_type_adapters(conn_config): """ with post_db.open_connection(conn_config) as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + # typname + ## int4range + ## tstzrange + ## geometry + ## point + + # int4range + cur.execute("SELECT typarray FROM pg_type where typname = 'int4range'") + int4range_oid = cur.fetchone() + if int4range_oid: + psycopg2.extensions.register_type( + psycopg2.extensions.new_array_type( + (int4range_oid[0],), "INT4RANGE", psycopg2.STRING + ) + ) + # citext[] cur.execute("SELECT typarray FROM pg_type where typname = 'citext'") citext_array_oid = cur.fetchone()