From 1e7495c2df51c80e7ad13e83c30f29b874b97ca0 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Wed, 26 Jun 2024 16:49:48 -0700 Subject: [PATCH] Moves over h_async to hamiton.async_driver Its time we support this. Decisions: 1. Did not want to keep it in hamilton.driver, that's getting too bloated and there's name collision 2. Moved over docs to point to the right place (as much as I could) 3. Left it backwards compatible with a warning to update the import --- .ci/setup.sh | 5 - .circleci/config.yml | 24 - .../langchain_snippets/hamilton_async.py | 4 +- docs/reference/drivers/AsyncDriver.rst | 4 +- .../graph-adapters/AsyncGraphAdapter.rst | 2 +- examples/async/fastapi_example.py | 12 +- .../scenario_1/fastapi_server.py | 5 +- hamilton/async_driver.py | 482 +++++++++++++++++ hamilton/experimental/h_async.py | 483 +----------------- plugin_tests/h_async/__init__.py | 0 plugin_tests/h_async/conftest.py | 4 - plugin_tests/h_async/requirements-test.txt | 1 - plugin_tests/h_async/resources/__init__.py | 0 requirements-test.txt | 1 + .../resources/simple_async_module.py | 0 .../test_async_driver.py | 21 +- tests/test_telemetry.py | 9 +- 17 files changed, 514 insertions(+), 543 deletions(-) create mode 100644 hamilton/async_driver.py delete mode 100644 plugin_tests/h_async/__init__.py delete mode 100644 plugin_tests/h_async/conftest.py delete mode 100644 plugin_tests/h_async/requirements-test.txt delete mode 100644 plugin_tests/h_async/resources/__init__.py rename {plugin_tests/h_async => tests}/resources/simple_async_module.py (100%) rename plugin_tests/h_async/test_h_async.py => tests/test_async_driver.py (93%) diff --git a/.ci/setup.sh b/.ci/setup.sh index f123d62f0..9056e15d8 100755 --- a/.ci/setup.sh +++ b/.ci/setup.sh @@ -21,11 +21,6 @@ if [[ ${TASK} != "pre-commit" ]]; then -r requirements-test.txt fi -if [[ ${TASK} == "async" ]]; then - pip install \ - -r plugin_tests/h_async/requirements-test.txt -fi - if [[ ${TASK} == "pyspark" ]]; then if [[ ${OPERATING_SYSTEM} == "Linux" ]]; then sudo apt-get install \ diff --git a/.circleci/config.yml b/.circleci/config.yml index efbfe69bb..dfa860cb8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -155,27 +155,3 @@ workflows: name: integrations-py312 python-version: '3.12' task: integrations - - test: - requires: - - check_for_changes - name: asyncio-py39 - python-version: '3.9' - task: async - - test: - requires: - - check_for_changes - name: asyncio-py310 - python-version: '3.10' - task: async - - test: - requires: - - check_for_changes - name: asyncio-py311 - python-version: '3.11' - task: async - - test: - requires: - - check_for_changes - name: asyncio-py312 - python-version: '3.12' - task: async diff --git a/docs/code-comparisons/langchain_snippets/hamilton_async.py b/docs/code-comparisons/langchain_snippets/hamilton_async.py index 342631e13..e76258639 100644 --- a/docs/code-comparisons/langchain_snippets/hamilton_async.py +++ b/docs/code-comparisons/langchain_snippets/hamilton_async.py @@ -38,9 +38,9 @@ async def joke_response( import hamilton_async from hamilton import base - from hamilton.experimental import h_async + from hamilton import async_driver - dr = h_async.AsyncDriver( + dr = async_driver.AsyncDriver( {}, hamilton_async, result_builder=base.DictResult() diff --git a/docs/reference/drivers/AsyncDriver.rst b/docs/reference/drivers/AsyncDriver.rst index 01286be66..a91b0c45b 100644 --- a/docs/reference/drivers/AsyncDriver.rst +++ b/docs/reference/drivers/AsyncDriver.rst @@ -2,7 +2,7 @@ AsyncDriver ___________ Use this driver in an async context. E.g. for use with FastAPI. -.. autoclass:: hamilton.experimental.h_async.AsyncDriver +.. autoclass:: hamilton.async_driver.AsyncDriver :special-members: __init__ :members: @@ -11,6 +11,6 @@ Async Builder Builds a driver in an async context -- use ``await builder....build()``. -.. autoclass:: hamilton.experimental.h_async.Builder +.. autoclass:: hamilton.async_driver.Builder :special-members: __init__ :members: diff --git a/docs/reference/graph-adapters/AsyncGraphAdapter.rst b/docs/reference/graph-adapters/AsyncGraphAdapter.rst index a22d43853..c2f9bab0c 100644 --- a/docs/reference/graph-adapters/AsyncGraphAdapter.rst +++ b/docs/reference/graph-adapters/AsyncGraphAdapter.rst @@ -3,7 +3,7 @@ h_async.AsyncGraphAdapter ========================= -.. autoclass:: hamilton.experimental.h_async.AsyncGraphAdapter +.. autoclass:: hamilton.async_driver.AsyncGraphAdapter :special-members: __init__ :members: :inherited-members: diff --git a/examples/async/fastapi_example.py b/examples/async/fastapi_example.py index bf73b0d0f..441f67d7e 100644 --- a/examples/async/fastapi_example.py +++ b/examples/async/fastapi_example.py @@ -7,13 +7,13 @@ from aiohttp import client_exceptions from hamilton_sdk import adapters -from hamilton.experimental import h_async +from hamilton import async_driver logger = logging.getLogger(__name__) # can instantiate a driver once for the life of the app: -dr_with_tracking: h_async.AsyncDriver = None -dr_without_tracking: h_async.AsyncDriver = None +dr_with_tracking: async_driver.AsyncDriver = None +dr_without_tracking: async_driver.AsyncDriver = None async def _tracking_server_running(): @@ -36,7 +36,7 @@ async def lifespan(app: fastapi.FastAPI): """ global dr_with_tracking global dr_without_tracking - builder = h_async.Builder().with_modules(async_module) + builder = async_driver.Builder().with_modules(async_module) is_server_running = await _tracking_server_running() dr_without_tracking = await builder.build() dr_with_tracking = ( @@ -66,8 +66,6 @@ async def lifespan(app: fastapi.FastAPI): async def call_without_tracker(request: fastapi.Request) -> dict: """Handler for pipeline call -- this does not track in the Hamilton UI""" input_data = {"request": request} - # Can instantiate a driver within a request as well: - # dr = h_async.AsyncDriver({}, async_module, result_builder=base.DictResult()) result = await dr_without_tracking.execute(["pipeline"], inputs=input_data) # dr.visualize_execution(["pipeline"], "./pipeline.dot", {"format": "png"}, inputs=input_data) return result @@ -77,8 +75,6 @@ async def call_without_tracker(request: fastapi.Request) -> dict: async def call_with_tracker(request: fastapi.Request) -> dict: """Handler for pipeline call -- this does track in the Hamilton UI.""" input_data = {"request": request} - # Can instantiate a driver within a request as well: - # dr = h_async.AsyncDriver({}, async_module, result_builder=base.DictResult()) if dr_with_tracking is None: raise ValueError( "Tracking driver not initialized -- you must have the tracking server running at app startup to use this endpoint." diff --git a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/fastapi_server.py b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/fastapi_server.py index 132fd17a2..2d572d9b5 100644 --- a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/fastapi_server.py +++ b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/fastapi_server.py @@ -17,8 +17,7 @@ import pandas as pd import pydantic -from hamilton import base -from hamilton.experimental import h_async +from hamilton import async_driver, base app = fastapi.FastAPI() @@ -56,7 +55,7 @@ def fake_model_predict(df: pd.DataFrame) -> pd.Series: # We instantiate an async driver once for the life of the app. We use the AsyncDriver here because under the hood # FastAPI is async. If you were using Flask, you could use the regular Hamilton driver without issue. -dr = h_async.AsyncDriver( +dr = async_driver.AsyncDriver( {}, # no config/invariant inputs in this example. features, # the module that contains the common feature definitions. result_builder=base.SimplePythonDataFrameGraphAdapter(), diff --git a/hamilton/async_driver.py b/hamilton/async_driver.py new file mode 100644 index 000000000..ec5319ea3 --- /dev/null +++ b/hamilton/async_driver.py @@ -0,0 +1,482 @@ +import asyncio +import inspect +import logging +import sys +import time +import typing +import uuid +from types import ModuleType +from typing import Any, Dict, Optional, Tuple + +import hamilton.lifecycle.base as lifecycle_base +from hamilton import base, driver, graph, lifecycle, node, telemetry +from hamilton.execution.graph_functions import create_error_message +from hamilton.io.materialization import ExtractorFactory, MaterializerFactory + +logger = logging.getLogger(__name__) + + +async def await_dict_of_tasks(task_dict: Dict[str, typing.Awaitable]) -> Dict[str, Any]: + """Util to await a dictionary of tasks as asyncio.gather is kind of garbage""" + keys = sorted(task_dict.keys()) + coroutines = [task_dict[key] for key in keys] + coroutines_gathered = await asyncio.gather(*coroutines) + return dict(zip(keys, coroutines_gathered)) + + +async def process_value(val: Any) -> Any: + """Helper function to process the value of a potential awaitable. + This is very simple -- all it does is await the value if its not already resolved. + + :param val: Value to process. + :return: The value (awaited if it is a coroutine, raw otherwise). + """ + if not inspect.isawaitable(val): + return val + return await val + + +class AsyncGraphAdapter(lifecycle_base.BaseDoNodeExecute, lifecycle.ResultBuilder): + """Graph adapter for use with the :class:`AsyncDriver` class.""" + + def __init__( + self, + result_builder: base.ResultMixin = None, + async_lifecycle_adapters: Optional[lifecycle_base.LifecycleAdapterSet] = None, + ): + """Creates an AsyncGraphAdapter class. Note this will *only* work with the AsyncDriver class. + + Some things to note: + + 1. This executes everything at the end (recursively). E.G. the final DAG nodes are awaited + 2. This does *not* work with decorators when the async function is being decorated. That is\ + because that function is called directly within the decorator, so we cannot await it. + """ + super(AsyncGraphAdapter, self).__init__() + self.adapter = ( + async_lifecycle_adapters + if async_lifecycle_adapters is not None + else lifecycle_base.LifecycleAdapterSet() + ) + self.result_builder = result_builder if result_builder else base.PandasDataFrameResult() + self.is_initialized = False + + def do_node_execute( + self, + *, + run_id: str, + node_: node.Node, + kwargs: typing.Dict[str, typing.Any], + task_id: Optional[str] = None, + ) -> typing.Any: + """Executes a node. Note this doesn't actually execute it -- rather, it returns a task. + This does *not* use async def, as we want it to be awaited on later -- this await is done + in processing parameters of downstream functions/final results. We can ensure that as + we also run the driver that this corresponds to. + + Note that this assumes that everything is awaitable, even if it isn't. + In that case, it just wraps it in one. + + :param task_id: + :param node_: + :param run_id: + :param node: Node to wrap + :param kwargs: Keyword arguments (either coroutines or raw values) to call it with + :return: A task + """ + callabl = node_.callable + + async def new_fn(fn=callabl, **fn_kwargs): + task_dict = {key: process_value(value) for key, value in fn_kwargs.items()} + fn_kwargs = await await_dict_of_tasks(task_dict) + error = None + result = None + success = True + pre_node_execute_errored = False + try: + if self.adapter.does_hook("pre_node_execute", is_async=True): + try: + await self.adapter.call_all_lifecycle_hooks_async( + "pre_node_execute", + run_id=run_id, + node_=node_, + kwargs=fn_kwargs, + task_id=task_id, + ) + except Exception as e: + pre_node_execute_errored = True + raise e + # TODO -- consider how to use node execution methods in the future + # This is messy as it is a method called within a method... + # if self.adapter.does_method("do_node_execute", is_async=False): + # result = self.adapter.call_lifecycle_method_sync( + # "do_node_execute", + # run_id=run_id, + # node_=node_, + # kwargs=kwargs, + # task_id=task_id, + # ) + # else: + + result = ( + await fn(**fn_kwargs) if asyncio.iscoroutinefunction(fn) else fn(**fn_kwargs) + ) + except Exception as e: + success = False + error = e + step = "[pre-node-execute:async]" if pre_node_execute_errored else "" + message = create_error_message(kwargs, node_, step) + logger.exception(message) + raise + finally: + if not pre_node_execute_errored and self.adapter.does_hook( + "post_node_execute", is_async=True + ): + try: + await self.adapter.call_all_lifecycle_hooks_async( + "post_node_execute", + run_id=run_id, + node_=node_, + kwargs=fn_kwargs, + success=success, + error=error, + result=result, + task_id=task_id, + ) + except Exception: + message = create_error_message(kwargs, node_, "[post-node-execute]") + logger.exception(message) + raise + + return result + + coroutine = new_fn(**kwargs) + task = asyncio.create_task(coroutine) + return task + + def build_result(self, **outputs: Any) -> Any: + return self.result_builder.build_result(**outputs) + + +def separate_sync_from_async( + adapters: typing.List[lifecycle.LifecycleAdapter], +) -> Tuple[typing.List[lifecycle.LifecycleAdapter], typing.List[lifecycle.LifecycleAdapter]]: + """Separates the sync and async adapters from a list of adapters. + Note this only works with hooks -- we'll be dealing with methods later. + + :param adapters: List of adapters + :return: Tuple of sync adapters, async adapters + """ + + adapter_set = lifecycle_base.LifecycleAdapterSet(*adapters) + # this is using internal(ish) fields (.sync_hooks/.async_hooks) -- we should probably expose it + # For now this is OK + # Note those are dict[hook_name, list[hook]], so we have to flatten + return ( + [sync_adapter for adapters in adapter_set.sync_hooks.values() for sync_adapter in adapters], + [ + async_adapter + for adapters in adapter_set.async_hooks.values() + for async_adapter in adapters + ], + ) + + +class AsyncDriver(driver.Driver): + """Async driver. This is a driver that uses the AsyncGraphAdapter to execute the graph. + + .. code-block:: python + + dr = async_driver.AsyncDriver({}, async_module, result_builder=base.DictResult()) + df = await dr.execute([...], inputs=...) + + """ + + def __init__( + self, + config, + *modules, + result_builder: Optional[base.ResultMixin] = None, + adapters: typing.List[lifecycle.LifecycleAdapter] = None, + ): + """Instantiates an asynchronous driver. + + You will also need to call `ainit` to initialize the driver if you have any hooks/adapters. + + Note that this is not the desired API -- you should be using the :py:class:`hamilton.async_driver.Builder` class to create the driver. + + :param config: Config to build the graph + :param modules: Modules to crawl for fns/graph nodes + :param adapters: Adapters to use for lifecycle methods. + :param result_builder: Results mixin to compile the graph's final results. TBD whether this should be included in the long run. + """ + if adapters is not None: + sync_adapters, async_adapters = separate_sync_from_async(adapters) + else: + # separate out so we know what the driver + sync_adapters = [] + async_adapters = [] + # we'll need to use this in multiple contexts so we'll keep it around for later + result_builders = [adapter for adapter in adapters if isinstance(adapter, base.ResultMixin)] + if result_builder is not None: + result_builders.append(result_builder) + if len(result_builders) > 1: + raise ValueError( + "You cannot pass more than one result builder to the async driver. " + "Please pass in a single result builder" + ) + # it will be defaulted by the graph adapter + result_builder = result_builders[0] if len(result_builders) == 1 else None + super(AsyncDriver, self).__init__( + config, + *modules, + adapter=[ + # We pass in the async adapters here as this can call node-level hooks + # Otherwise we trust the driver/fn graph to call sync adapters + AsyncGraphAdapter( + result_builder=result_builder, + async_lifecycle_adapters=lifecycle_base.LifecycleAdapterSet(*async_adapters), + ), + # We pass in the sync adapters here as this can call + *sync_adapters, + *async_adapters, # note async adapters will not be called during synchronous execution -- this is for access later + ], + ) + self.initialized = False + + async def ainit(self) -> "AsyncDriver": + """Initializes the driver when using async. This only exists for backwards compatibility. + In Hamilton 2.0, we will be using an asynchronous constructor. + See https://dev.to/akarshan/asynchronous-python-magic-how-to-create-awaitable-constructors-with-asyncmixin-18j5. + """ + if self.initialized: + # this way it can be called twice + return self + if self.adapter.does_hook("post_graph_construct", is_async=True): + await self.adapter.call_all_lifecycle_hooks_async( + "post_graph_construct", + graph=self.graph, + modules=self.graph_modules, + config=self.config, + ) + await self.adapter.ainit() + self.initialized = True + return self + + async def raw_execute( + self, + final_vars: typing.List[str], + overrides: Dict[str, Any] = None, + display_graph: bool = False, # don't care + inputs: Dict[str, Any] = None, + _fn_graph: graph.FunctionGraph = None, + ) -> Dict[str, Any]: + """Executes the graph, returning a dictionary of strings (node keys) to final results. + + :param final_vars: Variables to execute (+ upstream) + :param overrides: Overrides for nodes + :param display_graph: whether or not to display graph -- this is not supported. + :param inputs: Inputs for DAG runtime calculation + :param _fn_graph: Function graph for compatibility with superclass -- unused + :return: A dict of key -> result + """ + assert _fn_graph is None, ( + "_fn_graph must not be provided " + "-- the only reason you'd do this is to use materialize(), which is not supported yet.." + ) + run_id = str(uuid.uuid4()) + nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs) + memoized_computation = dict() # memoized storage + if self.adapter.does_hook("pre_graph_execute"): + await self.adapter.call_all_lifecycle_hooks_sync_and_async( + "pre_graph_execute", + run_id=run_id, + graph=self.graph, + final_vars=final_vars, + inputs=inputs, + overrides=overrides, + ) + results = None + error = None + success = False + try: + self.graph.execute(nodes, memoized_computation, overrides, inputs, run_id=run_id) + if display_graph: + raise ValueError( + "display_graph=True is not supported for the async graph adapter. " + "Instead you should be using visualize_execution." + ) + task_dict = { + key: asyncio.create_task(process_value(memoized_computation[key])) + for key in final_vars + } + results = await await_dict_of_tasks(task_dict) + success = True + except Exception as e: + error = e + success = False + raise e + finally: + if self.adapter.does_hook("post_graph_execute", is_async=None): + await self.adapter.call_all_lifecycle_hooks_sync_and_async( + "post_graph_execute", + run_id=run_id, + graph=self.graph, + success=success, + error=error, + results=results, + ) + return results + + async def execute( + self, + final_vars: typing.List[str], + overrides: Dict[str, Any] = None, + display_graph: bool = False, + inputs: Dict[str, Any] = None, + ) -> Any: + """Executes computation. + + :param final_vars: the final list of variables we want to compute. + :param overrides: values that will override "nodes" in the DAG. + :param display_graph: DEPRECATED. Whether we want to display the graph being computed. + :param inputs: Runtime inputs to the DAG. + :return: an object consisting of the variables requested, matching the type returned by the GraphAdapter. + See constructor for how the GraphAdapter is initialized. The default one right now returns a pandas + dataframe. + """ + if display_graph: + raise ValueError( + "display_graph=True is not supported for the async graph adapter. " + "Instead you should be using visualize_execution." + ) + start_time = time.time() + run_successful = True + error = None + try: + outputs = await self.raw_execute(final_vars, overrides, display_graph, inputs=inputs) + # Currently we don't allow async build results, but we could. + if self.adapter.does_method("do_build_result", is_async=False): + return self.adapter.call_lifecycle_method_sync("do_build_result", outputs=outputs) + return outputs + except Exception as e: + run_successful = False + logger.error(driver.SLACK_ERROR_MESSAGE) + error = telemetry.sanitize_error(*sys.exc_info()) + raise e + finally: + duration = time.time() - start_time + # ensure we can capture telemetry in async friendly way. + if telemetry.is_telemetry_enabled(): + + async def make_coroutine(): + self.capture_execute_telemetry( + error, final_vars, inputs, overrides, run_successful, duration + ) + + try: + # we don't have to await because we are running within the event loop. + asyncio.create_task(make_coroutine()) + except Exception as e: + if logger.isEnabledFor(logging.DEBUG): + logger.error(f"Encountered error submitting async telemetry:\n{e}") + + def capture_constructor_telemetry( + self, + error: Optional[str], + modules: Tuple[ModuleType], + config: Dict[str, Any], + adapter: base.HamiltonGraphAdapter, + ): + """Ensures we capture constructor telemetry the right way in an async context. + + This is a simpler wrapper around what's in the driver class. + + :param error: sanitized error string, if any. + :param modules: tuple of modules to build DAG from. + :param config: config to create the driver. + :param adapter: adapter class object. + """ + if telemetry.is_telemetry_enabled(): + try: + # check whether the event loop has been started yet or not + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.run_in_executor( + None, + super(AsyncDriver, self).capture_constructor_telemetry, + error, + modules, + config, + adapter, + ) + else: + + async def make_coroutine(): + super(AsyncDriver, self).capture_constructor_telemetry( + error, modules, config, adapter + ) + + loop.run_until_complete(make_coroutine()) + except Exception as e: + if logger.isEnabledFor(logging.DEBUG): + logger.error(f"Encountered error submitting async telemetry:\n{e}") + + +class Builder(driver.Builder): + """Builder for the async driver. This is equivalent to the standard builder, but has a more limited API. + Note this does not support dynamic execution or materializers (for now). + + Here is an example of how you might use it to get the tracker working: + + .. code-block:: python + + from hamilton_sdk import tracker + tracker_async = adapters.AsyncHamiltonTracker( + project_id=1, + username="elijah", + dag_name="async_tracker", + ) + dr = ( + await async_driver + .Builder() + .with_modules(async_module) + .with_adapters(tracking_async) + .build() + ) + """ + + def __init__(self): + super(Builder, self).__init__() + + def _not_supported(self, method_name: str, additional_message: str = ""): + raise ValueError( + f"Builder().{method_name}() is not supported for the async driver. {additional_message}" + ) + + def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> "Builder": + self._not_supported("enable_dynamic_execution") + + def with_materializers( + self, *materializers: typing.Union[ExtractorFactory, MaterializerFactory] + ) -> "Builder": + self._not_supported("with_materializers") + + def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> "Builder": + self._not_supported( + "with_adapter", + "Use with_adapters instead to pass in the tracker (or other async hooks/methods)", + ) + + async def build(self): + adapter = self.adapters if self.adapters is not None else [] + if self.legacy_graph_adapter is not None: + adapter.append(self.legacy_graph_adapter) + + result_builder = ( + self.result_builder if self.result_builder is not None else base.DictResult() + ) + + return await AsyncDriver( + self.config, *self.modules, adapters=adapter, result_builder=result_builder + ).ainit() diff --git a/hamilton/experimental/h_async.py b/hamilton/experimental/h_async.py index d05535972..b803c7666 100644 --- a/hamilton/experimental/h_async.py +++ b/hamilton/experimental/h_async.py @@ -1,483 +1,12 @@ -import asyncio -import inspect import logging -import sys -import time -import typing -import uuid -from types import ModuleType -from typing import Any, Dict, Optional, Tuple -import hamilton.lifecycle.base as lifecycle_base -from hamilton import base, driver, graph, lifecycle, node, telemetry -from hamilton.execution.graph_functions import create_error_message -from hamilton.io.materialization import ExtractorFactory, MaterializerFactory +import hamilton.async_driver logger = logging.getLogger(__name__) +logger.warning( + "This module is deprecated and will be removed in Hamilton 2.0 " + "Please use `hamilton.async_driver` instead. " +) -async def await_dict_of_tasks(task_dict: Dict[str, typing.Awaitable]) -> Dict[str, Any]: - """Util to await a dictionary of tasks as asyncio.gather is kind of garbage""" - keys = sorted(task_dict.keys()) - coroutines = [task_dict[key] for key in keys] - coroutines_gathered = await asyncio.gather(*coroutines) - return dict(zip(keys, coroutines_gathered)) - - -async def process_value(val: Any) -> Any: - """Helper function to process the value of a potential awaitable. - This is very simple -- all it does is await the value if its not already resolved. - - :param val: Value to process. - :return: The value (awaited if it is a coroutine, raw otherwise). - """ - if not inspect.isawaitable(val): - return val - return await val - - -class AsyncGraphAdapter(lifecycle_base.BaseDoNodeExecute, lifecycle.ResultBuilder): - """Graph adapter for use with the :class:`AsyncDriver` class.""" - - def __init__( - self, - result_builder: base.ResultMixin = None, - async_lifecycle_adapters: Optional[lifecycle_base.LifecycleAdapterSet] = None, - ): - """Creates an AsyncGraphAdapter class. Note this will *only* work with the AsyncDriver class. - - Some things to note: - - 1. This executes everything at the end (recursively). E.G. the final DAG nodes are awaited - 2. This does *not* work with decorators when the async function is being decorated. That is\ - because that function is called directly within the decorator, so we cannot await it. - """ - super(AsyncGraphAdapter, self).__init__() - self.adapter = ( - async_lifecycle_adapters - if async_lifecycle_adapters is not None - else lifecycle_base.LifecycleAdapterSet() - ) - self.result_builder = result_builder if result_builder else base.PandasDataFrameResult() - self.is_initialized = False - - def do_node_execute( - self, - *, - run_id: str, - node_: node.Node, - kwargs: typing.Dict[str, typing.Any], - task_id: Optional[str] = None, - ) -> typing.Any: - """Executes a node. Note this doesn't actually execute it -- rather, it returns a task. - This does *not* use async def, as we want it to be awaited on later -- this await is done - in processing parameters of downstream functions/final results. We can ensure that as - we also run the driver that this corresponds to. - - Note that this assumes that everything is awaitable, even if it isn't. - In that case, it just wraps it in one. - - :param task_id: - :param node_: - :param run_id: - :param node: Node to wrap - :param kwargs: Keyword arguments (either coroutines or raw values) to call it with - :return: A task - """ - callabl = node_.callable - - async def new_fn(fn=callabl, **fn_kwargs): - task_dict = {key: process_value(value) for key, value in fn_kwargs.items()} - fn_kwargs = await await_dict_of_tasks(task_dict) - error = None - result = None - success = True - pre_node_execute_errored = False - try: - if self.adapter.does_hook("pre_node_execute", is_async=True): - try: - await self.adapter.call_all_lifecycle_hooks_async( - "pre_node_execute", - run_id=run_id, - node_=node_, - kwargs=fn_kwargs, - task_id=task_id, - ) - except Exception as e: - pre_node_execute_errored = True - raise e - # TODO -- consider how to use node execution methods in the future - # This is messy as it is a method called within a method... - # if self.adapter.does_method("do_node_execute", is_async=False): - # result = self.adapter.call_lifecycle_method_sync( - # "do_node_execute", - # run_id=run_id, - # node_=node_, - # kwargs=kwargs, - # task_id=task_id, - # ) - # else: - - result = ( - await fn(**fn_kwargs) if asyncio.iscoroutinefunction(fn) else fn(**fn_kwargs) - ) - except Exception as e: - success = False - error = e - step = "[pre-node-execute:async]" if pre_node_execute_errored else "" - message = create_error_message(kwargs, node_, step) - logger.exception(message) - raise - finally: - if not pre_node_execute_errored and self.adapter.does_hook( - "post_node_execute", is_async=True - ): - try: - await self.adapter.call_all_lifecycle_hooks_async( - "post_node_execute", - run_id=run_id, - node_=node_, - kwargs=fn_kwargs, - success=success, - error=error, - result=result, - task_id=task_id, - ) - except Exception: - message = create_error_message(kwargs, node_, "[post-node-execute]") - logger.exception(message) - raise - - return result - - coroutine = new_fn(**kwargs) - task = asyncio.create_task(coroutine) - return task - - def build_result(self, **outputs: Any) -> Any: - return self.result_builder.build_result(**outputs) - - -def separate_sync_from_async( - adapters: typing.List[lifecycle.LifecycleAdapter], -) -> Tuple[typing.List[lifecycle.LifecycleAdapter], typing.List[lifecycle.LifecycleAdapter]]: - """Separates the sync and async adapters from a list of adapters. - Note this only works with hooks -- we'll be dealing with methods later. - - :param adapters: List of adapters - :return: Tuple of sync adapters, async adapters - """ - - adapter_set = lifecycle_base.LifecycleAdapterSet(*adapters) - # this is using internal(ish) fields (.sync_hooks/.async_hooks) -- we should probably expose it - # For now this is OK - # Note those are dict[hook_name, list[hook]], so we have to flatten - return ( - [sync_adapter for adapters in adapter_set.sync_hooks.values() for sync_adapter in adapters], - [ - async_adapter - for adapters in adapter_set.async_hooks.values() - for async_adapter in adapters - ], - ) - - -class AsyncDriver(driver.Driver): - """Async driver. This is a driver that uses the AsyncGraphAdapter to execute the graph. - - .. code-block:: python - - dr = h_async.AsyncDriver({}, async_module, result_builder=base.DictResult()) - df = await dr.execute([...], inputs=...) - - """ - - def __init__( - self, - config, - *modules, - result_builder: Optional[base.ResultMixin] = None, - adapters: typing.List[lifecycle.LifecycleAdapter] = None, - ): - """Instantiates an asynchronous driver. - - You will also need to call `ainit` to initialize the driver if you have any hooks/adapters. - - Note that this is not the desired API -- you should be using the :py:class:`hamilton.experimental.h_async.Builder` class to create the driver. - - :param config: Config to build the graph - :param modules: Modules to crawl for fns/graph nodes - :param adapters: Adapters to use for lifecycle methods. - :param result_builder: Results mixin to compile the graph's final results. TBD whether this should be included in the long run. - """ - if adapters is not None: - sync_adapters, async_adapters = separate_sync_from_async(adapters) - else: - # separate out so we know what the driver - sync_adapters = [] - async_adapters = [] - # we'll need to use this in multiple contexts so we'll keep it around for later - result_builders = [adapter for adapter in adapters if isinstance(adapter, base.ResultMixin)] - if result_builder is not None: - result_builders.append(result_builder) - if len(result_builders) > 1: - raise ValueError( - "You cannot pass more than one result builder to the async driver. " - "Please pass in a single result builder" - ) - # it will be defaulted by the graph adapter - result_builder = result_builders[0] if len(result_builders) == 1 else None - super(AsyncDriver, self).__init__( - config, - *modules, - adapter=[ - # We pass in the async adapters here as this can call node-level hooks - # Otherwise we trust the driver/fn graph to call sync adapters - AsyncGraphAdapter( - result_builder=result_builder, - async_lifecycle_adapters=lifecycle_base.LifecycleAdapterSet(*async_adapters), - ), - # We pass in the sync adapters here as this can call - *sync_adapters, - *async_adapters, # note async adapters will not be called during synchronous execution -- this is for access later - ], - ) - self.initialized = False - - async def ainit(self) -> "AsyncDriver": - """Initializes the driver when using async. This only exists for backwards compatibility. - In Hamilton 2.0, we will be using an asynchronous constructor. - See https://dev.to/akarshan/asynchronous-python-magic-how-to-create-awaitable-constructors-with-asyncmixin-18j5. - """ - if self.initialized: - # this way it can be called twice - return self - if self.adapter.does_hook("post_graph_construct", is_async=True): - await self.adapter.call_all_lifecycle_hooks_async( - "post_graph_construct", - graph=self.graph, - modules=self.graph_modules, - config=self.config, - ) - await self.adapter.ainit() - self.initialized = True - return self - - async def raw_execute( - self, - final_vars: typing.List[str], - overrides: Dict[str, Any] = None, - display_graph: bool = False, # don't care - inputs: Dict[str, Any] = None, - _fn_graph: graph.FunctionGraph = None, - ) -> Dict[str, Any]: - """Executes the graph, returning a dictionary of strings (node keys) to final results. - - :param final_vars: Variables to execute (+ upstream) - :param overrides: Overrides for nodes - :param display_graph: whether or not to display graph -- this is not supported. - :param inputs: Inputs for DAG runtime calculation - :param _fn_graph: Function graph for compatibility with superclass -- unused - :return: A dict of key -> result - """ - assert _fn_graph is None, ( - "_fn_graph must not be provided " - "-- the only reason you'd do this is to use materialize(), which is not supported yet.." - ) - run_id = str(uuid.uuid4()) - nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs) - memoized_computation = dict() # memoized storage - if self.adapter.does_hook("pre_graph_execute"): - await self.adapter.call_all_lifecycle_hooks_sync_and_async( - "pre_graph_execute", - run_id=run_id, - graph=self.graph, - final_vars=final_vars, - inputs=inputs, - overrides=overrides, - ) - results = None - error = None - success = False - try: - self.graph.execute(nodes, memoized_computation, overrides, inputs, run_id=run_id) - if display_graph: - raise ValueError( - "display_graph=True is not supported for the async graph adapter. " - "Instead you should be using visualize_execution." - ) - task_dict = { - key: asyncio.create_task(process_value(memoized_computation[key])) - for key in final_vars - } - results = await await_dict_of_tasks(task_dict) - success = True - except Exception as e: - error = e - success = False - raise e - finally: - if self.adapter.does_hook("post_graph_execute", is_async=None): - await self.adapter.call_all_lifecycle_hooks_sync_and_async( - "post_graph_execute", - run_id=run_id, - graph=self.graph, - success=success, - error=error, - results=results, - ) - return results - - async def execute( - self, - final_vars: typing.List[str], - overrides: Dict[str, Any] = None, - display_graph: bool = False, - inputs: Dict[str, Any] = None, - ) -> Any: - """Executes computation. - - :param final_vars: the final list of variables we want to compute. - :param overrides: values that will override "nodes" in the DAG. - :param display_graph: DEPRECATED. Whether we want to display the graph being computed. - :param inputs: Runtime inputs to the DAG. - :return: an object consisting of the variables requested, matching the type returned by the GraphAdapter. - See constructor for how the GraphAdapter is initialized. The default one right now returns a pandas - dataframe. - """ - if display_graph: - raise ValueError( - "display_graph=True is not supported for the async graph adapter. " - "Instead you should be using visualize_execution." - ) - start_time = time.time() - run_successful = True - error = None - try: - outputs = await self.raw_execute(final_vars, overrides, display_graph, inputs=inputs) - # Currently we don't allow async build results, but we could. - if self.adapter.does_method("do_build_result", is_async=False): - return self.adapter.call_lifecycle_method_sync("do_build_result", outputs=outputs) - return outputs - except Exception as e: - run_successful = False - logger.error(driver.SLACK_ERROR_MESSAGE) - error = telemetry.sanitize_error(*sys.exc_info()) - raise e - finally: - duration = time.time() - start_time - # ensure we can capture telemetry in async friendly way. - if telemetry.is_telemetry_enabled(): - - async def make_coroutine(): - self.capture_execute_telemetry( - error, final_vars, inputs, overrides, run_successful, duration - ) - - try: - # we don't have to await because we are running within the event loop. - asyncio.create_task(make_coroutine()) - except Exception as e: - if logger.isEnabledFor(logging.DEBUG): - logger.error(f"Encountered error submitting async telemetry:\n{e}") - - def capture_constructor_telemetry( - self, - error: Optional[str], - modules: Tuple[ModuleType], - config: Dict[str, Any], - adapter: base.HamiltonGraphAdapter, - ): - """Ensures we capture constructor telemetry the right way in an async context. - - This is a simpler wrapper around what's in the driver class. - - :param error: sanitized error string, if any. - :param modules: tuple of modules to build DAG from. - :param config: config to create the driver. - :param adapter: adapter class object. - """ - if telemetry.is_telemetry_enabled(): - try: - # check whether the event loop has been started yet or not - loop = asyncio.get_event_loop() - if loop.is_running(): - loop.run_in_executor( - None, - super(AsyncDriver, self).capture_constructor_telemetry, - error, - modules, - config, - adapter, - ) - else: - - async def make_coroutine(): - super(AsyncDriver, self).capture_constructor_telemetry( - error, modules, config, adapter - ) - - loop.run_until_complete(make_coroutine()) - except Exception as e: - if logger.isEnabledFor(logging.DEBUG): - logger.error(f"Encountered error submitting async telemetry:\n{e}") - - -class Builder(driver.Builder): - """Builder for the async driver. This is equivalent to the standard builder, but has a more limited API. - Note this does not support dynamic execution or materializers (for now). - - Here is an example of how you might use it to get the tracker working: - - .. code-block:: python - - from hamilton_sdk import tracker - tracker_async = adapters.AsyncHamiltonTracker( - project_id=1, - username="elijah", - dag_name="async_tracker", - ) - dr = ( - await h_async - .Builder() - .with_modules(async_module) - .with_adapters(tracking_async) - .build() - ) - """ - - def __init__(self): - super(Builder, self).__init__() - - def _not_supported(self, method_name: str, additional_message: str = ""): - raise ValueError( - f"Builder().{method_name}() is not supported for the async driver. {additional_message}" - ) - - def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> "Builder": - self._not_supported("enable_dynamic_execution") - - def with_materializers( - self, *materializers: typing.Union[ExtractorFactory, MaterializerFactory] - ) -> "Builder": - self._not_supported("with_materializers") - - def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> "Builder": - self._not_supported( - "with_adapter", - "Use with_adapters instead to pass in the tracker (or other async hooks/methods)", - ) - - async def build(self): - adapter = self.adapters if self.adapters is not None else [] - if self.legacy_graph_adapter is not None: - adapter.append(self.legacy_graph_adapter) - - result_builder = self.result_builder if self.result_builder is not None else base.DictResult() - - return await AsyncDriver( - self.config, - *self.modules, - adapters=adapter, - result_builder=result_builder - ).ainit() +AsyncDriver = hamilton.async_driver.AsyncDriver diff --git a/plugin_tests/h_async/__init__.py b/plugin_tests/h_async/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/plugin_tests/h_async/conftest.py b/plugin_tests/h_async/conftest.py deleted file mode 100644 index bc5ef5b5a..000000000 --- a/plugin_tests/h_async/conftest.py +++ /dev/null @@ -1,4 +0,0 @@ -from hamilton import telemetry - -# disable telemetry for all tests! -telemetry.disable_telemetry() diff --git a/plugin_tests/h_async/requirements-test.txt b/plugin_tests/h_async/requirements-test.txt deleted file mode 100644 index 2d73dba5b..000000000 --- a/plugin_tests/h_async/requirements-test.txt +++ /dev/null @@ -1 +0,0 @@ -pytest-asyncio diff --git a/plugin_tests/h_async/resources/__init__.py b/plugin_tests/h_async/resources/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/requirements-test.txt b/requirements-test.txt index 9d79f161d..d7eed9c6f 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -22,6 +22,7 @@ polars pyarrow pyreadstat # for SPSS data loader pytest +pytest-asyncio pytest-cov PyYAML scikit-learn diff --git a/plugin_tests/h_async/resources/simple_async_module.py b/tests/resources/simple_async_module.py similarity index 100% rename from plugin_tests/h_async/resources/simple_async_module.py rename to tests/resources/simple_async_module.py diff --git a/plugin_tests/h_async/test_h_async.py b/tests/test_async_driver.py similarity index 93% rename from plugin_tests/h_async/test_h_async.py rename to tests/test_async_driver.py index 0c0626755..4f840ba1b 100644 --- a/plugin_tests/h_async/test_h_async.py +++ b/tests/test_async_driver.py @@ -4,8 +4,7 @@ import pandas as pd import pytest -from hamilton import base -from hamilton.experimental import h_async +from hamilton import async_driver, base from hamilton.lifecycle.base import ( BasePostGraphConstruct, BasePostGraphConstructAsync, @@ -30,36 +29,36 @@ async def async_identity(n: int) -> int: @pytest.mark.asyncio async def test_await_dict_of_coroutines(): tasks = {n: async_identity(n) for n in range(0, 10)} - results = await h_async.await_dict_of_tasks(tasks) + results = await async_driver.await_dict_of_tasks(tasks) assert results == {n: await async_identity(n) for n in range(0, 10)} @pytest.mark.asyncio async def test_await_dict_of_tasks(): tasks = {n: asyncio.create_task(async_identity(n)) for n in range(0, 10)} - results = await h_async.await_dict_of_tasks(tasks) + results = await async_driver.await_dict_of_tasks(tasks) assert results == {n: await async_identity(n) for n in range(0, 10)} # The following are not parameterized as we need to use the event loop -- fixtures will complicate this @pytest.mark.asyncio async def test_process_value_raw(): - assert await h_async.process_value(1) == 1 + assert await async_driver.process_value(1) == 1 @pytest.mark.asyncio async def test_process_value_coroutine(): - assert await h_async.process_value(async_identity(1)) == 1 + assert await async_driver.process_value(async_identity(1)) == 1 @pytest.mark.asyncio async def test_process_value_task(): - assert await h_async.process_value(asyncio.create_task(async_identity(1))) == 1 + assert await async_driver.process_value(asyncio.create_task(async_identity(1))) == 1 @pytest.mark.asyncio async def test_driver_end_to_end(): - dr = h_async.AsyncDriver({}, simple_async_module) + dr = async_driver.AsyncDriver({}, simple_async_module) all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] result = await dr.raw_execute(final_vars=all_vars, inputs={"external_input": 1}) result["a"] = result["a"].to_dict() # convert to dict for comparison @@ -85,7 +84,7 @@ async def test_driver_end_to_end(): @mock.patch("hamilton.telemetry.send_event_json") @mock.patch("hamilton.telemetry.g_telemetry_enabled", True) async def test_driver_end_to_end_telemetry(send_event_json): - dr = h_async.AsyncDriver({}, simple_async_module, result_builder=base.DictResult()) + dr = async_driver.AsyncDriver({}, simple_async_module, result_builder=base.DictResult()) with mock.patch("hamilton.telemetry.g_telemetry_enabled", False): # don't count this telemetry tracking invocation all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] @@ -156,7 +155,7 @@ async def post_graph_construct(self, **kwargs): adapter = AsyncTrackingAdapter(tracked_calls) - dr = await h_async.AsyncDriver( + dr = await async_driver.AsyncDriver( {}, simple_async_module, result_builder=base.DictResult(), adapters=[adapter] ).ainit() all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] @@ -220,7 +219,7 @@ def post_graph_construct(self, **kwargs): adapter = AsyncTrackingAdapter(tracked_calls) - dr = await h_async.AsyncDriver( + dr = await async_driver.AsyncDriver( {}, simple_async_module, result_builder=base.DictResult(), adapters=[adapter] ).ainit() all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index e42c51b93..1f7303637 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -7,8 +7,7 @@ import pytest -from hamilton import base, node, telemetry -from hamilton.experimental import h_async +from hamilton import async_driver, base, node, telemetry from hamilton.lifecycle import base as lifecycle_base @@ -163,8 +162,8 @@ class CustomResultBuilder(base.ResultMixin): "hamilton.base.DefaultAdapter", ), ( - h_async.AsyncGraphAdapter(base.DictResult()), - "hamilton.experimental.h_async.AsyncGraphAdapter", + async_driver.AsyncGraphAdapter(base.DictResult()), + "hamilton.experimental.async_driver.AsyncGraphAdapter", ), (CustomAdapter(base.DictResult()), "custom_adapter"), ], @@ -189,7 +188,7 @@ def test_get_adapter_name(adapter, expected): "hamilton.base.StrictIndexTypePandasDataFrameResult", ), (base.SimplePythonGraphAdapter(CustomResultBuilder()), "custom_builder"), - (h_async.AsyncGraphAdapter(base.DictResult()), "hamilton.base.DictResult"), + (async_driver.AsyncGraphAdapter(base.DictResult()), "hamilton.base.DictResult"), (CustomAdapter(base.DictResult()), "hamilton.base.DictResult"), (CustomAdapter(CustomResultBuilder()), "custom_builder"), ],