diff --git a/examples/async/README.md b/examples/async/README.md index 140f906ce..e5fbe3790 100644 --- a/examples/async/README.md +++ b/examples/async/README.md @@ -29,6 +29,12 @@ You should get the following result: {"pipeline":{"computation1":false,"computation2":true}} ``` +## Tracking + +This uses the async tracker if the [ui](https://hamilton.dagworks.io/en/latest/concepts/ui/) +is running on port 8241 -- see [fastapi_example.py](fastapi_example.py) for the code. +If it is not running it will proceed anyway without tracking. + ## How it works diff --git a/examples/async/async_module.py b/examples/async/async_module.py index 33432b8c7..ad52c7a1b 100644 --- a/examples/async/async_module.py +++ b/examples/async/async_module.py @@ -24,17 +24,17 @@ def bar(request_raw: dict) -> str: return request_raw.get("bar", "baz") -async def computation1(foo: str, some_data: dict) -> bool: - await asyncio.sleep(1) - return False - - async def some_data() -> dict: async with aiohttp.ClientSession() as session: async with session.get("http://httpbin.org/get") as resp: return await resp.json() +async def computation1(foo: str, some_data: dict) -> bool: + await asyncio.sleep(1) + return False + + async def computation2(bar: str) -> bool: await asyncio.sleep(1) return True diff --git a/examples/async/fastapi_example.py b/examples/async/fastapi_example.py index 88c87ca07..17118ab98 100644 --- a/examples/async/fastapi_example.py +++ b/examples/async/fastapi_example.py @@ -1,13 +1,61 @@ +import logging +from contextlib import asynccontextmanager + +import aiohttp import async_module import fastapi +import hamilton_sdk.adapters +from aiohttp import client_exceptions -from hamilton import base from hamilton.experimental import h_async -app = fastapi.FastAPI() +logger = logging.getLogger(__name__) # can instantiate a driver once for the life of the app: -dr = h_async.AsyncDriver({}, async_module, result_builder=base.DictResult()) +dr = None + + +async def _tracking_server_running(): + """Quickly tells if the tracking server is up and running""" + async with aiohttp.ClientSession() as session: + try: + async with session.get("http://localhost:8241/api/v1/ping") as response: + if response.status == 200: + return True + else: + return False + except client_exceptions.ClientConnectionError: + return False + + +@asynccontextmanager +async def lifespan(app: fastapi.FastAPI): + global dr + is_server_running = await _tracking_server_running() + if not is_server_running: + logger.warning( + "Tracking server is not running, skipping telemetry. To run the telemetry server, run hamilton ui. " + "Note you must have a project with ID 1 if it is running -- if not, you can change the project " + "ID in this file or create a new one from the UI" + ) + adapters = [] + if is_server_running: + tracker_async = hamilton_sdk.adapters.AsyncHamiltonTracker( + project_id=1, + username="elijah", + dag_name="async_tracker", + ) + tracker_sync = hamilton_sdk.adapters.HamiltonTracker( + project_id=1, + username="elijah", + dag_name="sync_tracker_dont_use_this", + ) + adapters = [tracker_async, tracker_sync] + dr = await h_async.Builder().with_modules(async_module).with_adapters(*adapters).build() + yield + + +app = fastapi.FastAPI(lifespan=lifespan) @app.post("/execute") diff --git a/hamilton/driver.py b/hamilton/driver.py index 1808536c9..abc576e2c 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -403,6 +403,7 @@ def __init__( if _graph_executor is None: _graph_executor = DefaultGraphExecutor(self.adapter) self.graph_executor = _graph_executor + self.config = config except Exception as e: error = telemetry.sanitize_error(*sys.exc_info()) logger.error(SLACK_ERROR_MESSAGE) diff --git a/hamilton/experimental/h_async.py b/hamilton/experimental/h_async.py index c1ad34d5f..406632735 100644 --- a/hamilton/experimental/h_async.py +++ b/hamilton/experimental/h_async.py @@ -3,17 +3,20 @@ import logging import sys import time -import types import typing +import uuid from types import ModuleType from typing import Any, Dict, Optional, Tuple -from hamilton import base, driver, node, telemetry +import hamilton.lifecycle.base as lifecycle_base +from hamilton import base, driver, graph, lifecycle, node, telemetry +from hamilton.execution.graph_functions import create_error_message +from hamilton.io.materialization import ExtractorFactory, MaterializerFactory logger = logging.getLogger(__name__) -async def await_dict_of_tasks(task_dict: Dict[str, types.CoroutineType]) -> Dict[str, Any]: +async def await_dict_of_tasks(task_dict: Dict[str, typing.Awaitable]) -> Dict[str, Any]: """Util to await a dictionary of tasks as asyncio.gather is kind of garbage""" keys = sorted(task_dict.keys()) coroutines = [task_dict[key] for key in keys] @@ -33,10 +36,14 @@ async def process_value(val: Any) -> Any: return await val -class AsyncGraphAdapter(base.SimplePythonDataFrameGraphAdapter): +class AsyncGraphAdapter(lifecycle_base.BaseDoNodeExecute, lifecycle.ResultBuilder): """Graph adapter for use with the :class:`AsyncDriver` class.""" - def __init__(self, result_builder: base.ResultMixin = None): + def __init__( + self, + result_builder: base.ResultMixin = None, + async_lifecycle_adapters: Optional[lifecycle_base.LifecycleAdapterSet] = None, + ): """Creates an AsyncGraphAdapter class. Note this will *only* work with the AsyncDriver class. Some things to note: @@ -46,9 +53,22 @@ def __init__(self, result_builder: base.ResultMixin = None): because that function is called directly within the decorator, so we cannot await it. """ super(AsyncGraphAdapter, self).__init__() + self.adapter = ( + async_lifecycle_adapters + if async_lifecycle_adapters is not None + else lifecycle_base.LifecycleAdapterSet() + ) self.result_builder = result_builder if result_builder else base.PandasDataFrameResult() + self.is_initialized = False - def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any: + def do_node_execute( + self, + *, + run_id: str, + node_: node.Node, + kwargs: typing.Dict[str, typing.Any], + task_id: Optional[str] = None, + ) -> typing.Any: """Executes a node. Note this doesn't actually execute it -- rather, it returns a task. This does *not* use async def, as we want it to be awaited on later -- this await is done in processing parameters of downstream functions/final results. We can ensure that as @@ -57,34 +77,111 @@ def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> Note that this assumes that everything is awaitable, even if it isn't. In that case, it just wraps it in one. + :param task_id: + :param node_: + :param run_id: :param node: Node to wrap :param kwargs: Keyword arguments (either coroutines or raw values) to call it with :return: A task """ - callabl = node.callable + callabl = node_.callable async def new_fn(fn=callabl, **fn_kwargs): task_dict = {key: process_value(value) for key, value in fn_kwargs.items()} fn_kwargs = await await_dict_of_tasks(task_dict) - if inspect.iscoroutinefunction(fn): - return await fn(**fn_kwargs) - return fn(**fn_kwargs) + error = None + result = None + success = True + pre_node_execute_errored = False + try: + if self.adapter.does_hook("pre_node_execute", is_async=True): + try: + await self.adapter.call_all_lifecycle_hooks_async( + "pre_node_execute", + run_id=run_id, + node_=node_, + kwargs=fn_kwargs, + task_id=task_id, + ) + except Exception as e: + pre_node_execute_errored = True + raise e + # TODO -- consider how to use node execution methods in the future + # This is messy as it is a method called within a method... + # if self.adapter.does_method("do_node_execute", is_async=False): + # result = self.adapter.call_lifecycle_method_sync( + # "do_node_execute", + # run_id=run_id, + # node_=node_, + # kwargs=kwargs, + # task_id=task_id, + # ) + # else: + + result = ( + await fn(**fn_kwargs) if asyncio.iscoroutinefunction(fn) else fn(**fn_kwargs) + ) + except Exception as e: + success = False + error = e + step = "[pre-node-execute:async]" if pre_node_execute_errored else "" + message = create_error_message(kwargs, node_, step) + logger.exception(message) + raise + finally: + if not pre_node_execute_errored and self.adapter.does_hook( + "post_node_execute", is_async=True + ): + try: + await self.adapter.call_all_lifecycle_hooks_async( + "post_node_execute", + run_id=run_id, + node_=node_, + kwargs=fn_kwargs, + success=success, + error=error, + result=result, + task_id=task_id, + ) + except Exception: + message = create_error_message(kwargs, node_, "[post-node-execute]") + logger.exception(message) + raise + + return result coroutine = new_fn(**kwargs) task = asyncio.create_task(coroutine) return task - def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any: - """Currently this is a no-op -- it just delegates to the resultsbuilder. - That said, we *could* make it async, but it feels wrong -- this will just be - called after `raw_execute`. - - :param outputs: Outputs (awaited) from the graph. - :return: The final results. - """ + def build_result(self, **outputs: Any) -> Any: return self.result_builder.build_result(**outputs) +def separate_sync_from_async( + adapters: typing.List[lifecycle.LifecycleAdapter], +) -> Tuple[typing.List[lifecycle.LifecycleAdapter], typing.List[lifecycle.LifecycleAdapter]]: + """Separates the sync and async adapters from a list of adapters. + Note this only works with hooks -- we'll be dealing with methods later. + + :param adapters: List of adapters + :return: Tuple of sync adapters, async adapters + """ + + adapter_set = lifecycle_base.LifecycleAdapterSet(*adapters) + # this is using internal(ish) fields (.sync_hooks/.async_hooks) -- we should probably expose it + # For now this is OK + # Note those are dict[hook_name, list[hook]], so we have to flatten + return ( + [sync_adapter for adapters in adapter_set.sync_hooks.values() for sync_adapter in adapters], + [ + async_adapter + for adapters in adapter_set.async_hooks.values() + for async_adapter in adapters + ], + ) + + class AsyncDriver(driver.Driver): """Async driver. This is a driver that uses the AsyncGraphAdapter to execute the graph. @@ -95,16 +192,75 @@ class AsyncDriver(driver.Driver): """ - def __init__(self, config, *modules, result_builder: Optional[base.ResultMixin] = None): + def __init__( + self, + config, + *modules, + result_builder: Optional[base.ResultMixin] = None, + adapters: typing.List[lifecycle.LifecycleAdapter] = None, + ): """Instantiates an asynchronous driver. + You will also need to call `ainit` to initialize the driver if you have any hooks/adapters. + :param config: Config to build the graph :param modules: Modules to crawl for fns/graph nodes + :param adapters: Adapters to use for lifecycle methods. :param result_builder: Results mixin to compile the graph's final results. TBD whether this should be included in the long run. """ + if adapters is not None: + sync_adapters, async_adapters = separate_sync_from_async(adapters) + else: + # separate out so we know what the driver + sync_adapters = [] + async_adapters = [] + # we'll need to use this in multiple contexts so we'll keep it around for later + result_builders = [adapter for adapter in adapters if isinstance(adapter, base.ResultMixin)] + if result_builder is not None: + result_builders.append(result_builder) + if len(result_builders) > 1: + raise ValueError( + "You cannot pass more than one result builder to the async driver. " + "Please pass in a single result builder" + ) + elif len(result_builders) == 0: + result_builders = [base.DictResult()] + result_builder = result_builders[0] super(AsyncDriver, self).__init__( - config, *modules, adapter=AsyncGraphAdapter(result_builder=result_builder) + config, + *modules, + adapter=[ + # We pass in the async adapters here as this can call node-level hooks + # Otherwise we trust the driver/fn graph to call sync adapters + AsyncGraphAdapter( + result_builder=result_builder, + async_lifecycle_adapters=lifecycle_base.LifecycleAdapterSet(*async_adapters), + ), + # We pass in the sync adapters here as this can call + *sync_adapters, + *async_adapters, # note async adapters will not be called during synchronous execution -- this is for access later + ], ) + self.initialized = False + + async def ainit(self) -> "AsyncDriver": + """Initializes the driver when using async. This only exists for backwards compatibility. + In Hamilton 2.0, we will be using an asynchronous constructor. + See https://dev.to/akarshan/asynchronous-python-magic-how-to-create-awaitable-constructors-with-asyncmixin-18j5. + """ + if self.initialized: + # this way it can be called twice + return self + if self.adapter.does_hook("post_graph_construct", is_async=True): + await self.adapter.call_all_lifecycle_hooks_async( + "post_graph_construct", + graph=self.graph, + modules=self.graph_modules, + config=self.config, + ) + await self.adapter.ainit() + self.initialized = True + return self async def raw_execute( self, @@ -112,7 +268,7 @@ async def raw_execute( overrides: Dict[str, Any] = None, display_graph: bool = False, # don't care inputs: Dict[str, Any] = None, - run_id: str = None, + _fn_graph: graph.FunctionGraph = None, ) -> Dict[str, Any]: """Executes the graph, returning a dictionary of strings (node keys) to final results. @@ -120,20 +276,56 @@ async def raw_execute( :param overrides: Overrides for nodes :param display_graph: whether or not to display graph -- this is not supported. :param inputs: Inputs for DAG runtime calculation + :param _fn_graph: Function graph for compatibility with superclass -- unused :return: A dict of key -> result """ + assert _fn_graph is None, ( + "_fn_graph must not be provided " + "-- the only reason you'd do this is to use materialize(), which is not supported yet.." + ) + run_id = str(uuid.uuid4()) nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs) memoized_computation = dict() # memoized storage - self.graph.execute(nodes, memoized_computation, overrides, inputs) - if display_graph: - raise ValueError( - "display_graph=True is not supported for the async graph adapter. " - "Instead you should be using visualize_execution." + if self.adapter.does_hook("pre_graph_execute"): + await self.adapter.call_all_lifecycle_hooks_sync_and_async( + "pre_graph_execute", + run_id=run_id, + graph=self.graph, + final_vars=final_vars, + inputs=inputs, + overrides=overrides, ) - task_dict = { - key: asyncio.create_task(process_value(memoized_computation[key])) for key in final_vars - } - return await await_dict_of_tasks(task_dict) + results = None + error = None + success = False + try: + self.graph.execute(nodes, memoized_computation, overrides, inputs, run_id=run_id) + if display_graph: + raise ValueError( + "display_graph=True is not supported for the async graph adapter. " + "Instead you should be using visualize_execution." + ) + task_dict = { + key: asyncio.create_task(process_value(memoized_computation[key])) + for key in final_vars + } + results = await await_dict_of_tasks(task_dict) + success = True + except Exception as e: + error = e + success = False + raise e + finally: + if self.adapter.does_hook("post_graph_execute", is_async=None): + await self.adapter.call_all_lifecycle_hooks_sync_and_async( + "post_graph_execute", + run_id=run_id, + graph=self.graph, + success=success, + error=error, + results=results, + ) + return results async def execute( self, @@ -228,3 +420,40 @@ async def make_coroutine(): except Exception as e: if logger.isEnabledFor(logging.DEBUG): logger.error(f"Encountered error submitting async telemetry:\n{e}") + + +class Builder(driver.Builder): + """Builder for the async driver""" + + def __init__(self): + super(Builder, self).__init__() + + def _not_supported(self, method_name: str, additional_message: str = ""): + raise ValueError( + f"Builder().{method_name}() is not supported for the async driver. {additional_message}" + ) + + def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> "Builder": + self._not_supported("enable_dynamic_execution") + + def with_materializers( + self, *materializers: typing.Union[ExtractorFactory, MaterializerFactory] + ) -> "Builder": + self._not_supported("with_materializers") + + def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> "Builder": + self._not_supported( + "with_adapter", + "Use with_adapters instead to pass in the tracker (or other async hooks/methods)", + ) + + async def build(self): + adapter = self.adapters if self.adapters is not None else [] + if self.legacy_graph_adapter is not None: + adapter.append(self.legacy_graph_adapter) + + return await AsyncDriver( + self.config, + *self.modules, + adapters=adapter, + ).ainit() diff --git a/hamilton/lifecycle/base.py b/hamilton/lifecycle/base.py index 0d69079e2..8e6854d8f 100644 --- a/hamilton/lifecycle/base.py +++ b/hamilton/lifecycle/base.py @@ -745,11 +745,17 @@ def __init__(self, *adapters: LifecycleAdapter): :param adapters: Adapters to group together """ - self._adapters = list(adapters) + self._adapters = self._uniqify_adapters(adapters) self.sync_hooks, self.async_hooks = self._get_lifecycle_hooks() self.sync_methods, self.async_methods = self._get_lifecycle_methods() self.sync_validators = self._get_lifecycle_validators() + def _uniqify_adapters(self, adapters: List[LifecycleAdapter]) -> List[LifecycleAdapter]: + seen = set() + return [ + adapter for adapter in adapters if not (id(adapter) in seen or seen.add(id(adapter))) + ] + def _get_lifecycle_validators( self, ) -> Dict[str, List[LifecycleAdapter]]: @@ -811,7 +817,7 @@ def _get_lifecycle_methods( {method: list(adapters) for method, adapters in async_methods.items()}, ) - def does_hook(self, hook_name: str, is_async: bool) -> bool: + def does_hook(self, hook_name: str, is_async: Optional[bool] = None) -> bool: """Whether or not a hook is implemented by any of the adapters in this group. If this hook is not registered, this will raise a ValueError. @@ -819,21 +825,22 @@ def does_hook(self, hook_name: str, is_async: bool) -> bool: :param is_async: Whether you want the async version or not :return: True if this adapter set does this hook, False otherwise """ - if is_async and hook_name not in REGISTERED_ASYNC_HOOKS: + either = is_async is None + if (is_async or either) and hook_name not in REGISTERED_ASYNC_HOOKS: raise ValueError( f"Hook {hook_name} is not registered as an asynchronous lifecycle hook. " f"Registered hooks are {REGISTERED_ASYNC_HOOKS}" ) - if not is_async and hook_name not in REGISTERED_SYNC_HOOKS: + if ((not is_async) or either) and hook_name not in REGISTERED_SYNC_HOOKS: raise ValueError( f"Hook {hook_name} is not registered as a synchronous lifecycle hook. " f"Registered hooks are {REGISTERED_SYNC_HOOKS}" ) - if not is_async: - return hook_name in self.sync_hooks - return hook_name in self.async_hooks + has_async = hook_name in self.async_hooks + has_sync = hook_name in self.sync_hooks + return (has_async or has_sync) if either else has_async if is_async else has_sync - def does_method(self, method_name: str, is_async: bool) -> bool: + def does_method(self, method_name: str, is_async: Optional[bool] = None) -> bool: """Whether a method is implemented by any of the adapters in this group. If this method is not registered, this will raise a ValueError. @@ -841,19 +848,20 @@ def does_method(self, method_name: str, is_async: bool) -> bool: :param is_async: Whether you want the async version or not :return: True if this adapter set does this method, False otherwise """ - if is_async and method_name not in REGISTERED_ASYNC_METHODS: + either = is_async is None + if (is_async or either) and method_name not in REGISTERED_ASYNC_METHODS: raise ValueError( f"Method {method_name} is not registered as an asynchronous lifecycle method. " f"Registered methods are {REGISTERED_ASYNC_METHODS}" ) - if not is_async and method_name not in REGISTERED_SYNC_METHODS: + if ((not is_async) or either) and method_name not in REGISTERED_SYNC_METHODS: raise ValueError( f"Method {method_name} is not registered as a synchronous lifecycle method. " f"Registered methods are {REGISTERED_SYNC_METHODS}" ) - if not is_async: - return method_name in self.sync_methods - return method_name in self.async_methods + has_async = method_name in self.async_methods + has_sync = method_name in self.sync_methods + return (has_async or has_sync) if either else has_async if is_async else has_sync def does_validation(self, validator_name: str) -> bool: """Whether a validator is implemented by any of the adapters in this group. @@ -875,7 +883,7 @@ def call_all_lifecycle_hooks_sync(self, hook_name: str, **kwargs): :param hook_name: Name of the hooks to call :param kwargs: Keyword arguments to pass into the hook """ - for adapter in self.sync_hooks[hook_name]: + for adapter in self.sync_hooks.get(hook_name, []): getattr(adapter, hook_name)(**kwargs) async def call_all_lifecycle_hooks_async(self, hook_name: str, **kwargs): @@ -885,10 +893,14 @@ async def call_all_lifecycle_hooks_async(self, hook_name: str, **kwargs): :param kwargs: Keyword arguments to pass into the hook """ futures = [] - for adapter in self.async_hooks[hook_name]: + for adapter in self.async_hooks.get(hook_name, []): futures.append(getattr(adapter, hook_name)(**kwargs)) await asyncio.gather(*futures) + async def call_all_lifecycle_hooks_sync_and_async(self, hook_name: str, **kwargs): + self.call_all_lifecycle_hooks_sync(hook_name, **kwargs) + await self.call_all_lifecycle_hooks_async(hook_name, **kwargs) + def call_lifecycle_method_sync(self, method_name: str, **kwargs) -> Any: """Calls a lifecycle method in this group, by method name. @@ -947,3 +959,12 @@ def adapters(self) -> List[LifecycleAdapter]: :return: A list of adapters """ return self._adapters + + async def ainit(self): + """Asynchronously initializes the adapters in this group. This is so we can avoid having an async constructor + -- it is an implicit contract -- the async adapters are allowed one ainit() method that will be called by the driver. + """ + for adapter in self.adapters: + print(adapter) + if hasattr(adapter, "ainit"): + await adapter.ainit() diff --git a/plugin_tests/h_async/test_h_async.py b/plugin_tests/h_async/test_h_async.py index 6f670afed..0c0626755 100644 --- a/plugin_tests/h_async/test_h_async.py +++ b/plugin_tests/h_async/test_h_async.py @@ -6,6 +6,18 @@ from hamilton import base from hamilton.experimental import h_async +from hamilton.lifecycle.base import ( + BasePostGraphConstruct, + BasePostGraphConstructAsync, + BasePostGraphExecute, + BasePostGraphExecuteAsync, + BasePostNodeExecute, + BasePostNodeExecuteAsync, + BasePreGraphExecute, + BasePreGraphExecuteAsync, + BasePreNodeExecute, + BasePreNodeExecuteAsync, +) from .resources import simple_async_module @@ -102,3 +114,139 @@ async def test_driver_end_to_end_telemetry(send_event_json): await asyncio.gather(*[t for t in tasks if t != current_task]) assert send_event_json.called assert len(send_event_json.call_args_list) == 2 + + +@pytest.mark.asyncio +async def test_async_driver_end_to_end_async_lifecycle_methods(): + tracked_calls = [] + + class AsyncTrackingAdapter( + BasePostGraphConstructAsync, + BasePreGraphExecuteAsync, + BasePostGraphExecuteAsync, + BasePreNodeExecuteAsync, + BasePostNodeExecuteAsync, + ): + def __init__(self, calls: list, pause_time: float = 0.01): + self.pause_time = pause_time + self.calls = calls + + async def _pause(self): + return await asyncio.sleep(self.pause_time) + + async def pre_graph_execute(self, **kwargs): + await self._pause() + self.calls.append(("pre_graph_execute", kwargs)) + + async def post_graph_execute(self, **kwargs): + await self._pause() + self.calls.append(("post_graph_execute", kwargs)) + + async def pre_node_execute(self, **kwargs): + await self._pause() + self.calls.append(("pre_node_execute", kwargs)) + + async def post_node_execute(self, **kwargs): + await self._pause() + self.calls.append(("post_node_execute", kwargs)) + + async def post_graph_construct(self, **kwargs): + await self._pause() + self.calls.append(("post_graph_construct", kwargs)) + + adapter = AsyncTrackingAdapter(tracked_calls) + + dr = await h_async.AsyncDriver( + {}, simple_async_module, result_builder=base.DictResult(), adapters=[adapter] + ).ainit() + all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] + result = await dr.execute(final_vars=all_vars, inputs={"external_input": 1}) + hooks_called = [call[0] for call in tracked_calls] + assert set(hooks_called) == { + "pre_graph_execute", + "post_graph_execute", + "pre_node_execute", + "post_node_execute", + "post_graph_construct", + } + result["a"] = result["a"].to_dict() + result["b"] = result["b"].to_dict() + assert result == { + "a": pd.Series([1, 2, 3]).to_dict(), + "another_async_func": 8, + "async_func_with_param": 4, + "b": pd.Series([4, 5, 6]).to_dict(), + "external_input": 1, + "non_async_func_with_decorator": {"result_1": 9, "result_2": 5}, + "result_1": 9, + "result_2": 5, + "result_3": 1, + "result_4": 2, + "return_dict": {"result_3": 1, "result_4": 2}, + "simple_async_func": 2, + "simple_non_async_func": 7, + } + + +@pytest.mark.asyncio +async def test_async_driver_end_to_end_sync_lifecycle_methods(): + tracked_calls = [] + + class AsyncTrackingAdapter( + BasePostGraphConstruct, + BasePreGraphExecute, + BasePostGraphExecute, + BasePreNodeExecute, + BasePostNodeExecute, + ): + def __init__(self, calls: list, pause_time: float = 0.01): + self.pause_time = pause_time + self.calls = calls + + def pre_graph_execute(self, **kwargs): + self.calls.append(("pre_graph_execute", kwargs)) + + def post_graph_execute(self, **kwargs): + self.calls.append(("post_graph_execute", kwargs)) + + def pre_node_execute(self, **kwargs): + self.calls.append(("pre_node_execute", kwargs)) + + def post_node_execute(self, **kwargs): + self.calls.append(("post_node_execute", kwargs)) + + def post_graph_construct(self, **kwargs): + self.calls.append(("post_graph_construct", kwargs)) + + adapter = AsyncTrackingAdapter(tracked_calls) + + dr = await h_async.AsyncDriver( + {}, simple_async_module, result_builder=base.DictResult(), adapters=[adapter] + ).ainit() + all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] + result = await dr.execute(final_vars=all_vars, inputs={"external_input": 1}) + hooks_called = [call[0] for call in tracked_calls] + assert set(hooks_called) == { + "pre_graph_execute", + "post_graph_execute", + "pre_node_execute", + "post_node_execute", + "post_graph_construct", + } + result["a"] = result["a"].to_dict() + result["b"] = result["b"].to_dict() + assert result == { + "a": pd.Series([1, 2, 3]).to_dict(), + "another_async_func": 8, + "async_func_with_param": 4, + "b": pd.Series([4, 5, 6]).to_dict(), + "external_input": 1, + "non_async_func_with_decorator": {"result_1": 9, "result_2": 5}, + "result_1": 9, + "result_2": 5, + "result_3": 1, + "result_4": 2, + "return_dict": {"result_3": 1, "result_4": 2}, + "simple_async_func": 2, + "simple_non_async_func": 7, + } diff --git a/tests/lifecycle/test_lifecycle_base.py b/tests/lifecycle/test_lifecycle_base.py index 0c2456065..4de5cfffb 100644 --- a/tests/lifecycle/test_lifecycle_base.py +++ b/tests/lifecycle/test_lifecycle_base.py @@ -215,7 +215,10 @@ def post_graph_execute( assert adapter_set.does_hook("pre_do_anything", is_async=False) assert adapter_set.does_hook("post_graph_execute", is_async=False) + # either sync or async + assert adapter_set.does_hook("post_graph_execute", is_async=None) assert not adapter_set.does_hook("pre_node_execute", is_async=False) + assert not adapter_set.does_hook("pre_node_execute", is_async=None) assert not adapter_set.does_hook("pre_node_execute", is_async=True) adapter_set.call_all_lifecycle_hooks_sync("pre_do_anything") diff --git a/ui/sdk/src/hamilton_sdk/adapters.py b/ui/sdk/src/hamilton_sdk/adapters.py index fc951fcf3..5003840e0 100644 --- a/ui/sdk/src/hamilton_sdk/adapters.py +++ b/ui/sdk/src/hamilton_sdk/adapters.py @@ -1,4 +1,3 @@ -import asyncio import datetime import hashlib import logging @@ -148,7 +147,6 @@ def pre_graph_execute( dag_template_id = self.dag_template_id_cache[fg_id] else: raise ValueError("DAG template ID not found in cache. This should never happen.") - tracking_state = TrackingState(run_id) self.tracking_states[run_id] = tracking_state # cache tracking_state.clock_start() @@ -355,7 +353,7 @@ def post_graph_execute( ) -class AsyncHamiltonAdapter( +class AsyncHamiltonTracker( base.BasePostGraphConstructAsync, base.BasePreGraphExecuteAsync, base.BasePreNodeExecuteAsync, @@ -365,13 +363,13 @@ class AsyncHamiltonAdapter( def __init__( self, project_id: int, - api_key: str, username: str, dag_name: str, tags: Dict[str, str] = None, client_factory: Callable[ - [str, str, str], clients.HamiltonClient + [str, str, str], clients.BasicAsynchronousHamiltonClient ] = clients.BasicAsynchronousHamiltonClient, + api_key: str = os.environ.get("HAMILTON_API_KEY", ""), hamilton_api_url=os.environ.get("HAMILTON_API_URL", constants.HAMILTON_API_URL), hamilton_ui_url=os.environ.get("HAMILTON_UI_URL", constants.HAMILTON_UI_URL), ): @@ -385,11 +383,36 @@ def __init__( driver.validate_tags(self.base_tags) self.dag_name = dag_name self.hamilton_ui_url = hamilton_ui_url - logger.debug("Validating authentication against Hamilton BE API...") - asyncio.run(self.client.validate_auth()) - logger.debug(f"Ensuring project {self.project_id} exists...") + # try: + # asyncio.run(self.client.project_exists(self.project_id)) + # except clients.UnauthorizedException: + # logger.exception( + # f"Authentication failed. Please check your username and try again. " + # f"Username: {self.username}" + # ) + # raise + # except clients.ResourceDoesNotExistException: + # logger.error( + # f"Project {self.project_id} does not exist/is accessible. Please create it first in the UI! " + # f"You can do so at {self.hamilton_ui_url}/dashboard/projects" + # ) + # raise + self.dag_template_id_cache = {} + self.tracking_states = {} + self.dw_run_ids = {} + self.task_runs = {} + self.initialized = False + super().__init__() + + async def ainit(self): + if self.initialized: + return self + """You must call this to initialize the tracker.""" + logger.warning("Validating authentication against Hamilton BE API...") + await self.client.validate_auth() + logger.warning(f"Ensuring project {self.project_id} exists...") try: - asyncio.run(self.client.project_exists(self.project_id)) + await self.client.project_exists(self.project_id) except clients.UnauthorizedException: logger.exception( f"Authentication failed. Please check your username and try again. " @@ -402,11 +425,11 @@ def __init__( f"You can do so at {self.hamilton_ui_url}/dashboard/projects" ) raise - self.dag_template_id_cache = {} - self.tracking_states = {} - self.dw_run_ids = {} - self.task_runs = {} - super().__init__() + logger.warning("Initializing Hamilton tracker.") + await self.client.ainit() + logger.warning("Initialized Hamilton tracker.") + self.initialized = True + return self async def post_graph_construct( self, graph: h_graph.FunctionGraph, modules: List[ModuleType], config: Dict[str, Any] @@ -445,7 +468,6 @@ async def pre_graph_execute( overrides: Dict[str, Any], ): logger.debug("pre_graph_execute %s", run_id) - self.run_id = run_id fg_id = id(graph) if fg_id in self.dag_template_id_cache: dag_template_id = self.dag_template_id_cache[fg_id] @@ -480,7 +502,7 @@ async def pre_node_execute( task_update = dict( node_template_name=node_.name, - node_name=get_node_name(node_.name, task_id), + node_name=get_node_name(node_, task_id), realized_dependencies=[dep.name for dep in node_.dependencies], status=task_run.status, start_time=task_run.start_time, @@ -488,7 +510,7 @@ async def pre_node_execute( ) await self.client.update_tasks( self.dw_run_ids[run_id], - attributes=[None], + attributes=[], task_updates=[task_update], in_samples=[task_run.is_in_sample], ) @@ -501,6 +523,7 @@ async def post_node_execute( error: Optional[Exception], result: Any, task_id: Optional[str] = None, + **future_kwargs, ): logger.debug("post_node_execute %s", run_id) task_run = self.task_runs[run_id][node_.name] diff --git a/ui/sdk/src/hamilton_sdk/api/clients.py b/ui/sdk/src/hamilton_sdk/api/clients.py index a066e1975..e3afb137a 100644 --- a/ui/sdk/src/hamilton_sdk/api/clients.py +++ b/ui/sdk/src/hamilton_sdk/api/clients.py @@ -1,4 +1,5 @@ import abc +import asyncio import datetime import logging import queue @@ -30,6 +31,30 @@ def __init__(self, path: str, user: str): super().__init__(message) +def create_batch(batch: dict, dag_run_id: int): + attributes = defaultdict(list) + task_updates = defaultdict(list) + for item in batch: + if item["dag_run_id"] == dag_run_id: + for attr in item["attributes"]: + if attr is None: + continue + attributes[attr["node_name"]].append(attr) + for task_update in item["task_updates"]: + if task_update is None: + continue + task_updates[task_update["node_name"]].append(task_update) + + # this assumes correct ordering of the attributes and task_updates + attributes_list = [ + reduce(lambda x, y: {**x, **y}, attributes[node_name]) for node_name in attributes + ] + task_updates_list = [ + reduce(lambda x, y: {**x, **y}, task_updates[node_name]) for node_name in task_updates + ] + return attributes_list, task_updates_list + + class HamiltonClient: @abc.abstractmethod def validate_auth(self): @@ -220,28 +245,7 @@ def flush(self, batch): # group by dag_run_id -- just incase someone does something weird? dag_run_ids = set([item["dag_run_id"] for item in batch]) for dag_run_id in dag_run_ids: - attributes = defaultdict(list) - task_updates = defaultdict(list) - for item in batch: - if item["dag_run_id"] == dag_run_id: - for attr in item["attributes"]: - if attr is None: - continue - attributes[attr["node_name"]].append(attr) - for task_update in item["task_updates"]: - if task_update is None: - continue - task_updates[task_update["node_name"]].append(task_update) - - # this assumes correct ordering of the attributes and task_updates - attributes_list = [ - reduce(lambda x, y: {**x, **y}, attributes[node_name]) for node_name in attributes - ] - task_updates_list = [ - reduce(lambda x, y: {**x, **y}, task_updates[node_name]) - for node_name in task_updates - ] - + attributes_list, task_updates_list = create_batch(batch, dag_run_id) response = requests.put( f"{self.base_url}/dag_runs_bulk?dag_run_id={dag_run_id}", json={ @@ -512,6 +516,65 @@ def __init__(self, api_key: str, username: str, h_api_url: str, base_path: str = self.api_key = api_key self.username = username self.base_url = h_api_url + base_path + self.flush_interval = 5 + self.data_queue = asyncio.Queue() + self.running = True + self.max_batch_size = 100 + + async def ainit(self): + asyncio.create_task(self.worker()) + + async def flush(self, batch): + """Flush the batch (send it to the backend or process it).""" + logger.debug(f"Flushing batch: {len(batch)}") # Replace with actual processing logic + # group by dag_run_id -- just incase someone does something weird? + dag_run_ids = set([item["dag_run_id"] for item in batch]) + for dag_run_id in dag_run_ids: + attributes_list, task_updates_list = create_batch(batch, dag_run_id) + async with aiohttp.ClientSession() as session: + async with session.put( + f"{self.base_url}/dag_runs_bulk?dag_run_id={dag_run_id}", + json={ + "attributes": make_json_safe(attributes_list), + "task_updates": make_json_safe(task_updates_list), + }, + headers=self._common_headers(), + ) as response: + try: + response.raise_for_status() + logger.debug(f"Updated tasks for DAG run {dag_run_id}") + except HTTPError: + logger.exception(f"Failed to update tasks for DAG run {dag_run_id}") + # zraise + + async def worker(self): + """Worker thread to process the queue""" + batch = [] + last_flush_time = time.time() + logger.debug("Starting worker") + while True: + logger.debug( + f"Awaiting item from queue -- current batched # of items are: {len(batch)}" + ) + try: + item = await asyncio.wait_for(self.data_queue.get(), timeout=self.flush_interval) + batch.append(item) + except asyncio.TimeoutError: + # This is fine, we just keep waiting + pass + else: + if item is None: + await self.flush(batch) + return + + # Check if batch is full or flush interval has passed + if ( + len(batch) >= self.max_batch_size + or (time.time() - last_flush_time) >= self.flush_interval + ): + await self.flush(batch) + batch = [] + last_flush_time = time.time() def _common_headers(self) -> Dict[str, Any]: """Yields the common headers for all requests. @@ -726,26 +789,14 @@ async def update_tasks( f"Updating tasks for DAG run {dag_run_id} with {len(attributes)} " f"attributes and {len(task_updates)} task updates" ) - url = f"{self.base_url}/dag_runs_bulk?dag_run_id={dag_run_id}" - headers = self._common_headers() - data = { - "attributes": make_json_safe(attributes), - "task_updates": make_json_safe(task_updates), - } - - async with aiohttp.ClientSession() as session: - async with session.put(url, json=data, headers=headers) as response: - try: - response.raise_for_status() - logger.debug(f"Updated tasks for DAG run {dag_run_id}") - except aiohttp.ClientResponseError: - logger.exception(f"Failed to update tasks for DAG run {dag_run_id}") - raise + await self.data_queue.put( + {"dag_run_id": dag_run_id, "attributes": attributes, "task_updates": task_updates} + ) async def log_dag_run_end(self, dag_run_id: int, status: str): logger.debug(f"Logging end of DAG run {dag_run_id} with status {status}") url = f"{self.base_url}/dag_runs/{dag_run_id}/" - data = (make_json_safe({"run_status": status, "run_end_time": datetime.datetime.utcnow()}),) + data = make_json_safe({"run_status": status, "run_end_time": datetime.datetime.utcnow()}) headers = self._common_headers() async with aiohttp.ClientSession() as session: async with session.put(url, json=data, headers=headers) as response: