Skip to content

Commit

Permalink
[ISSUE airbytehq#23994] make MessageGrouper use AirbyteEntrypoint (ai…
Browse files Browse the repository at this point in the history
…rbytehq#25402)

* [ISSUE airbytehq#23994] make MessageGrouper use AirbyteEntrypoint

* [ISSUE airbytehq#23994] code review
  • Loading branch information
maxi297 authored Apr 24, 2023
1 parent 08ab4a5 commit 4d65fa1
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
StreamReadSlices,
StreamReadSlicesInner,
)
from airbyte_cdk.entrypoint import AirbyteEntrypoint
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer
Expand Down Expand Up @@ -167,7 +168,7 @@ def _read_stream(self, source: DeclarativeSource, config: Mapping[str, Any], con
# the generator can raise an exception
# iterate over the generated messages. if next raise an exception, catch it and yield it as an AirbyteLogMessage
try:
yield from source.read(logger=self.logger, config=config, catalog=configured_catalog, state={})
yield from AirbyteEntrypoint(source).read(source.spec(self.logger), config, configured_catalog, {})
except Exception as e:
error_message = f"{e.args[0] if len(e.args) > 0 else str(e)}"
yield AirbyteTracedException.from_exception(e, message=error_message).as_airbyte_message()
Expand Down
89 changes: 57 additions & 32 deletions airbyte-cdk/python/airbyte_cdk/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import os.path
import sys
import tempfile
from typing import Iterable, List
from typing import Any, Iterable, List, Mapping

from airbyte_cdk.connector import TConfig
from airbyte_cdk.exception_handler import init_uncaught_exception_handler
from airbyte_cdk.logger import init_logger
from airbyte_cdk.models import AirbyteMessage, Status, Type
from airbyte_cdk.models.airbyte_protocol import ConnectorSpecification
from airbyte_cdk.sources import Source
from airbyte_cdk.sources.source import TCatalog, TState
from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit, split_config
from airbyte_cdk.utils.airbyte_secrets_utils import get_secrets, update_secrets
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
Expand Down Expand Up @@ -85,45 +87,68 @@ def run(self, parsed_args: argparse.Namespace) -> Iterable[str]:
raw_config = self.source.read_config(parsed_args.config)
config = self.source.configure(raw_config, temp_dir)

# Now that we have the config, we can use it to get a list of ai airbyte_secrets
# that we should filter in logging to avoid leaking secrets
config_secrets = get_secrets(source_spec.connectionSpecification, config)
update_secrets(config_secrets)

# Remove internal flags from config before validating so
# jsonschema's additionalProperties flag wont fail the validation
connector_config, _ = split_config(config)
if self.source.check_config_against_spec or cmd == "check":
try:
check_config_against_spec_or_exit(connector_config, source_spec)
except AirbyteTracedException as traced_exc:
connection_status = traced_exc.as_connection_status_message()
if connection_status and cmd == "check":
yield connection_status.json(exclude_unset=True)
return
raise traced_exc

if cmd == "check":
check_result = self.source.check(self.logger, config)
if check_result.status == Status.SUCCEEDED:
self.logger.info("Check succeeded")
else:
self.logger.error("Check failed")

output_message = AirbyteMessage(type=Type.CONNECTION_STATUS, connectionStatus=check_result).json(exclude_unset=True)
yield output_message
yield from map(AirbyteEntrypoint.airbyte_message_to_string, self.check(source_spec, config))
elif cmd == "discover":
catalog = self.source.discover(self.logger, config)
yield AirbyteMessage(type=Type.CATALOG, catalog=catalog).json(exclude_unset=True)
yield from map(AirbyteEntrypoint.airbyte_message_to_string, self.discover(source_spec, config))
elif cmd == "read":
config_catalog = self.source.read_catalog(parsed_args.catalog)
state = self.source.read_state(parsed_args.state)
generator = self.source.read(self.logger, config, config_catalog, state)
for message in generator:
yield message.json(exclude_unset=True)

yield from map(AirbyteEntrypoint.airbyte_message_to_string, self.read(source_spec, config, config_catalog, state))
else:
raise Exception("Unexpected command " + cmd)

def check(self, source_spec: ConnectorSpecification, config: TConfig) -> Iterable[AirbyteMessage]:
self.set_up_secret_filter(config, source_spec.connectionSpecification)
try:
self.validate_connection(source_spec, config)
except AirbyteTracedException as traced_exc:
connection_status = traced_exc.as_connection_status_message()
if connection_status:
yield connection_status
return

check_result = self.source.check(self.logger, config)
if check_result.status == Status.SUCCEEDED:
self.logger.info("Check succeeded")
else:
self.logger.error("Check failed")

yield AirbyteMessage(type=Type.CONNECTION_STATUS, connectionStatus=check_result)

def discover(self, source_spec: ConnectorSpecification, config: TConfig) -> Iterable[AirbyteMessage]:
self.set_up_secret_filter(config, source_spec.connectionSpecification)
if self.source.check_config_against_spec:
self.validate_connection(source_spec, config)
catalog = self.source.discover(self.logger, config)
yield AirbyteMessage(type=Type.CATALOG, catalog=catalog)

def read(self, source_spec: ConnectorSpecification, config: TConfig, catalog: TCatalog, state: TState) -> Iterable[AirbyteMessage]:
self.set_up_secret_filter(config, source_spec.connectionSpecification)
if self.source.check_config_against_spec:
self.validate_connection(source_spec, config)

yield from self.source.read(self.logger, config, catalog, state)

@staticmethod
def validate_connection(source_spec: ConnectorSpecification, config: Mapping[str, Any]) -> None:
# Remove internal flags from config before validating so
# jsonschema's additionalProperties flag won't fail the validation
connector_config, _ = split_config(config)
check_config_against_spec_or_exit(connector_config, source_spec)

@staticmethod
def set_up_secret_filter(config, connection_specification: Mapping[str, Any]):
# Now that we have the config, we can use it to get a list of ai airbyte_secrets
# that we should filter in logging to avoid leaking secrets
config_secrets = get_secrets(connection_specification, config)
update_secrets(config_secrets)

@staticmethod
def airbyte_message_to_string(airbyte_message: AirbyteMessage) -> str:
return airbyte_message.json(exclude_unset=True)


def launch(source: Source, args: List[str]):
source_entrypoint = AirbyteEntrypoint(source)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import dataclasses
import json
import logging
from unittest import mock
from unittest.mock import patch

Expand All @@ -30,6 +31,7 @@
AirbyteStream,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
ConnectorSpecification,
DestinationSyncMode,
Level,
SyncMode,
Expand Down Expand Up @@ -82,6 +84,16 @@
},
],
"check": {"type": "CheckStream", "stream_names": ["lists"]},
"spec": {
"connection_specification": {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"required": [],
"properties": {},
"additionalProperties": True
},
"type": "Spec"
}
}

RESOLVE_MANIFEST_CONFIG = {
Expand Down Expand Up @@ -300,6 +312,16 @@ def test_resolve_manifest(valid_resolve_manifest_config_file):
},
],
"check": {"type": "CheckStream", "stream_names": ["lists"]},
"spec": {
"connection_specification": {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"required": [],
"properties": {},
"additionalProperties": True
},
"type": "Spec"
}
}
assert resolved_manifest.record.data["manifest"] == expected_resolved_manifest
assert resolved_manifest.record.stream == "resolve_manifest"
Expand Down Expand Up @@ -364,6 +386,15 @@ class MockManifestDeclarativeSource:
def read(self, logger, config, catalog, state):
raise ValueError("error_message")

def spec(self, logger: logging.Logger) -> ConnectorSpecification:
connector_specification = mock.Mock()
connector_specification.connectionSpecification = {}
return connector_specification

@property
def check_config_against_spec(self):
return False

stack_trace = "a stack trace"
mock_from_exception.return_value = stack_trace

Expand Down
Loading

0 comments on commit 4d65fa1

Please sign in to comment.