From f07efc2e7dce3d65cb0456921c3772e238153aa2 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Mon, 24 Jun 2024 21:48:44 -0700 Subject: [PATCH 1/6] Fixes up async hooks Fixes some issues around asynchonrous hooks. This also adds an API for whether we have a synchronous *or* an asynchronous hook. Adds an async API for calling all lifecycle hooks whether they are sync or async. --- hamilton/driver.py | 1 + hamilton/lifecycle/base.py | 56 +++++++++++++++++++------- tests/lifecycle/test_lifecycle_base.py | 3 ++ 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/hamilton/driver.py b/hamilton/driver.py index 1808536c9..abc576e2c 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -403,6 +403,7 @@ def __init__( if _graph_executor is None: _graph_executor = DefaultGraphExecutor(self.adapter) self.graph_executor = _graph_executor + self.config = config except Exception as e: error = telemetry.sanitize_error(*sys.exc_info()) logger.error(SLACK_ERROR_MESSAGE) diff --git a/hamilton/lifecycle/base.py b/hamilton/lifecycle/base.py index 0d69079e2..a9af34f8e 100644 --- a/hamilton/lifecycle/base.py +++ b/hamilton/lifecycle/base.py @@ -745,11 +745,17 @@ def __init__(self, *adapters: LifecycleAdapter): :param adapters: Adapters to group together """ - self._adapters = list(adapters) + self._adapters = self._uniqify_adapters(adapters) self.sync_hooks, self.async_hooks = self._get_lifecycle_hooks() self.sync_methods, self.async_methods = self._get_lifecycle_methods() self.sync_validators = self._get_lifecycle_validators() + def _uniqify_adapters(self, adapters: List[LifecycleAdapter]) -> List[LifecycleAdapter]: + seen = set() + return [ + adapter for adapter in adapters if not (id(adapter) in seen or seen.add(id(adapter))) + ] + def _get_lifecycle_validators( self, ) -> Dict[str, List[LifecycleAdapter]]: @@ -811,7 +817,7 @@ def _get_lifecycle_methods( {method: list(adapters) for method, adapters in async_methods.items()}, ) - def does_hook(self, hook_name: str, is_async: bool) -> bool: + def does_hook(self, hook_name: str, is_async: Optional[bool] = None) -> bool: """Whether or not a hook is implemented by any of the adapters in this group. If this hook is not registered, this will raise a ValueError. @@ -819,21 +825,22 @@ def does_hook(self, hook_name: str, is_async: bool) -> bool: :param is_async: Whether you want the async version or not :return: True if this adapter set does this hook, False otherwise """ - if is_async and hook_name not in REGISTERED_ASYNC_HOOKS: + either = is_async is None + if (is_async or either) and hook_name not in REGISTERED_ASYNC_HOOKS: raise ValueError( f"Hook {hook_name} is not registered as an asynchronous lifecycle hook. " f"Registered hooks are {REGISTERED_ASYNC_HOOKS}" ) - if not is_async and hook_name not in REGISTERED_SYNC_HOOKS: + if ((not is_async) or either) and hook_name not in REGISTERED_SYNC_HOOKS: raise ValueError( f"Hook {hook_name} is not registered as a synchronous lifecycle hook. " f"Registered hooks are {REGISTERED_SYNC_HOOKS}" ) - if not is_async: - return hook_name in self.sync_hooks - return hook_name in self.async_hooks + has_async = hook_name in self.async_hooks + has_sync = hook_name in self.sync_hooks + return (has_async or has_sync) if either else has_async if is_async else has_sync - def does_method(self, method_name: str, is_async: bool) -> bool: + def does_method(self, method_name: str, is_async: Optional[bool] = None) -> bool: """Whether a method is implemented by any of the adapters in this group. If this method is not registered, this will raise a ValueError. @@ -841,19 +848,20 @@ def does_method(self, method_name: str, is_async: bool) -> bool: :param is_async: Whether you want the async version or not :return: True if this adapter set does this method, False otherwise """ - if is_async and method_name not in REGISTERED_ASYNC_METHODS: + either = is_async is None + if (is_async or either) and method_name not in REGISTERED_ASYNC_METHODS: raise ValueError( f"Method {method_name} is not registered as an asynchronous lifecycle method. " f"Registered methods are {REGISTERED_ASYNC_METHODS}" ) - if not is_async and method_name not in REGISTERED_SYNC_METHODS: + if ((not is_async) or either) and method_name not in REGISTERED_SYNC_METHODS: raise ValueError( f"Method {method_name} is not registered as a synchronous lifecycle method. " f"Registered methods are {REGISTERED_SYNC_METHODS}" ) - if not is_async: - return method_name in self.sync_methods - return method_name in self.async_methods + has_async = method_name in self.async_methods + has_sync = method_name in self.sync_methods + return (has_async or has_sync) if either else has_async if is_async else has_sync def does_validation(self, validator_name: str) -> bool: """Whether a validator is implemented by any of the adapters in this group. @@ -875,7 +883,7 @@ def call_all_lifecycle_hooks_sync(self, hook_name: str, **kwargs): :param hook_name: Name of the hooks to call :param kwargs: Keyword arguments to pass into the hook """ - for adapter in self.sync_hooks[hook_name]: + for adapter in self.sync_hooks.get(hook_name, []): getattr(adapter, hook_name)(**kwargs) async def call_all_lifecycle_hooks_async(self, hook_name: str, **kwargs): @@ -885,10 +893,19 @@ async def call_all_lifecycle_hooks_async(self, hook_name: str, **kwargs): :param kwargs: Keyword arguments to pass into the hook """ futures = [] - for adapter in self.async_hooks[hook_name]: + for adapter in self.async_hooks.get(hook_name, []): futures.append(getattr(adapter, hook_name)(**kwargs)) await asyncio.gather(*futures) + async def call_all_lifecycle_hooks_sync_and_async(self, hook_name: str, **kwargs): + """Calls all the lifecycle hooks whether they are sync or async + + :param hook_name: name of the hook + :param kwargs: keyword arguments for the hook + """ + self.call_all_lifecycle_hooks_sync(hook_name, **kwargs) + await self.call_all_lifecycle_hooks_async(hook_name, **kwargs) + def call_lifecycle_method_sync(self, method_name: str, **kwargs) -> Any: """Calls a lifecycle method in this group, by method name. @@ -947,3 +964,12 @@ def adapters(self) -> List[LifecycleAdapter]: :return: A list of adapters """ return self._adapters + + async def ainit(self): + """Asynchronously initializes the adapters in this group. This is so we can avoid having an async constructor + -- it is an implicit contract -- the async adapters are allowed one ainit() method that will be called by the driver. + """ + for adapter in self.adapters: + print(adapter) + if hasattr(adapter, "ainit"): + await adapter.ainit() diff --git a/tests/lifecycle/test_lifecycle_base.py b/tests/lifecycle/test_lifecycle_base.py index 0c2456065..4de5cfffb 100644 --- a/tests/lifecycle/test_lifecycle_base.py +++ b/tests/lifecycle/test_lifecycle_base.py @@ -215,7 +215,10 @@ def post_graph_execute( assert adapter_set.does_hook("pre_do_anything", is_async=False) assert adapter_set.does_hook("post_graph_execute", is_async=False) + # either sync or async + assert adapter_set.does_hook("post_graph_execute", is_async=None) assert not adapter_set.does_hook("pre_node_execute", is_async=False) + assert not adapter_set.does_hook("pre_node_execute", is_async=None) assert not adapter_set.does_hook("pre_node_execute", is_async=True) adapter_set.call_all_lifecycle_hooks_sync("pre_do_anything") From 9383e1769b8cedd1a63f3903bed1e801b233e837 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Mon, 24 Jun 2024 21:56:31 -0700 Subject: [PATCH 2/6] Fixes up asynchronous driver for Hamilton 1. Adds a builder API specifically for the asynchonous driver 2. Fixes up the ASyncGraphAdapter to call hooks in a clean way This does something clever -- the graph adapter takes in the hooks, which then run async while the function is running. Synchronous hooks will work, but they won't be called at quite the right time. --- docs/reference/drivers/AsyncDriver.rst | 11 +- hamilton/experimental/h_async.py | 313 ++++++++++++++++++++++--- hamilton/lifecycle/base.py | 15 +- plugin_tests/h_async/test_h_async.py | 148 ++++++++++++ 4 files changed, 454 insertions(+), 33 deletions(-) diff --git a/docs/reference/drivers/AsyncDriver.rst b/docs/reference/drivers/AsyncDriver.rst index 75a38b7c4..01286be66 100644 --- a/docs/reference/drivers/AsyncDriver.rst +++ b/docs/reference/drivers/AsyncDriver.rst @@ -1,7 +1,16 @@ AsyncDriver -______________ +___________ Use this driver in an async context. E.g. for use with FastAPI. .. autoclass:: hamilton.experimental.h_async.AsyncDriver :special-members: __init__ :members: + +Async Builder +------------- + +Builds a driver in an async context -- use ``await builder....build()``. + +.. autoclass:: hamilton.experimental.h_async.Builder + :special-members: __init__ + :members: diff --git a/hamilton/experimental/h_async.py b/hamilton/experimental/h_async.py index c1ad34d5f..d05535972 100644 --- a/hamilton/experimental/h_async.py +++ b/hamilton/experimental/h_async.py @@ -3,17 +3,20 @@ import logging import sys import time -import types import typing +import uuid from types import ModuleType from typing import Any, Dict, Optional, Tuple -from hamilton import base, driver, node, telemetry +import hamilton.lifecycle.base as lifecycle_base +from hamilton import base, driver, graph, lifecycle, node, telemetry +from hamilton.execution.graph_functions import create_error_message +from hamilton.io.materialization import ExtractorFactory, MaterializerFactory logger = logging.getLogger(__name__) -async def await_dict_of_tasks(task_dict: Dict[str, types.CoroutineType]) -> Dict[str, Any]: +async def await_dict_of_tasks(task_dict: Dict[str, typing.Awaitable]) -> Dict[str, Any]: """Util to await a dictionary of tasks as asyncio.gather is kind of garbage""" keys = sorted(task_dict.keys()) coroutines = [task_dict[key] for key in keys] @@ -33,10 +36,14 @@ async def process_value(val: Any) -> Any: return await val -class AsyncGraphAdapter(base.SimplePythonDataFrameGraphAdapter): +class AsyncGraphAdapter(lifecycle_base.BaseDoNodeExecute, lifecycle.ResultBuilder): """Graph adapter for use with the :class:`AsyncDriver` class.""" - def __init__(self, result_builder: base.ResultMixin = None): + def __init__( + self, + result_builder: base.ResultMixin = None, + async_lifecycle_adapters: Optional[lifecycle_base.LifecycleAdapterSet] = None, + ): """Creates an AsyncGraphAdapter class. Note this will *only* work with the AsyncDriver class. Some things to note: @@ -46,9 +53,22 @@ def __init__(self, result_builder: base.ResultMixin = None): because that function is called directly within the decorator, so we cannot await it. """ super(AsyncGraphAdapter, self).__init__() + self.adapter = ( + async_lifecycle_adapters + if async_lifecycle_adapters is not None + else lifecycle_base.LifecycleAdapterSet() + ) self.result_builder = result_builder if result_builder else base.PandasDataFrameResult() + self.is_initialized = False - def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any: + def do_node_execute( + self, + *, + run_id: str, + node_: node.Node, + kwargs: typing.Dict[str, typing.Any], + task_id: Optional[str] = None, + ) -> typing.Any: """Executes a node. Note this doesn't actually execute it -- rather, it returns a task. This does *not* use async def, as we want it to be awaited on later -- this await is done in processing parameters of downstream functions/final results. We can ensure that as @@ -57,34 +77,111 @@ def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> Note that this assumes that everything is awaitable, even if it isn't. In that case, it just wraps it in one. + :param task_id: + :param node_: + :param run_id: :param node: Node to wrap :param kwargs: Keyword arguments (either coroutines or raw values) to call it with :return: A task """ - callabl = node.callable + callabl = node_.callable async def new_fn(fn=callabl, **fn_kwargs): task_dict = {key: process_value(value) for key, value in fn_kwargs.items()} fn_kwargs = await await_dict_of_tasks(task_dict) - if inspect.iscoroutinefunction(fn): - return await fn(**fn_kwargs) - return fn(**fn_kwargs) + error = None + result = None + success = True + pre_node_execute_errored = False + try: + if self.adapter.does_hook("pre_node_execute", is_async=True): + try: + await self.adapter.call_all_lifecycle_hooks_async( + "pre_node_execute", + run_id=run_id, + node_=node_, + kwargs=fn_kwargs, + task_id=task_id, + ) + except Exception as e: + pre_node_execute_errored = True + raise e + # TODO -- consider how to use node execution methods in the future + # This is messy as it is a method called within a method... + # if self.adapter.does_method("do_node_execute", is_async=False): + # result = self.adapter.call_lifecycle_method_sync( + # "do_node_execute", + # run_id=run_id, + # node_=node_, + # kwargs=kwargs, + # task_id=task_id, + # ) + # else: + + result = ( + await fn(**fn_kwargs) if asyncio.iscoroutinefunction(fn) else fn(**fn_kwargs) + ) + except Exception as e: + success = False + error = e + step = "[pre-node-execute:async]" if pre_node_execute_errored else "" + message = create_error_message(kwargs, node_, step) + logger.exception(message) + raise + finally: + if not pre_node_execute_errored and self.adapter.does_hook( + "post_node_execute", is_async=True + ): + try: + await self.adapter.call_all_lifecycle_hooks_async( + "post_node_execute", + run_id=run_id, + node_=node_, + kwargs=fn_kwargs, + success=success, + error=error, + result=result, + task_id=task_id, + ) + except Exception: + message = create_error_message(kwargs, node_, "[post-node-execute]") + logger.exception(message) + raise + + return result coroutine = new_fn(**kwargs) task = asyncio.create_task(coroutine) return task - def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any: - """Currently this is a no-op -- it just delegates to the resultsbuilder. - That said, we *could* make it async, but it feels wrong -- this will just be - called after `raw_execute`. - - :param outputs: Outputs (awaited) from the graph. - :return: The final results. - """ + def build_result(self, **outputs: Any) -> Any: return self.result_builder.build_result(**outputs) +def separate_sync_from_async( + adapters: typing.List[lifecycle.LifecycleAdapter], +) -> Tuple[typing.List[lifecycle.LifecycleAdapter], typing.List[lifecycle.LifecycleAdapter]]: + """Separates the sync and async adapters from a list of adapters. + Note this only works with hooks -- we'll be dealing with methods later. + + :param adapters: List of adapters + :return: Tuple of sync adapters, async adapters + """ + + adapter_set = lifecycle_base.LifecycleAdapterSet(*adapters) + # this is using internal(ish) fields (.sync_hooks/.async_hooks) -- we should probably expose it + # For now this is OK + # Note those are dict[hook_name, list[hook]], so we have to flatten + return ( + [sync_adapter for adapters in adapter_set.sync_hooks.values() for sync_adapter in adapters], + [ + async_adapter + for adapters in adapter_set.async_hooks.values() + for async_adapter in adapters + ], + ) + + class AsyncDriver(driver.Driver): """Async driver. This is a driver that uses the AsyncGraphAdapter to execute the graph. @@ -95,16 +192,76 @@ class AsyncDriver(driver.Driver): """ - def __init__(self, config, *modules, result_builder: Optional[base.ResultMixin] = None): + def __init__( + self, + config, + *modules, + result_builder: Optional[base.ResultMixin] = None, + adapters: typing.List[lifecycle.LifecycleAdapter] = None, + ): """Instantiates an asynchronous driver. + You will also need to call `ainit` to initialize the driver if you have any hooks/adapters. + + Note that this is not the desired API -- you should be using the :py:class:`hamilton.experimental.h_async.Builder` class to create the driver. + :param config: Config to build the graph :param modules: Modules to crawl for fns/graph nodes + :param adapters: Adapters to use for lifecycle methods. :param result_builder: Results mixin to compile the graph's final results. TBD whether this should be included in the long run. """ + if adapters is not None: + sync_adapters, async_adapters = separate_sync_from_async(adapters) + else: + # separate out so we know what the driver + sync_adapters = [] + async_adapters = [] + # we'll need to use this in multiple contexts so we'll keep it around for later + result_builders = [adapter for adapter in adapters if isinstance(adapter, base.ResultMixin)] + if result_builder is not None: + result_builders.append(result_builder) + if len(result_builders) > 1: + raise ValueError( + "You cannot pass more than one result builder to the async driver. " + "Please pass in a single result builder" + ) + # it will be defaulted by the graph adapter + result_builder = result_builders[0] if len(result_builders) == 1 else None super(AsyncDriver, self).__init__( - config, *modules, adapter=AsyncGraphAdapter(result_builder=result_builder) + config, + *modules, + adapter=[ + # We pass in the async adapters here as this can call node-level hooks + # Otherwise we trust the driver/fn graph to call sync adapters + AsyncGraphAdapter( + result_builder=result_builder, + async_lifecycle_adapters=lifecycle_base.LifecycleAdapterSet(*async_adapters), + ), + # We pass in the sync adapters here as this can call + *sync_adapters, + *async_adapters, # note async adapters will not be called during synchronous execution -- this is for access later + ], ) + self.initialized = False + + async def ainit(self) -> "AsyncDriver": + """Initializes the driver when using async. This only exists for backwards compatibility. + In Hamilton 2.0, we will be using an asynchronous constructor. + See https://dev.to/akarshan/asynchronous-python-magic-how-to-create-awaitable-constructors-with-asyncmixin-18j5. + """ + if self.initialized: + # this way it can be called twice + return self + if self.adapter.does_hook("post_graph_construct", is_async=True): + await self.adapter.call_all_lifecycle_hooks_async( + "post_graph_construct", + graph=self.graph, + modules=self.graph_modules, + config=self.config, + ) + await self.adapter.ainit() + self.initialized = True + return self async def raw_execute( self, @@ -112,7 +269,7 @@ async def raw_execute( overrides: Dict[str, Any] = None, display_graph: bool = False, # don't care inputs: Dict[str, Any] = None, - run_id: str = None, + _fn_graph: graph.FunctionGraph = None, ) -> Dict[str, Any]: """Executes the graph, returning a dictionary of strings (node keys) to final results. @@ -120,20 +277,56 @@ async def raw_execute( :param overrides: Overrides for nodes :param display_graph: whether or not to display graph -- this is not supported. :param inputs: Inputs for DAG runtime calculation + :param _fn_graph: Function graph for compatibility with superclass -- unused :return: A dict of key -> result """ + assert _fn_graph is None, ( + "_fn_graph must not be provided " + "-- the only reason you'd do this is to use materialize(), which is not supported yet.." + ) + run_id = str(uuid.uuid4()) nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs) memoized_computation = dict() # memoized storage - self.graph.execute(nodes, memoized_computation, overrides, inputs) - if display_graph: - raise ValueError( - "display_graph=True is not supported for the async graph adapter. " - "Instead you should be using visualize_execution." + if self.adapter.does_hook("pre_graph_execute"): + await self.adapter.call_all_lifecycle_hooks_sync_and_async( + "pre_graph_execute", + run_id=run_id, + graph=self.graph, + final_vars=final_vars, + inputs=inputs, + overrides=overrides, ) - task_dict = { - key: asyncio.create_task(process_value(memoized_computation[key])) for key in final_vars - } - return await await_dict_of_tasks(task_dict) + results = None + error = None + success = False + try: + self.graph.execute(nodes, memoized_computation, overrides, inputs, run_id=run_id) + if display_graph: + raise ValueError( + "display_graph=True is not supported for the async graph adapter. " + "Instead you should be using visualize_execution." + ) + task_dict = { + key: asyncio.create_task(process_value(memoized_computation[key])) + for key in final_vars + } + results = await await_dict_of_tasks(task_dict) + success = True + except Exception as e: + error = e + success = False + raise e + finally: + if self.adapter.does_hook("post_graph_execute", is_async=None): + await self.adapter.call_all_lifecycle_hooks_sync_and_async( + "post_graph_execute", + run_id=run_id, + graph=self.graph, + success=success, + error=error, + results=results, + ) + return results async def execute( self, @@ -228,3 +421,63 @@ async def make_coroutine(): except Exception as e: if logger.isEnabledFor(logging.DEBUG): logger.error(f"Encountered error submitting async telemetry:\n{e}") + + +class Builder(driver.Builder): + """Builder for the async driver. This is equivalent to the standard builder, but has a more limited API. + Note this does not support dynamic execution or materializers (for now). + + Here is an example of how you might use it to get the tracker working: + + .. code-block:: python + + from hamilton_sdk import tracker + tracker_async = adapters.AsyncHamiltonTracker( + project_id=1, + username="elijah", + dag_name="async_tracker", + ) + dr = ( + await h_async + .Builder() + .with_modules(async_module) + .with_adapters(tracking_async) + .build() + ) + """ + + def __init__(self): + super(Builder, self).__init__() + + def _not_supported(self, method_name: str, additional_message: str = ""): + raise ValueError( + f"Builder().{method_name}() is not supported for the async driver. {additional_message}" + ) + + def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> "Builder": + self._not_supported("enable_dynamic_execution") + + def with_materializers( + self, *materializers: typing.Union[ExtractorFactory, MaterializerFactory] + ) -> "Builder": + self._not_supported("with_materializers") + + def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> "Builder": + self._not_supported( + "with_adapter", + "Use with_adapters instead to pass in the tracker (or other async hooks/methods)", + ) + + async def build(self): + adapter = self.adapters if self.adapters is not None else [] + if self.legacy_graph_adapter is not None: + adapter.append(self.legacy_graph_adapter) + + result_builder = self.result_builder if self.result_builder is not None else base.DictResult() + + return await AsyncDriver( + self.config, + *self.modules, + adapters=adapter, + result_builder=result_builder + ).ainit() diff --git a/hamilton/lifecycle/base.py b/hamilton/lifecycle/base.py index a9af34f8e..ef9d07387 100644 --- a/hamilton/lifecycle/base.py +++ b/hamilton/lifecycle/base.py @@ -751,6 +751,14 @@ def __init__(self, *adapters: LifecycleAdapter): self.sync_validators = self._get_lifecycle_validators() def _uniqify_adapters(self, adapters: List[LifecycleAdapter]) -> List[LifecycleAdapter]: + """Removes duplicate adapters from the list of adapters -- this often happens on how they're passed in + and we don't want to have the same adapter twice. Specifically, this came up due to parsing/splitting out adapters + with async lifecycle hooks -- there were cases in which we were passed duplicates. This was compounded as we would pass + adapters to other adapter sets and end up further duplicating. + + TODO -- remove this and ensure that no case passes in duplicates. + """ + seen = set() return [ adapter for adapter in adapters if not (id(adapter) in seen or seen.add(id(adapter))) @@ -967,9 +975,12 @@ def adapters(self) -> List[LifecycleAdapter]: async def ainit(self): """Asynchronously initializes the adapters in this group. This is so we can avoid having an async constructor - -- it is an implicit contract -- the async adapters are allowed one ainit() method that will be called by the driver. + -- it is an implicit internal-facing contract -- the async adapters are allowed one ainit() + method that will be called by the driver. + + Note this is not public-facing -- E.G. you cannot expect to define this on your own adapters. We may consider adding + a ``pre_do_anything`` async hook and removing this, but for now this should suffice. """ for adapter in self.adapters: - print(adapter) if hasattr(adapter, "ainit"): await adapter.ainit() diff --git a/plugin_tests/h_async/test_h_async.py b/plugin_tests/h_async/test_h_async.py index 6f670afed..0c0626755 100644 --- a/plugin_tests/h_async/test_h_async.py +++ b/plugin_tests/h_async/test_h_async.py @@ -6,6 +6,18 @@ from hamilton import base from hamilton.experimental import h_async +from hamilton.lifecycle.base import ( + BasePostGraphConstruct, + BasePostGraphConstructAsync, + BasePostGraphExecute, + BasePostGraphExecuteAsync, + BasePostNodeExecute, + BasePostNodeExecuteAsync, + BasePreGraphExecute, + BasePreGraphExecuteAsync, + BasePreNodeExecute, + BasePreNodeExecuteAsync, +) from .resources import simple_async_module @@ -102,3 +114,139 @@ async def test_driver_end_to_end_telemetry(send_event_json): await asyncio.gather(*[t for t in tasks if t != current_task]) assert send_event_json.called assert len(send_event_json.call_args_list) == 2 + + +@pytest.mark.asyncio +async def test_async_driver_end_to_end_async_lifecycle_methods(): + tracked_calls = [] + + class AsyncTrackingAdapter( + BasePostGraphConstructAsync, + BasePreGraphExecuteAsync, + BasePostGraphExecuteAsync, + BasePreNodeExecuteAsync, + BasePostNodeExecuteAsync, + ): + def __init__(self, calls: list, pause_time: float = 0.01): + self.pause_time = pause_time + self.calls = calls + + async def _pause(self): + return await asyncio.sleep(self.pause_time) + + async def pre_graph_execute(self, **kwargs): + await self._pause() + self.calls.append(("pre_graph_execute", kwargs)) + + async def post_graph_execute(self, **kwargs): + await self._pause() + self.calls.append(("post_graph_execute", kwargs)) + + async def pre_node_execute(self, **kwargs): + await self._pause() + self.calls.append(("pre_node_execute", kwargs)) + + async def post_node_execute(self, **kwargs): + await self._pause() + self.calls.append(("post_node_execute", kwargs)) + + async def post_graph_construct(self, **kwargs): + await self._pause() + self.calls.append(("post_graph_construct", kwargs)) + + adapter = AsyncTrackingAdapter(tracked_calls) + + dr = await h_async.AsyncDriver( + {}, simple_async_module, result_builder=base.DictResult(), adapters=[adapter] + ).ainit() + all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] + result = await dr.execute(final_vars=all_vars, inputs={"external_input": 1}) + hooks_called = [call[0] for call in tracked_calls] + assert set(hooks_called) == { + "pre_graph_execute", + "post_graph_execute", + "pre_node_execute", + "post_node_execute", + "post_graph_construct", + } + result["a"] = result["a"].to_dict() + result["b"] = result["b"].to_dict() + assert result == { + "a": pd.Series([1, 2, 3]).to_dict(), + "another_async_func": 8, + "async_func_with_param": 4, + "b": pd.Series([4, 5, 6]).to_dict(), + "external_input": 1, + "non_async_func_with_decorator": {"result_1": 9, "result_2": 5}, + "result_1": 9, + "result_2": 5, + "result_3": 1, + "result_4": 2, + "return_dict": {"result_3": 1, "result_4": 2}, + "simple_async_func": 2, + "simple_non_async_func": 7, + } + + +@pytest.mark.asyncio +async def test_async_driver_end_to_end_sync_lifecycle_methods(): + tracked_calls = [] + + class AsyncTrackingAdapter( + BasePostGraphConstruct, + BasePreGraphExecute, + BasePostGraphExecute, + BasePreNodeExecute, + BasePostNodeExecute, + ): + def __init__(self, calls: list, pause_time: float = 0.01): + self.pause_time = pause_time + self.calls = calls + + def pre_graph_execute(self, **kwargs): + self.calls.append(("pre_graph_execute", kwargs)) + + def post_graph_execute(self, **kwargs): + self.calls.append(("post_graph_execute", kwargs)) + + def pre_node_execute(self, **kwargs): + self.calls.append(("pre_node_execute", kwargs)) + + def post_node_execute(self, **kwargs): + self.calls.append(("post_node_execute", kwargs)) + + def post_graph_construct(self, **kwargs): + self.calls.append(("post_graph_construct", kwargs)) + + adapter = AsyncTrackingAdapter(tracked_calls) + + dr = await h_async.AsyncDriver( + {}, simple_async_module, result_builder=base.DictResult(), adapters=[adapter] + ).ainit() + all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] + result = await dr.execute(final_vars=all_vars, inputs={"external_input": 1}) + hooks_called = [call[0] for call in tracked_calls] + assert set(hooks_called) == { + "pre_graph_execute", + "post_graph_execute", + "pre_node_execute", + "post_node_execute", + "post_graph_construct", + } + result["a"] = result["a"].to_dict() + result["b"] = result["b"].to_dict() + assert result == { + "a": pd.Series([1, 2, 3]).to_dict(), + "another_async_func": 8, + "async_func_with_param": 4, + "b": pd.Series([4, 5, 6]).to_dict(), + "external_input": 1, + "non_async_func_with_decorator": {"result_1": 9, "result_2": 5}, + "result_1": 9, + "result_2": 5, + "result_3": 1, + "result_4": 2, + "return_dict": {"result_3": 1, "result_4": 2}, + "simple_async_func": 2, + "simple_non_async_func": 7, + } From 6549130cf47d35d58e011a123de34ead9aee87f3 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Mon, 24 Jun 2024 22:01:19 -0700 Subject: [PATCH 3/6] Updates the Hamilton SDK to work with async This has a basic queuing system -- we laucnh a periodic task that runs and flushes out a queue -- it times out at the flush interval and flushes everything it has. This will often flush after a request is done but it will always flush. --- ui/sdk/src/hamilton_sdk/adapters.py | 43 ++++---- ui/sdk/src/hamilton_sdk/api/clients.py | 132 +++++++++++++++++-------- 2 files changed, 118 insertions(+), 57 deletions(-) diff --git a/ui/sdk/src/hamilton_sdk/adapters.py b/ui/sdk/src/hamilton_sdk/adapters.py index 2fddfcf74..88eef4a1f 100644 --- a/ui/sdk/src/hamilton_sdk/adapters.py +++ b/ui/sdk/src/hamilton_sdk/adapters.py @@ -1,4 +1,3 @@ -import asyncio import datetime import hashlib import logging @@ -148,7 +147,6 @@ def pre_graph_execute( dag_template_id = self.dag_template_id_cache[fg_id] else: raise ValueError("DAG template ID not found in cache. This should never happen.") - tracking_state = TrackingState(run_id) self.tracking_states[run_id] = tracking_state # cache tracking_state.clock_start() @@ -386,7 +384,7 @@ def post_graph_execute( ) -class AsyncHamiltonAdapter( +class AsyncHamiltonTracker( base.BasePostGraphConstructAsync, base.BasePreGraphExecuteAsync, base.BasePreNodeExecuteAsync, @@ -396,13 +394,13 @@ class AsyncHamiltonAdapter( def __init__( self, project_id: int, - api_key: str, username: str, dag_name: str, tags: Dict[str, str] = None, client_factory: Callable[ - [str, str, str], clients.HamiltonClient + [str, str, str], clients.BasicAsynchronousHamiltonClient ] = clients.BasicAsynchronousHamiltonClient, + api_key: str = os.environ.get("HAMILTON_API_KEY", ""), hamilton_api_url=os.environ.get("HAMILTON_API_URL", constants.HAMILTON_API_URL), hamilton_ui_url=os.environ.get("HAMILTON_UI_URL", constants.HAMILTON_UI_URL), ): @@ -416,11 +414,22 @@ def __init__( driver.validate_tags(self.base_tags) self.dag_name = dag_name self.hamilton_ui_url = hamilton_ui_url - logger.debug("Validating authentication against Hamilton BE API...") - asyncio.run(self.client.validate_auth()) - logger.debug(f"Ensuring project {self.project_id} exists...") + self.dag_template_id_cache = {} + self.tracking_states = {} + self.dw_run_ids = {} + self.task_runs = {} + self.initialized = False + super().__init__() + + async def ainit(self): + if self.initialized: + return self + """You must call this to initialize the tracker.""" + logger.info("Validating authentication against Hamilton BE API...") + await self.client.validate_auth() + logger.info(f"Ensuring project {self.project_id} exists...") try: - asyncio.run(self.client.project_exists(self.project_id)) + await self.client.project_exists(self.project_id) except clients.UnauthorizedException: logger.exception( f"Authentication failed. Please check your username and try again. " @@ -433,11 +442,11 @@ def __init__( f"You can do so at {self.hamilton_ui_url}/dashboard/projects" ) raise - self.dag_template_id_cache = {} - self.tracking_states = {} - self.dw_run_ids = {} - self.task_runs = {} - super().__init__() + logger.info("Initializing Hamilton tracker.") + await self.client.ainit() + logger.info("Initialized Hamilton tracker.") + self.initialized = True + return self async def post_graph_construct( self, graph: h_graph.FunctionGraph, modules: List[ModuleType], config: Dict[str, Any] @@ -476,7 +485,6 @@ async def pre_graph_execute( overrides: Dict[str, Any], ): logger.debug("pre_graph_execute %s", run_id) - self.run_id = run_id fg_id = id(graph) if fg_id in self.dag_template_id_cache: dag_template_id = self.dag_template_id_cache[fg_id] @@ -511,7 +519,7 @@ async def pre_node_execute( task_update = dict( node_template_name=node_.name, - node_name=get_node_name(node_.name, task_id), + node_name=get_node_name(node_, task_id), realized_dependencies=[dep.name for dep in node_.dependencies], status=task_run.status, start_time=task_run.start_time, @@ -519,7 +527,7 @@ async def pre_node_execute( ) await self.client.update_tasks( self.dw_run_ids[run_id], - attributes=[None], + attributes=[], task_updates=[task_update], in_samples=[task_run.is_in_sample], ) @@ -532,6 +540,7 @@ async def post_node_execute( error: Optional[Exception], result: Any, task_id: Optional[str] = None, + **future_kwargs, ): logger.debug("post_node_execute %s", run_id) task_run = self.task_runs[run_id][node_.name] diff --git a/ui/sdk/src/hamilton_sdk/api/clients.py b/ui/sdk/src/hamilton_sdk/api/clients.py index 4de4d961b..9a43a67dd 100644 --- a/ui/sdk/src/hamilton_sdk/api/clients.py +++ b/ui/sdk/src/hamilton_sdk/api/clients.py @@ -1,4 +1,5 @@ import abc +import asyncio import datetime import functools import logging @@ -30,6 +31,33 @@ def __init__(self, path: str, user: str): super().__init__(message) +def create_batch(batch: dict, dag_run_id: int): + attributes = defaultdict(list) + task_updates = defaultdict(list) + for item in batch: + if item["dag_run_id"] == dag_run_id: + for attr in item["attributes"]: + if attr is None: + continue + attributes[attr["node_name"]].append(attr) + for task_update in item["task_updates"]: + if task_update is None: + continue + task_updates[task_update["node_name"]].append(task_update) + + # We do not care about disambiguating here -- only one named attribute should be logged + + attributes_list = [] + for node_name in attributes: + attributes_list.extend(attributes[node_name]) + # in this case we do care about order so we don't send double the updates. + task_updates_list = [ + functools.reduce(lambda x, y: {**x, **y}, task_updates[node_name]) + for node_name in task_updates + ] + return attributes_list, task_updates_list + + class HamiltonClient: @abc.abstractmethod def validate_auth(self): @@ -220,30 +248,7 @@ def flush(self, batch): # group by dag_run_id -- just incase someone does something weird? dag_run_ids = set([item["dag_run_id"] for item in batch]) for dag_run_id in dag_run_ids: - attributes = defaultdict(list) - task_updates = defaultdict(list) - for item in batch: - if item["dag_run_id"] == dag_run_id: - for attr in item["attributes"]: - if attr is None: - continue - attributes[attr["node_name"]].append(attr) - for task_update in item["task_updates"]: - if task_update is None: - continue - task_updates[task_update["node_name"]].append(task_update) - - # We do not care about disambiguating here -- only one named attribute should be logged - - attributes_list = [] - for node_name in attributes: - attributes_list.extend(attributes[node_name]) - # in this case we do care about order so we don't send double the updates. - task_updates_list = [ - functools.reduce(lambda x, y: {**x, **y}, task_updates[node_name]) - for node_name in task_updates - ] - + attributes_list, task_updates_list = create_batch(batch, dag_run_id) response = requests.put( f"{self.base_url}/dag_runs_bulk?dag_run_id={dag_run_id}", json={ @@ -514,6 +519,65 @@ def __init__(self, api_key: str, username: str, h_api_url: str, base_path: str = self.api_key = api_key self.username = username self.base_url = h_api_url + base_path + self.flush_interval = 5 + self.data_queue = asyncio.Queue() + self.running = True + self.max_batch_size = 100 + + async def ainit(self): + asyncio.create_task(self.worker()) + + async def flush(self, batch): + """Flush the batch (send it to the backend or process it).""" + logger.debug(f"Flushing batch: {len(batch)}") # Replace with actual processing logic + # group by dag_run_id -- just incase someone does something weird? + dag_run_ids = set([item["dag_run_id"] for item in batch]) + for dag_run_id in dag_run_ids: + attributes_list, task_updates_list = create_batch(batch, dag_run_id) + async with aiohttp.ClientSession() as session: + async with session.put( + f"{self.base_url}/dag_runs_bulk?dag_run_id={dag_run_id}", + json={ + "attributes": make_json_safe(attributes_list), + "task_updates": make_json_safe(task_updates_list), + }, + headers=self._common_headers(), + ) as response: + try: + response.raise_for_status() + logger.debug(f"Updated tasks for DAG run {dag_run_id}") + except HTTPError: + logger.exception(f"Failed to update tasks for DAG run {dag_run_id}") + # zraise + + async def worker(self): + """Worker thread to process the queue""" + batch = [] + last_flush_time = time.time() + logger.debug("Starting worker") + while True: + logger.debug( + f"Awaiting item from queue -- current batched # of items are: {len(batch)}" + ) + try: + item = await asyncio.wait_for(self.data_queue.get(), timeout=self.flush_interval) + batch.append(item) + except asyncio.TimeoutError: + # This is fine, we just keep waiting + pass + else: + if item is None: + await self.flush(batch) + return + + # Check if batch is full or flush interval has passed + if ( + len(batch) >= self.max_batch_size + or (time.time() - last_flush_time) >= self.flush_interval + ): + await self.flush(batch) + batch = [] + last_flush_time = time.time() def _common_headers(self) -> Dict[str, Any]: """Yields the common headers for all requests. @@ -728,26 +792,14 @@ async def update_tasks( f"Updating tasks for DAG run {dag_run_id} with {len(attributes)} " f"attributes and {len(task_updates)} task updates" ) - url = f"{self.base_url}/dag_runs_bulk?dag_run_id={dag_run_id}" - headers = self._common_headers() - data = { - "attributes": make_json_safe(attributes), - "task_updates": make_json_safe(task_updates), - } - - async with aiohttp.ClientSession() as session: - async with session.put(url, json=data, headers=headers) as response: - try: - response.raise_for_status() - logger.debug(f"Updated tasks for DAG run {dag_run_id}") - except aiohttp.ClientResponseError: - logger.exception(f"Failed to update tasks for DAG run {dag_run_id}") - raise + await self.data_queue.put( + {"dag_run_id": dag_run_id, "attributes": attributes, "task_updates": task_updates} + ) async def log_dag_run_end(self, dag_run_id: int, status: str): logger.debug(f"Logging end of DAG run {dag_run_id} with status {status}") url = f"{self.base_url}/dag_runs/{dag_run_id}/" - data = (make_json_safe({"run_status": status, "run_end_time": datetime.datetime.utcnow()}),) + data = make_json_safe({"run_status": status, "run_end_time": datetime.datetime.utcnow()}) headers = self._common_headers() async with aiohttp.ClientSession() as session: async with session.put(url, json=data, headers=headers) as response: From c330c852635c143638d437c7a1afc389d1d9a3bf Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Mon, 24 Jun 2024 22:03:34 -0700 Subject: [PATCH 4/6] Fixes up the async example to use the tracker But only if the server is running --- examples/async/README.md | 24 ++++++++++ examples/async/async_module.py | 10 ++-- examples/async/fastapi_example.py | 79 ++++++++++++++++++++++++++++--- 3 files changed, 102 insertions(+), 11 deletions(-) diff --git a/examples/async/README.md b/examples/async/README.md index 140f906ce..8ef0fb490 100644 --- a/examples/async/README.md +++ b/examples/async/README.md @@ -29,6 +29,30 @@ You should get the following result: {"pipeline":{"computation1":false,"computation2":true}} ``` +## Tracking + +This has an additional endpoint that will use the async tracker if the [ui](https://hamilton.dagworks.io/en/latest/concepts/ui/) +is running on port 8241 -- see [fastapi_example.py](fastapi_example.py) for the code. +If it is not running it will proceed anyway without tracking. + +You can run it with: + +```bash +curl -X 'POST' \ + 'http://localhost:8000/execute' \ + -H 'accept: application/json' \ + -d '{}' +``` + +Recall, to get the server running, you'll have to run the following: + +```bash +pip install sf-hamilton[ui] +hamilton ui +``` + +This assumes a project (1) exists -- if you want a different one you can go the the UI and create one and/or set it in the code. + ## How it works diff --git a/examples/async/async_module.py b/examples/async/async_module.py index 33432b8c7..ad52c7a1b 100644 --- a/examples/async/async_module.py +++ b/examples/async/async_module.py @@ -24,17 +24,17 @@ def bar(request_raw: dict) -> str: return request_raw.get("bar", "baz") -async def computation1(foo: str, some_data: dict) -> bool: - await asyncio.sleep(1) - return False - - async def some_data() -> dict: async with aiohttp.ClientSession() as session: async with session.get("http://httpbin.org/get") as resp: return await resp.json() +async def computation1(foo: str, some_data: dict) -> bool: + await asyncio.sleep(1) + return False + + async def computation2(bar: str) -> bool: await asyncio.sleep(1) return True diff --git a/examples/async/fastapi_example.py b/examples/async/fastapi_example.py index 88c87ca07..bf73b0d0f 100644 --- a/examples/async/fastapi_example.py +++ b/examples/async/fastapi_example.py @@ -1,22 +1,89 @@ +import logging +from contextlib import asynccontextmanager + +import aiohttp import async_module import fastapi +from aiohttp import client_exceptions +from hamilton_sdk import adapters -from hamilton import base from hamilton.experimental import h_async -app = fastapi.FastAPI() +logger = logging.getLogger(__name__) # can instantiate a driver once for the life of the app: -dr = h_async.AsyncDriver({}, async_module, result_builder=base.DictResult()) +dr_with_tracking: h_async.AsyncDriver = None +dr_without_tracking: h_async.AsyncDriver = None + + +async def _tracking_server_running(): + """Quickly tells if the tracking server is up and running""" + async with aiohttp.ClientSession() as session: + try: + async with session.get("http://localhost:8241/api/v1/ping") as response: + if response.status == 200: + return True + else: + return False + except client_exceptions.ClientConnectionError: + return False + + +@asynccontextmanager +async def lifespan(app: fastapi.FastAPI): + """Fast API lifespan context manager for setting up the driver and tracking adapters + This has to be done async as there are initializers + """ + global dr_with_tracking + global dr_without_tracking + builder = h_async.Builder().with_modules(async_module) + is_server_running = await _tracking_server_running() + dr_without_tracking = await builder.build() + dr_with_tracking = ( + await builder.with_adapters( + adapters.AsyncHamiltonTracker( + project_id=1, + username="elijah", + dag_name="async_tracker", + ) + ).build() + if is_server_running + else None + ) + if not is_server_running: + logger.warning( + "Tracking server is not running, skipping telemetry. To run the telemetry server, run hamilton ui. " + "Note you must have a project with ID 1 if it is running -- if not, you can change the project " + "ID in this file or create a new one from the UI. Then make sure to restart this server." + ) + yield + + +app = fastapi.FastAPI(lifespan=lifespan) @app.post("/execute") -async def call(request: fastapi.Request) -> dict: - """Handler for pipeline call""" +async def call_without_tracker(request: fastapi.Request) -> dict: + """Handler for pipeline call -- this does not track in the Hamilton UI""" + input_data = {"request": request} + # Can instantiate a driver within a request as well: + # dr = h_async.AsyncDriver({}, async_module, result_builder=base.DictResult()) + result = await dr_without_tracking.execute(["pipeline"], inputs=input_data) + # dr.visualize_execution(["pipeline"], "./pipeline.dot", {"format": "png"}, inputs=input_data) + return result + + +@app.post("/execute_tracker") +async def call_with_tracker(request: fastapi.Request) -> dict: + """Handler for pipeline call -- this does track in the Hamilton UI.""" input_data = {"request": request} # Can instantiate a driver within a request as well: # dr = h_async.AsyncDriver({}, async_module, result_builder=base.DictResult()) - result = await dr.execute(["pipeline"], inputs=input_data) + if dr_with_tracking is None: + raise ValueError( + "Tracking driver not initialized -- you must have the tracking server running at app startup to use this endpoint." + ) + result = await dr_with_tracking.execute(["pipeline"], inputs=input_data) # dr.visualize_execution(["pipeline"], "./pipeline.dot", {"format": "png"}, inputs=input_data) return result From 1e7495c2df51c80e7ad13e83c30f29b874b97ca0 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Wed, 26 Jun 2024 16:49:48 -0700 Subject: [PATCH 5/6] Moves over h_async to hamiton.async_driver Its time we support this. Decisions: 1. Did not want to keep it in hamilton.driver, that's getting too bloated and there's name collision 2. Moved over docs to point to the right place (as much as I could) 3. Left it backwards compatible with a warning to update the import --- .ci/setup.sh | 5 - .circleci/config.yml | 24 - .../langchain_snippets/hamilton_async.py | 4 +- docs/reference/drivers/AsyncDriver.rst | 4 +- .../graph-adapters/AsyncGraphAdapter.rst | 2 +- examples/async/fastapi_example.py | 12 +- .../scenario_1/fastapi_server.py | 5 +- hamilton/async_driver.py | 482 +++++++++++++++++ hamilton/experimental/h_async.py | 483 +----------------- plugin_tests/h_async/__init__.py | 0 plugin_tests/h_async/conftest.py | 4 - plugin_tests/h_async/requirements-test.txt | 1 - plugin_tests/h_async/resources/__init__.py | 0 requirements-test.txt | 1 + .../resources/simple_async_module.py | 0 .../test_async_driver.py | 21 +- tests/test_telemetry.py | 9 +- 17 files changed, 514 insertions(+), 543 deletions(-) create mode 100644 hamilton/async_driver.py delete mode 100644 plugin_tests/h_async/__init__.py delete mode 100644 plugin_tests/h_async/conftest.py delete mode 100644 plugin_tests/h_async/requirements-test.txt delete mode 100644 plugin_tests/h_async/resources/__init__.py rename {plugin_tests/h_async => tests}/resources/simple_async_module.py (100%) rename plugin_tests/h_async/test_h_async.py => tests/test_async_driver.py (93%) diff --git a/.ci/setup.sh b/.ci/setup.sh index f123d62f0..9056e15d8 100755 --- a/.ci/setup.sh +++ b/.ci/setup.sh @@ -21,11 +21,6 @@ if [[ ${TASK} != "pre-commit" ]]; then -r requirements-test.txt fi -if [[ ${TASK} == "async" ]]; then - pip install \ - -r plugin_tests/h_async/requirements-test.txt -fi - if [[ ${TASK} == "pyspark" ]]; then if [[ ${OPERATING_SYSTEM} == "Linux" ]]; then sudo apt-get install \ diff --git a/.circleci/config.yml b/.circleci/config.yml index efbfe69bb..dfa860cb8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -155,27 +155,3 @@ workflows: name: integrations-py312 python-version: '3.12' task: integrations - - test: - requires: - - check_for_changes - name: asyncio-py39 - python-version: '3.9' - task: async - - test: - requires: - - check_for_changes - name: asyncio-py310 - python-version: '3.10' - task: async - - test: - requires: - - check_for_changes - name: asyncio-py311 - python-version: '3.11' - task: async - - test: - requires: - - check_for_changes - name: asyncio-py312 - python-version: '3.12' - task: async diff --git a/docs/code-comparisons/langchain_snippets/hamilton_async.py b/docs/code-comparisons/langchain_snippets/hamilton_async.py index 342631e13..e76258639 100644 --- a/docs/code-comparisons/langchain_snippets/hamilton_async.py +++ b/docs/code-comparisons/langchain_snippets/hamilton_async.py @@ -38,9 +38,9 @@ async def joke_response( import hamilton_async from hamilton import base - from hamilton.experimental import h_async + from hamilton import async_driver - dr = h_async.AsyncDriver( + dr = async_driver.AsyncDriver( {}, hamilton_async, result_builder=base.DictResult() diff --git a/docs/reference/drivers/AsyncDriver.rst b/docs/reference/drivers/AsyncDriver.rst index 01286be66..a91b0c45b 100644 --- a/docs/reference/drivers/AsyncDriver.rst +++ b/docs/reference/drivers/AsyncDriver.rst @@ -2,7 +2,7 @@ AsyncDriver ___________ Use this driver in an async context. E.g. for use with FastAPI. -.. autoclass:: hamilton.experimental.h_async.AsyncDriver +.. autoclass:: hamilton.async_driver.AsyncDriver :special-members: __init__ :members: @@ -11,6 +11,6 @@ Async Builder Builds a driver in an async context -- use ``await builder....build()``. -.. autoclass:: hamilton.experimental.h_async.Builder +.. autoclass:: hamilton.async_driver.Builder :special-members: __init__ :members: diff --git a/docs/reference/graph-adapters/AsyncGraphAdapter.rst b/docs/reference/graph-adapters/AsyncGraphAdapter.rst index a22d43853..c2f9bab0c 100644 --- a/docs/reference/graph-adapters/AsyncGraphAdapter.rst +++ b/docs/reference/graph-adapters/AsyncGraphAdapter.rst @@ -3,7 +3,7 @@ h_async.AsyncGraphAdapter ========================= -.. autoclass:: hamilton.experimental.h_async.AsyncGraphAdapter +.. autoclass:: hamilton.async_driver.AsyncGraphAdapter :special-members: __init__ :members: :inherited-members: diff --git a/examples/async/fastapi_example.py b/examples/async/fastapi_example.py index bf73b0d0f..441f67d7e 100644 --- a/examples/async/fastapi_example.py +++ b/examples/async/fastapi_example.py @@ -7,13 +7,13 @@ from aiohttp import client_exceptions from hamilton_sdk import adapters -from hamilton.experimental import h_async +from hamilton import async_driver logger = logging.getLogger(__name__) # can instantiate a driver once for the life of the app: -dr_with_tracking: h_async.AsyncDriver = None -dr_without_tracking: h_async.AsyncDriver = None +dr_with_tracking: async_driver.AsyncDriver = None +dr_without_tracking: async_driver.AsyncDriver = None async def _tracking_server_running(): @@ -36,7 +36,7 @@ async def lifespan(app: fastapi.FastAPI): """ global dr_with_tracking global dr_without_tracking - builder = h_async.Builder().with_modules(async_module) + builder = async_driver.Builder().with_modules(async_module) is_server_running = await _tracking_server_running() dr_without_tracking = await builder.build() dr_with_tracking = ( @@ -66,8 +66,6 @@ async def lifespan(app: fastapi.FastAPI): async def call_without_tracker(request: fastapi.Request) -> dict: """Handler for pipeline call -- this does not track in the Hamilton UI""" input_data = {"request": request} - # Can instantiate a driver within a request as well: - # dr = h_async.AsyncDriver({}, async_module, result_builder=base.DictResult()) result = await dr_without_tracking.execute(["pipeline"], inputs=input_data) # dr.visualize_execution(["pipeline"], "./pipeline.dot", {"format": "png"}, inputs=input_data) return result @@ -77,8 +75,6 @@ async def call_without_tracker(request: fastapi.Request) -> dict: async def call_with_tracker(request: fastapi.Request) -> dict: """Handler for pipeline call -- this does track in the Hamilton UI.""" input_data = {"request": request} - # Can instantiate a driver within a request as well: - # dr = h_async.AsyncDriver({}, async_module, result_builder=base.DictResult()) if dr_with_tracking is None: raise ValueError( "Tracking driver not initialized -- you must have the tracking server running at app startup to use this endpoint." diff --git a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/fastapi_server.py b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/fastapi_server.py index 132fd17a2..2d572d9b5 100644 --- a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/fastapi_server.py +++ b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/fastapi_server.py @@ -17,8 +17,7 @@ import pandas as pd import pydantic -from hamilton import base -from hamilton.experimental import h_async +from hamilton import async_driver, base app = fastapi.FastAPI() @@ -56,7 +55,7 @@ def fake_model_predict(df: pd.DataFrame) -> pd.Series: # We instantiate an async driver once for the life of the app. We use the AsyncDriver here because under the hood # FastAPI is async. If you were using Flask, you could use the regular Hamilton driver without issue. -dr = h_async.AsyncDriver( +dr = async_driver.AsyncDriver( {}, # no config/invariant inputs in this example. features, # the module that contains the common feature definitions. result_builder=base.SimplePythonDataFrameGraphAdapter(), diff --git a/hamilton/async_driver.py b/hamilton/async_driver.py new file mode 100644 index 000000000..ec5319ea3 --- /dev/null +++ b/hamilton/async_driver.py @@ -0,0 +1,482 @@ +import asyncio +import inspect +import logging +import sys +import time +import typing +import uuid +from types import ModuleType +from typing import Any, Dict, Optional, Tuple + +import hamilton.lifecycle.base as lifecycle_base +from hamilton import base, driver, graph, lifecycle, node, telemetry +from hamilton.execution.graph_functions import create_error_message +from hamilton.io.materialization import ExtractorFactory, MaterializerFactory + +logger = logging.getLogger(__name__) + + +async def await_dict_of_tasks(task_dict: Dict[str, typing.Awaitable]) -> Dict[str, Any]: + """Util to await a dictionary of tasks as asyncio.gather is kind of garbage""" + keys = sorted(task_dict.keys()) + coroutines = [task_dict[key] for key in keys] + coroutines_gathered = await asyncio.gather(*coroutines) + return dict(zip(keys, coroutines_gathered)) + + +async def process_value(val: Any) -> Any: + """Helper function to process the value of a potential awaitable. + This is very simple -- all it does is await the value if its not already resolved. + + :param val: Value to process. + :return: The value (awaited if it is a coroutine, raw otherwise). + """ + if not inspect.isawaitable(val): + return val + return await val + + +class AsyncGraphAdapter(lifecycle_base.BaseDoNodeExecute, lifecycle.ResultBuilder): + """Graph adapter for use with the :class:`AsyncDriver` class.""" + + def __init__( + self, + result_builder: base.ResultMixin = None, + async_lifecycle_adapters: Optional[lifecycle_base.LifecycleAdapterSet] = None, + ): + """Creates an AsyncGraphAdapter class. Note this will *only* work with the AsyncDriver class. + + Some things to note: + + 1. This executes everything at the end (recursively). E.G. the final DAG nodes are awaited + 2. This does *not* work with decorators when the async function is being decorated. That is\ + because that function is called directly within the decorator, so we cannot await it. + """ + super(AsyncGraphAdapter, self).__init__() + self.adapter = ( + async_lifecycle_adapters + if async_lifecycle_adapters is not None + else lifecycle_base.LifecycleAdapterSet() + ) + self.result_builder = result_builder if result_builder else base.PandasDataFrameResult() + self.is_initialized = False + + def do_node_execute( + self, + *, + run_id: str, + node_: node.Node, + kwargs: typing.Dict[str, typing.Any], + task_id: Optional[str] = None, + ) -> typing.Any: + """Executes a node. Note this doesn't actually execute it -- rather, it returns a task. + This does *not* use async def, as we want it to be awaited on later -- this await is done + in processing parameters of downstream functions/final results. We can ensure that as + we also run the driver that this corresponds to. + + Note that this assumes that everything is awaitable, even if it isn't. + In that case, it just wraps it in one. + + :param task_id: + :param node_: + :param run_id: + :param node: Node to wrap + :param kwargs: Keyword arguments (either coroutines or raw values) to call it with + :return: A task + """ + callabl = node_.callable + + async def new_fn(fn=callabl, **fn_kwargs): + task_dict = {key: process_value(value) for key, value in fn_kwargs.items()} + fn_kwargs = await await_dict_of_tasks(task_dict) + error = None + result = None + success = True + pre_node_execute_errored = False + try: + if self.adapter.does_hook("pre_node_execute", is_async=True): + try: + await self.adapter.call_all_lifecycle_hooks_async( + "pre_node_execute", + run_id=run_id, + node_=node_, + kwargs=fn_kwargs, + task_id=task_id, + ) + except Exception as e: + pre_node_execute_errored = True + raise e + # TODO -- consider how to use node execution methods in the future + # This is messy as it is a method called within a method... + # if self.adapter.does_method("do_node_execute", is_async=False): + # result = self.adapter.call_lifecycle_method_sync( + # "do_node_execute", + # run_id=run_id, + # node_=node_, + # kwargs=kwargs, + # task_id=task_id, + # ) + # else: + + result = ( + await fn(**fn_kwargs) if asyncio.iscoroutinefunction(fn) else fn(**fn_kwargs) + ) + except Exception as e: + success = False + error = e + step = "[pre-node-execute:async]" if pre_node_execute_errored else "" + message = create_error_message(kwargs, node_, step) + logger.exception(message) + raise + finally: + if not pre_node_execute_errored and self.adapter.does_hook( + "post_node_execute", is_async=True + ): + try: + await self.adapter.call_all_lifecycle_hooks_async( + "post_node_execute", + run_id=run_id, + node_=node_, + kwargs=fn_kwargs, + success=success, + error=error, + result=result, + task_id=task_id, + ) + except Exception: + message = create_error_message(kwargs, node_, "[post-node-execute]") + logger.exception(message) + raise + + return result + + coroutine = new_fn(**kwargs) + task = asyncio.create_task(coroutine) + return task + + def build_result(self, **outputs: Any) -> Any: + return self.result_builder.build_result(**outputs) + + +def separate_sync_from_async( + adapters: typing.List[lifecycle.LifecycleAdapter], +) -> Tuple[typing.List[lifecycle.LifecycleAdapter], typing.List[lifecycle.LifecycleAdapter]]: + """Separates the sync and async adapters from a list of adapters. + Note this only works with hooks -- we'll be dealing with methods later. + + :param adapters: List of adapters + :return: Tuple of sync adapters, async adapters + """ + + adapter_set = lifecycle_base.LifecycleAdapterSet(*adapters) + # this is using internal(ish) fields (.sync_hooks/.async_hooks) -- we should probably expose it + # For now this is OK + # Note those are dict[hook_name, list[hook]], so we have to flatten + return ( + [sync_adapter for adapters in adapter_set.sync_hooks.values() for sync_adapter in adapters], + [ + async_adapter + for adapters in adapter_set.async_hooks.values() + for async_adapter in adapters + ], + ) + + +class AsyncDriver(driver.Driver): + """Async driver. This is a driver that uses the AsyncGraphAdapter to execute the graph. + + .. code-block:: python + + dr = async_driver.AsyncDriver({}, async_module, result_builder=base.DictResult()) + df = await dr.execute([...], inputs=...) + + """ + + def __init__( + self, + config, + *modules, + result_builder: Optional[base.ResultMixin] = None, + adapters: typing.List[lifecycle.LifecycleAdapter] = None, + ): + """Instantiates an asynchronous driver. + + You will also need to call `ainit` to initialize the driver if you have any hooks/adapters. + + Note that this is not the desired API -- you should be using the :py:class:`hamilton.async_driver.Builder` class to create the driver. + + :param config: Config to build the graph + :param modules: Modules to crawl for fns/graph nodes + :param adapters: Adapters to use for lifecycle methods. + :param result_builder: Results mixin to compile the graph's final results. TBD whether this should be included in the long run. + """ + if adapters is not None: + sync_adapters, async_adapters = separate_sync_from_async(adapters) + else: + # separate out so we know what the driver + sync_adapters = [] + async_adapters = [] + # we'll need to use this in multiple contexts so we'll keep it around for later + result_builders = [adapter for adapter in adapters if isinstance(adapter, base.ResultMixin)] + if result_builder is not None: + result_builders.append(result_builder) + if len(result_builders) > 1: + raise ValueError( + "You cannot pass more than one result builder to the async driver. " + "Please pass in a single result builder" + ) + # it will be defaulted by the graph adapter + result_builder = result_builders[0] if len(result_builders) == 1 else None + super(AsyncDriver, self).__init__( + config, + *modules, + adapter=[ + # We pass in the async adapters here as this can call node-level hooks + # Otherwise we trust the driver/fn graph to call sync adapters + AsyncGraphAdapter( + result_builder=result_builder, + async_lifecycle_adapters=lifecycle_base.LifecycleAdapterSet(*async_adapters), + ), + # We pass in the sync adapters here as this can call + *sync_adapters, + *async_adapters, # note async adapters will not be called during synchronous execution -- this is for access later + ], + ) + self.initialized = False + + async def ainit(self) -> "AsyncDriver": + """Initializes the driver when using async. This only exists for backwards compatibility. + In Hamilton 2.0, we will be using an asynchronous constructor. + See https://dev.to/akarshan/asynchronous-python-magic-how-to-create-awaitable-constructors-with-asyncmixin-18j5. + """ + if self.initialized: + # this way it can be called twice + return self + if self.adapter.does_hook("post_graph_construct", is_async=True): + await self.adapter.call_all_lifecycle_hooks_async( + "post_graph_construct", + graph=self.graph, + modules=self.graph_modules, + config=self.config, + ) + await self.adapter.ainit() + self.initialized = True + return self + + async def raw_execute( + self, + final_vars: typing.List[str], + overrides: Dict[str, Any] = None, + display_graph: bool = False, # don't care + inputs: Dict[str, Any] = None, + _fn_graph: graph.FunctionGraph = None, + ) -> Dict[str, Any]: + """Executes the graph, returning a dictionary of strings (node keys) to final results. + + :param final_vars: Variables to execute (+ upstream) + :param overrides: Overrides for nodes + :param display_graph: whether or not to display graph -- this is not supported. + :param inputs: Inputs for DAG runtime calculation + :param _fn_graph: Function graph for compatibility with superclass -- unused + :return: A dict of key -> result + """ + assert _fn_graph is None, ( + "_fn_graph must not be provided " + "-- the only reason you'd do this is to use materialize(), which is not supported yet.." + ) + run_id = str(uuid.uuid4()) + nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs) + memoized_computation = dict() # memoized storage + if self.adapter.does_hook("pre_graph_execute"): + await self.adapter.call_all_lifecycle_hooks_sync_and_async( + "pre_graph_execute", + run_id=run_id, + graph=self.graph, + final_vars=final_vars, + inputs=inputs, + overrides=overrides, + ) + results = None + error = None + success = False + try: + self.graph.execute(nodes, memoized_computation, overrides, inputs, run_id=run_id) + if display_graph: + raise ValueError( + "display_graph=True is not supported for the async graph adapter. " + "Instead you should be using visualize_execution." + ) + task_dict = { + key: asyncio.create_task(process_value(memoized_computation[key])) + for key in final_vars + } + results = await await_dict_of_tasks(task_dict) + success = True + except Exception as e: + error = e + success = False + raise e + finally: + if self.adapter.does_hook("post_graph_execute", is_async=None): + await self.adapter.call_all_lifecycle_hooks_sync_and_async( + "post_graph_execute", + run_id=run_id, + graph=self.graph, + success=success, + error=error, + results=results, + ) + return results + + async def execute( + self, + final_vars: typing.List[str], + overrides: Dict[str, Any] = None, + display_graph: bool = False, + inputs: Dict[str, Any] = None, + ) -> Any: + """Executes computation. + + :param final_vars: the final list of variables we want to compute. + :param overrides: values that will override "nodes" in the DAG. + :param display_graph: DEPRECATED. Whether we want to display the graph being computed. + :param inputs: Runtime inputs to the DAG. + :return: an object consisting of the variables requested, matching the type returned by the GraphAdapter. + See constructor for how the GraphAdapter is initialized. The default one right now returns a pandas + dataframe. + """ + if display_graph: + raise ValueError( + "display_graph=True is not supported for the async graph adapter. " + "Instead you should be using visualize_execution." + ) + start_time = time.time() + run_successful = True + error = None + try: + outputs = await self.raw_execute(final_vars, overrides, display_graph, inputs=inputs) + # Currently we don't allow async build results, but we could. + if self.adapter.does_method("do_build_result", is_async=False): + return self.adapter.call_lifecycle_method_sync("do_build_result", outputs=outputs) + return outputs + except Exception as e: + run_successful = False + logger.error(driver.SLACK_ERROR_MESSAGE) + error = telemetry.sanitize_error(*sys.exc_info()) + raise e + finally: + duration = time.time() - start_time + # ensure we can capture telemetry in async friendly way. + if telemetry.is_telemetry_enabled(): + + async def make_coroutine(): + self.capture_execute_telemetry( + error, final_vars, inputs, overrides, run_successful, duration + ) + + try: + # we don't have to await because we are running within the event loop. + asyncio.create_task(make_coroutine()) + except Exception as e: + if logger.isEnabledFor(logging.DEBUG): + logger.error(f"Encountered error submitting async telemetry:\n{e}") + + def capture_constructor_telemetry( + self, + error: Optional[str], + modules: Tuple[ModuleType], + config: Dict[str, Any], + adapter: base.HamiltonGraphAdapter, + ): + """Ensures we capture constructor telemetry the right way in an async context. + + This is a simpler wrapper around what's in the driver class. + + :param error: sanitized error string, if any. + :param modules: tuple of modules to build DAG from. + :param config: config to create the driver. + :param adapter: adapter class object. + """ + if telemetry.is_telemetry_enabled(): + try: + # check whether the event loop has been started yet or not + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.run_in_executor( + None, + super(AsyncDriver, self).capture_constructor_telemetry, + error, + modules, + config, + adapter, + ) + else: + + async def make_coroutine(): + super(AsyncDriver, self).capture_constructor_telemetry( + error, modules, config, adapter + ) + + loop.run_until_complete(make_coroutine()) + except Exception as e: + if logger.isEnabledFor(logging.DEBUG): + logger.error(f"Encountered error submitting async telemetry:\n{e}") + + +class Builder(driver.Builder): + """Builder for the async driver. This is equivalent to the standard builder, but has a more limited API. + Note this does not support dynamic execution or materializers (for now). + + Here is an example of how you might use it to get the tracker working: + + .. code-block:: python + + from hamilton_sdk import tracker + tracker_async = adapters.AsyncHamiltonTracker( + project_id=1, + username="elijah", + dag_name="async_tracker", + ) + dr = ( + await async_driver + .Builder() + .with_modules(async_module) + .with_adapters(tracking_async) + .build() + ) + """ + + def __init__(self): + super(Builder, self).__init__() + + def _not_supported(self, method_name: str, additional_message: str = ""): + raise ValueError( + f"Builder().{method_name}() is not supported for the async driver. {additional_message}" + ) + + def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> "Builder": + self._not_supported("enable_dynamic_execution") + + def with_materializers( + self, *materializers: typing.Union[ExtractorFactory, MaterializerFactory] + ) -> "Builder": + self._not_supported("with_materializers") + + def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> "Builder": + self._not_supported( + "with_adapter", + "Use with_adapters instead to pass in the tracker (or other async hooks/methods)", + ) + + async def build(self): + adapter = self.adapters if self.adapters is not None else [] + if self.legacy_graph_adapter is not None: + adapter.append(self.legacy_graph_adapter) + + result_builder = ( + self.result_builder if self.result_builder is not None else base.DictResult() + ) + + return await AsyncDriver( + self.config, *self.modules, adapters=adapter, result_builder=result_builder + ).ainit() diff --git a/hamilton/experimental/h_async.py b/hamilton/experimental/h_async.py index d05535972..b803c7666 100644 --- a/hamilton/experimental/h_async.py +++ b/hamilton/experimental/h_async.py @@ -1,483 +1,12 @@ -import asyncio -import inspect import logging -import sys -import time -import typing -import uuid -from types import ModuleType -from typing import Any, Dict, Optional, Tuple -import hamilton.lifecycle.base as lifecycle_base -from hamilton import base, driver, graph, lifecycle, node, telemetry -from hamilton.execution.graph_functions import create_error_message -from hamilton.io.materialization import ExtractorFactory, MaterializerFactory +import hamilton.async_driver logger = logging.getLogger(__name__) +logger.warning( + "This module is deprecated and will be removed in Hamilton 2.0 " + "Please use `hamilton.async_driver` instead. " +) -async def await_dict_of_tasks(task_dict: Dict[str, typing.Awaitable]) -> Dict[str, Any]: - """Util to await a dictionary of tasks as asyncio.gather is kind of garbage""" - keys = sorted(task_dict.keys()) - coroutines = [task_dict[key] for key in keys] - coroutines_gathered = await asyncio.gather(*coroutines) - return dict(zip(keys, coroutines_gathered)) - - -async def process_value(val: Any) -> Any: - """Helper function to process the value of a potential awaitable. - This is very simple -- all it does is await the value if its not already resolved. - - :param val: Value to process. - :return: The value (awaited if it is a coroutine, raw otherwise). - """ - if not inspect.isawaitable(val): - return val - return await val - - -class AsyncGraphAdapter(lifecycle_base.BaseDoNodeExecute, lifecycle.ResultBuilder): - """Graph adapter for use with the :class:`AsyncDriver` class.""" - - def __init__( - self, - result_builder: base.ResultMixin = None, - async_lifecycle_adapters: Optional[lifecycle_base.LifecycleAdapterSet] = None, - ): - """Creates an AsyncGraphAdapter class. Note this will *only* work with the AsyncDriver class. - - Some things to note: - - 1. This executes everything at the end (recursively). E.G. the final DAG nodes are awaited - 2. This does *not* work with decorators when the async function is being decorated. That is\ - because that function is called directly within the decorator, so we cannot await it. - """ - super(AsyncGraphAdapter, self).__init__() - self.adapter = ( - async_lifecycle_adapters - if async_lifecycle_adapters is not None - else lifecycle_base.LifecycleAdapterSet() - ) - self.result_builder = result_builder if result_builder else base.PandasDataFrameResult() - self.is_initialized = False - - def do_node_execute( - self, - *, - run_id: str, - node_: node.Node, - kwargs: typing.Dict[str, typing.Any], - task_id: Optional[str] = None, - ) -> typing.Any: - """Executes a node. Note this doesn't actually execute it -- rather, it returns a task. - This does *not* use async def, as we want it to be awaited on later -- this await is done - in processing parameters of downstream functions/final results. We can ensure that as - we also run the driver that this corresponds to. - - Note that this assumes that everything is awaitable, even if it isn't. - In that case, it just wraps it in one. - - :param task_id: - :param node_: - :param run_id: - :param node: Node to wrap - :param kwargs: Keyword arguments (either coroutines or raw values) to call it with - :return: A task - """ - callabl = node_.callable - - async def new_fn(fn=callabl, **fn_kwargs): - task_dict = {key: process_value(value) for key, value in fn_kwargs.items()} - fn_kwargs = await await_dict_of_tasks(task_dict) - error = None - result = None - success = True - pre_node_execute_errored = False - try: - if self.adapter.does_hook("pre_node_execute", is_async=True): - try: - await self.adapter.call_all_lifecycle_hooks_async( - "pre_node_execute", - run_id=run_id, - node_=node_, - kwargs=fn_kwargs, - task_id=task_id, - ) - except Exception as e: - pre_node_execute_errored = True - raise e - # TODO -- consider how to use node execution methods in the future - # This is messy as it is a method called within a method... - # if self.adapter.does_method("do_node_execute", is_async=False): - # result = self.adapter.call_lifecycle_method_sync( - # "do_node_execute", - # run_id=run_id, - # node_=node_, - # kwargs=kwargs, - # task_id=task_id, - # ) - # else: - - result = ( - await fn(**fn_kwargs) if asyncio.iscoroutinefunction(fn) else fn(**fn_kwargs) - ) - except Exception as e: - success = False - error = e - step = "[pre-node-execute:async]" if pre_node_execute_errored else "" - message = create_error_message(kwargs, node_, step) - logger.exception(message) - raise - finally: - if not pre_node_execute_errored and self.adapter.does_hook( - "post_node_execute", is_async=True - ): - try: - await self.adapter.call_all_lifecycle_hooks_async( - "post_node_execute", - run_id=run_id, - node_=node_, - kwargs=fn_kwargs, - success=success, - error=error, - result=result, - task_id=task_id, - ) - except Exception: - message = create_error_message(kwargs, node_, "[post-node-execute]") - logger.exception(message) - raise - - return result - - coroutine = new_fn(**kwargs) - task = asyncio.create_task(coroutine) - return task - - def build_result(self, **outputs: Any) -> Any: - return self.result_builder.build_result(**outputs) - - -def separate_sync_from_async( - adapters: typing.List[lifecycle.LifecycleAdapter], -) -> Tuple[typing.List[lifecycle.LifecycleAdapter], typing.List[lifecycle.LifecycleAdapter]]: - """Separates the sync and async adapters from a list of adapters. - Note this only works with hooks -- we'll be dealing with methods later. - - :param adapters: List of adapters - :return: Tuple of sync adapters, async adapters - """ - - adapter_set = lifecycle_base.LifecycleAdapterSet(*adapters) - # this is using internal(ish) fields (.sync_hooks/.async_hooks) -- we should probably expose it - # For now this is OK - # Note those are dict[hook_name, list[hook]], so we have to flatten - return ( - [sync_adapter for adapters in adapter_set.sync_hooks.values() for sync_adapter in adapters], - [ - async_adapter - for adapters in adapter_set.async_hooks.values() - for async_adapter in adapters - ], - ) - - -class AsyncDriver(driver.Driver): - """Async driver. This is a driver that uses the AsyncGraphAdapter to execute the graph. - - .. code-block:: python - - dr = h_async.AsyncDriver({}, async_module, result_builder=base.DictResult()) - df = await dr.execute([...], inputs=...) - - """ - - def __init__( - self, - config, - *modules, - result_builder: Optional[base.ResultMixin] = None, - adapters: typing.List[lifecycle.LifecycleAdapter] = None, - ): - """Instantiates an asynchronous driver. - - You will also need to call `ainit` to initialize the driver if you have any hooks/adapters. - - Note that this is not the desired API -- you should be using the :py:class:`hamilton.experimental.h_async.Builder` class to create the driver. - - :param config: Config to build the graph - :param modules: Modules to crawl for fns/graph nodes - :param adapters: Adapters to use for lifecycle methods. - :param result_builder: Results mixin to compile the graph's final results. TBD whether this should be included in the long run. - """ - if adapters is not None: - sync_adapters, async_adapters = separate_sync_from_async(adapters) - else: - # separate out so we know what the driver - sync_adapters = [] - async_adapters = [] - # we'll need to use this in multiple contexts so we'll keep it around for later - result_builders = [adapter for adapter in adapters if isinstance(adapter, base.ResultMixin)] - if result_builder is not None: - result_builders.append(result_builder) - if len(result_builders) > 1: - raise ValueError( - "You cannot pass more than one result builder to the async driver. " - "Please pass in a single result builder" - ) - # it will be defaulted by the graph adapter - result_builder = result_builders[0] if len(result_builders) == 1 else None - super(AsyncDriver, self).__init__( - config, - *modules, - adapter=[ - # We pass in the async adapters here as this can call node-level hooks - # Otherwise we trust the driver/fn graph to call sync adapters - AsyncGraphAdapter( - result_builder=result_builder, - async_lifecycle_adapters=lifecycle_base.LifecycleAdapterSet(*async_adapters), - ), - # We pass in the sync adapters here as this can call - *sync_adapters, - *async_adapters, # note async adapters will not be called during synchronous execution -- this is for access later - ], - ) - self.initialized = False - - async def ainit(self) -> "AsyncDriver": - """Initializes the driver when using async. This only exists for backwards compatibility. - In Hamilton 2.0, we will be using an asynchronous constructor. - See https://dev.to/akarshan/asynchronous-python-magic-how-to-create-awaitable-constructors-with-asyncmixin-18j5. - """ - if self.initialized: - # this way it can be called twice - return self - if self.adapter.does_hook("post_graph_construct", is_async=True): - await self.adapter.call_all_lifecycle_hooks_async( - "post_graph_construct", - graph=self.graph, - modules=self.graph_modules, - config=self.config, - ) - await self.adapter.ainit() - self.initialized = True - return self - - async def raw_execute( - self, - final_vars: typing.List[str], - overrides: Dict[str, Any] = None, - display_graph: bool = False, # don't care - inputs: Dict[str, Any] = None, - _fn_graph: graph.FunctionGraph = None, - ) -> Dict[str, Any]: - """Executes the graph, returning a dictionary of strings (node keys) to final results. - - :param final_vars: Variables to execute (+ upstream) - :param overrides: Overrides for nodes - :param display_graph: whether or not to display graph -- this is not supported. - :param inputs: Inputs for DAG runtime calculation - :param _fn_graph: Function graph for compatibility with superclass -- unused - :return: A dict of key -> result - """ - assert _fn_graph is None, ( - "_fn_graph must not be provided " - "-- the only reason you'd do this is to use materialize(), which is not supported yet.." - ) - run_id = str(uuid.uuid4()) - nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs) - memoized_computation = dict() # memoized storage - if self.adapter.does_hook("pre_graph_execute"): - await self.adapter.call_all_lifecycle_hooks_sync_and_async( - "pre_graph_execute", - run_id=run_id, - graph=self.graph, - final_vars=final_vars, - inputs=inputs, - overrides=overrides, - ) - results = None - error = None - success = False - try: - self.graph.execute(nodes, memoized_computation, overrides, inputs, run_id=run_id) - if display_graph: - raise ValueError( - "display_graph=True is not supported for the async graph adapter. " - "Instead you should be using visualize_execution." - ) - task_dict = { - key: asyncio.create_task(process_value(memoized_computation[key])) - for key in final_vars - } - results = await await_dict_of_tasks(task_dict) - success = True - except Exception as e: - error = e - success = False - raise e - finally: - if self.adapter.does_hook("post_graph_execute", is_async=None): - await self.adapter.call_all_lifecycle_hooks_sync_and_async( - "post_graph_execute", - run_id=run_id, - graph=self.graph, - success=success, - error=error, - results=results, - ) - return results - - async def execute( - self, - final_vars: typing.List[str], - overrides: Dict[str, Any] = None, - display_graph: bool = False, - inputs: Dict[str, Any] = None, - ) -> Any: - """Executes computation. - - :param final_vars: the final list of variables we want to compute. - :param overrides: values that will override "nodes" in the DAG. - :param display_graph: DEPRECATED. Whether we want to display the graph being computed. - :param inputs: Runtime inputs to the DAG. - :return: an object consisting of the variables requested, matching the type returned by the GraphAdapter. - See constructor for how the GraphAdapter is initialized. The default one right now returns a pandas - dataframe. - """ - if display_graph: - raise ValueError( - "display_graph=True is not supported for the async graph adapter. " - "Instead you should be using visualize_execution." - ) - start_time = time.time() - run_successful = True - error = None - try: - outputs = await self.raw_execute(final_vars, overrides, display_graph, inputs=inputs) - # Currently we don't allow async build results, but we could. - if self.adapter.does_method("do_build_result", is_async=False): - return self.adapter.call_lifecycle_method_sync("do_build_result", outputs=outputs) - return outputs - except Exception as e: - run_successful = False - logger.error(driver.SLACK_ERROR_MESSAGE) - error = telemetry.sanitize_error(*sys.exc_info()) - raise e - finally: - duration = time.time() - start_time - # ensure we can capture telemetry in async friendly way. - if telemetry.is_telemetry_enabled(): - - async def make_coroutine(): - self.capture_execute_telemetry( - error, final_vars, inputs, overrides, run_successful, duration - ) - - try: - # we don't have to await because we are running within the event loop. - asyncio.create_task(make_coroutine()) - except Exception as e: - if logger.isEnabledFor(logging.DEBUG): - logger.error(f"Encountered error submitting async telemetry:\n{e}") - - def capture_constructor_telemetry( - self, - error: Optional[str], - modules: Tuple[ModuleType], - config: Dict[str, Any], - adapter: base.HamiltonGraphAdapter, - ): - """Ensures we capture constructor telemetry the right way in an async context. - - This is a simpler wrapper around what's in the driver class. - - :param error: sanitized error string, if any. - :param modules: tuple of modules to build DAG from. - :param config: config to create the driver. - :param adapter: adapter class object. - """ - if telemetry.is_telemetry_enabled(): - try: - # check whether the event loop has been started yet or not - loop = asyncio.get_event_loop() - if loop.is_running(): - loop.run_in_executor( - None, - super(AsyncDriver, self).capture_constructor_telemetry, - error, - modules, - config, - adapter, - ) - else: - - async def make_coroutine(): - super(AsyncDriver, self).capture_constructor_telemetry( - error, modules, config, adapter - ) - - loop.run_until_complete(make_coroutine()) - except Exception as e: - if logger.isEnabledFor(logging.DEBUG): - logger.error(f"Encountered error submitting async telemetry:\n{e}") - - -class Builder(driver.Builder): - """Builder for the async driver. This is equivalent to the standard builder, but has a more limited API. - Note this does not support dynamic execution or materializers (for now). - - Here is an example of how you might use it to get the tracker working: - - .. code-block:: python - - from hamilton_sdk import tracker - tracker_async = adapters.AsyncHamiltonTracker( - project_id=1, - username="elijah", - dag_name="async_tracker", - ) - dr = ( - await h_async - .Builder() - .with_modules(async_module) - .with_adapters(tracking_async) - .build() - ) - """ - - def __init__(self): - super(Builder, self).__init__() - - def _not_supported(self, method_name: str, additional_message: str = ""): - raise ValueError( - f"Builder().{method_name}() is not supported for the async driver. {additional_message}" - ) - - def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> "Builder": - self._not_supported("enable_dynamic_execution") - - def with_materializers( - self, *materializers: typing.Union[ExtractorFactory, MaterializerFactory] - ) -> "Builder": - self._not_supported("with_materializers") - - def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> "Builder": - self._not_supported( - "with_adapter", - "Use with_adapters instead to pass in the tracker (or other async hooks/methods)", - ) - - async def build(self): - adapter = self.adapters if self.adapters is not None else [] - if self.legacy_graph_adapter is not None: - adapter.append(self.legacy_graph_adapter) - - result_builder = self.result_builder if self.result_builder is not None else base.DictResult() - - return await AsyncDriver( - self.config, - *self.modules, - adapters=adapter, - result_builder=result_builder - ).ainit() +AsyncDriver = hamilton.async_driver.AsyncDriver diff --git a/plugin_tests/h_async/__init__.py b/plugin_tests/h_async/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/plugin_tests/h_async/conftest.py b/plugin_tests/h_async/conftest.py deleted file mode 100644 index bc5ef5b5a..000000000 --- a/plugin_tests/h_async/conftest.py +++ /dev/null @@ -1,4 +0,0 @@ -from hamilton import telemetry - -# disable telemetry for all tests! -telemetry.disable_telemetry() diff --git a/plugin_tests/h_async/requirements-test.txt b/plugin_tests/h_async/requirements-test.txt deleted file mode 100644 index 2d73dba5b..000000000 --- a/plugin_tests/h_async/requirements-test.txt +++ /dev/null @@ -1 +0,0 @@ -pytest-asyncio diff --git a/plugin_tests/h_async/resources/__init__.py b/plugin_tests/h_async/resources/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/requirements-test.txt b/requirements-test.txt index 9d79f161d..d7eed9c6f 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -22,6 +22,7 @@ polars pyarrow pyreadstat # for SPSS data loader pytest +pytest-asyncio pytest-cov PyYAML scikit-learn diff --git a/plugin_tests/h_async/resources/simple_async_module.py b/tests/resources/simple_async_module.py similarity index 100% rename from plugin_tests/h_async/resources/simple_async_module.py rename to tests/resources/simple_async_module.py diff --git a/plugin_tests/h_async/test_h_async.py b/tests/test_async_driver.py similarity index 93% rename from plugin_tests/h_async/test_h_async.py rename to tests/test_async_driver.py index 0c0626755..4f840ba1b 100644 --- a/plugin_tests/h_async/test_h_async.py +++ b/tests/test_async_driver.py @@ -4,8 +4,7 @@ import pandas as pd import pytest -from hamilton import base -from hamilton.experimental import h_async +from hamilton import async_driver, base from hamilton.lifecycle.base import ( BasePostGraphConstruct, BasePostGraphConstructAsync, @@ -30,36 +29,36 @@ async def async_identity(n: int) -> int: @pytest.mark.asyncio async def test_await_dict_of_coroutines(): tasks = {n: async_identity(n) for n in range(0, 10)} - results = await h_async.await_dict_of_tasks(tasks) + results = await async_driver.await_dict_of_tasks(tasks) assert results == {n: await async_identity(n) for n in range(0, 10)} @pytest.mark.asyncio async def test_await_dict_of_tasks(): tasks = {n: asyncio.create_task(async_identity(n)) for n in range(0, 10)} - results = await h_async.await_dict_of_tasks(tasks) + results = await async_driver.await_dict_of_tasks(tasks) assert results == {n: await async_identity(n) for n in range(0, 10)} # The following are not parameterized as we need to use the event loop -- fixtures will complicate this @pytest.mark.asyncio async def test_process_value_raw(): - assert await h_async.process_value(1) == 1 + assert await async_driver.process_value(1) == 1 @pytest.mark.asyncio async def test_process_value_coroutine(): - assert await h_async.process_value(async_identity(1)) == 1 + assert await async_driver.process_value(async_identity(1)) == 1 @pytest.mark.asyncio async def test_process_value_task(): - assert await h_async.process_value(asyncio.create_task(async_identity(1))) == 1 + assert await async_driver.process_value(asyncio.create_task(async_identity(1))) == 1 @pytest.mark.asyncio async def test_driver_end_to_end(): - dr = h_async.AsyncDriver({}, simple_async_module) + dr = async_driver.AsyncDriver({}, simple_async_module) all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] result = await dr.raw_execute(final_vars=all_vars, inputs={"external_input": 1}) result["a"] = result["a"].to_dict() # convert to dict for comparison @@ -85,7 +84,7 @@ async def test_driver_end_to_end(): @mock.patch("hamilton.telemetry.send_event_json") @mock.patch("hamilton.telemetry.g_telemetry_enabled", True) async def test_driver_end_to_end_telemetry(send_event_json): - dr = h_async.AsyncDriver({}, simple_async_module, result_builder=base.DictResult()) + dr = async_driver.AsyncDriver({}, simple_async_module, result_builder=base.DictResult()) with mock.patch("hamilton.telemetry.g_telemetry_enabled", False): # don't count this telemetry tracking invocation all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] @@ -156,7 +155,7 @@ async def post_graph_construct(self, **kwargs): adapter = AsyncTrackingAdapter(tracked_calls) - dr = await h_async.AsyncDriver( + dr = await async_driver.AsyncDriver( {}, simple_async_module, result_builder=base.DictResult(), adapters=[adapter] ).ainit() all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] @@ -220,7 +219,7 @@ def post_graph_construct(self, **kwargs): adapter = AsyncTrackingAdapter(tracked_calls) - dr = await h_async.AsyncDriver( + dr = await async_driver.AsyncDriver( {}, simple_async_module, result_builder=base.DictResult(), adapters=[adapter] ).ainit() all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index e42c51b93..1f7303637 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -7,8 +7,7 @@ import pytest -from hamilton import base, node, telemetry -from hamilton.experimental import h_async +from hamilton import async_driver, base, node, telemetry from hamilton.lifecycle import base as lifecycle_base @@ -163,8 +162,8 @@ class CustomResultBuilder(base.ResultMixin): "hamilton.base.DefaultAdapter", ), ( - h_async.AsyncGraphAdapter(base.DictResult()), - "hamilton.experimental.h_async.AsyncGraphAdapter", + async_driver.AsyncGraphAdapter(base.DictResult()), + "hamilton.experimental.async_driver.AsyncGraphAdapter", ), (CustomAdapter(base.DictResult()), "custom_adapter"), ], @@ -189,7 +188,7 @@ def test_get_adapter_name(adapter, expected): "hamilton.base.StrictIndexTypePandasDataFrameResult", ), (base.SimplePythonGraphAdapter(CustomResultBuilder()), "custom_builder"), - (h_async.AsyncGraphAdapter(base.DictResult()), "hamilton.base.DictResult"), + (async_driver.AsyncGraphAdapter(base.DictResult()), "hamilton.base.DictResult"), (CustomAdapter(base.DictResult()), "hamilton.base.DictResult"), (CustomAdapter(CustomResultBuilder()), "custom_builder"), ], From c969458204d6dd5ef665e340069946388eb29efb Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Wed, 26 Jun 2024 17:00:15 -0700 Subject: [PATCH 6/6] Fixes regression in which adapters was none, instead of an empty list --- hamilton/async_driver.py | 14 ++++++++------ tests/test_telemetry.py | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/hamilton/async_driver.py b/hamilton/async_driver.py index ec5319ea3..3471b7c73 100644 --- a/hamilton/async_driver.py +++ b/hamilton/async_driver.py @@ -205,18 +205,20 @@ def __init__( Note that this is not the desired API -- you should be using the :py:class:`hamilton.async_driver.Builder` class to create the driver. + This will only (currently) work properly with asynchronous lifecycle hooks, and does not support methods or validators. + You can still pass in synchronous lifecycle hooks, but they may behave strangely. + :param config: Config to build the graph :param modules: Modules to crawl for fns/graph nodes :param adapters: Adapters to use for lifecycle methods. :param result_builder: Results mixin to compile the graph's final results. TBD whether this should be included in the long run. """ - if adapters is not None: - sync_adapters, async_adapters = separate_sync_from_async(adapters) - else: - # separate out so we know what the driver - sync_adapters = [] - async_adapters = [] + if adapters is None: + adapters = [] + sync_adapters, async_adapters = separate_sync_from_async(adapters) + # we'll need to use this in multiple contexts so we'll keep it around for later + result_builders = [adapter for adapter in adapters if isinstance(adapter, base.ResultMixin)] if result_builder is not None: result_builders.append(result_builder) diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index 1f7303637..101ae29a6 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -163,7 +163,7 @@ class CustomResultBuilder(base.ResultMixin): ), ( async_driver.AsyncGraphAdapter(base.DictResult()), - "hamilton.experimental.async_driver.AsyncGraphAdapter", + "hamilton.async_driver.AsyncGraphAdapter", ), (CustomAdapter(base.DictResult()), "custom_adapter"), ],