diff --git a/ui/sdk/src/hamilton_sdk/adapters.py b/ui/sdk/src/hamilton_sdk/adapters.py index 2fddfcf74..88eef4a1f 100644 --- a/ui/sdk/src/hamilton_sdk/adapters.py +++ b/ui/sdk/src/hamilton_sdk/adapters.py @@ -1,4 +1,3 @@ -import asyncio import datetime import hashlib import logging @@ -148,7 +147,6 @@ def pre_graph_execute( dag_template_id = self.dag_template_id_cache[fg_id] else: raise ValueError("DAG template ID not found in cache. This should never happen.") - tracking_state = TrackingState(run_id) self.tracking_states[run_id] = tracking_state # cache tracking_state.clock_start() @@ -386,7 +384,7 @@ def post_graph_execute( ) -class AsyncHamiltonAdapter( +class AsyncHamiltonTracker( base.BasePostGraphConstructAsync, base.BasePreGraphExecuteAsync, base.BasePreNodeExecuteAsync, @@ -396,13 +394,13 @@ class AsyncHamiltonAdapter( def __init__( self, project_id: int, - api_key: str, username: str, dag_name: str, tags: Dict[str, str] = None, client_factory: Callable[ - [str, str, str], clients.HamiltonClient + [str, str, str], clients.BasicAsynchronousHamiltonClient ] = clients.BasicAsynchronousHamiltonClient, + api_key: str = os.environ.get("HAMILTON_API_KEY", ""), hamilton_api_url=os.environ.get("HAMILTON_API_URL", constants.HAMILTON_API_URL), hamilton_ui_url=os.environ.get("HAMILTON_UI_URL", constants.HAMILTON_UI_URL), ): @@ -416,11 +414,22 @@ def __init__( driver.validate_tags(self.base_tags) self.dag_name = dag_name self.hamilton_ui_url = hamilton_ui_url - logger.debug("Validating authentication against Hamilton BE API...") - asyncio.run(self.client.validate_auth()) - logger.debug(f"Ensuring project {self.project_id} exists...") + self.dag_template_id_cache = {} + self.tracking_states = {} + self.dw_run_ids = {} + self.task_runs = {} + self.initialized = False + super().__init__() + + async def ainit(self): + if self.initialized: + return self + """You must call this to initialize the tracker.""" + logger.info("Validating authentication against Hamilton BE API...") + await self.client.validate_auth() + logger.info(f"Ensuring project {self.project_id} exists...") try: - asyncio.run(self.client.project_exists(self.project_id)) + await self.client.project_exists(self.project_id) except clients.UnauthorizedException: logger.exception( f"Authentication failed. Please check your username and try again. " @@ -433,11 +442,11 @@ def __init__( f"You can do so at {self.hamilton_ui_url}/dashboard/projects" ) raise - self.dag_template_id_cache = {} - self.tracking_states = {} - self.dw_run_ids = {} - self.task_runs = {} - super().__init__() + logger.info("Initializing Hamilton tracker.") + await self.client.ainit() + logger.info("Initialized Hamilton tracker.") + self.initialized = True + return self async def post_graph_construct( self, graph: h_graph.FunctionGraph, modules: List[ModuleType], config: Dict[str, Any] @@ -476,7 +485,6 @@ async def pre_graph_execute( overrides: Dict[str, Any], ): logger.debug("pre_graph_execute %s", run_id) - self.run_id = run_id fg_id = id(graph) if fg_id in self.dag_template_id_cache: dag_template_id = self.dag_template_id_cache[fg_id] @@ -511,7 +519,7 @@ async def pre_node_execute( task_update = dict( node_template_name=node_.name, - node_name=get_node_name(node_.name, task_id), + node_name=get_node_name(node_, task_id), realized_dependencies=[dep.name for dep in node_.dependencies], status=task_run.status, start_time=task_run.start_time, @@ -519,7 +527,7 @@ async def pre_node_execute( ) await self.client.update_tasks( self.dw_run_ids[run_id], - attributes=[None], + attributes=[], task_updates=[task_update], in_samples=[task_run.is_in_sample], ) @@ -532,6 +540,7 @@ async def post_node_execute( error: Optional[Exception], result: Any, task_id: Optional[str] = None, + **future_kwargs, ): logger.debug("post_node_execute %s", run_id) task_run = self.task_runs[run_id][node_.name] diff --git a/ui/sdk/src/hamilton_sdk/api/clients.py b/ui/sdk/src/hamilton_sdk/api/clients.py index 4de4d961b..9a43a67dd 100644 --- a/ui/sdk/src/hamilton_sdk/api/clients.py +++ b/ui/sdk/src/hamilton_sdk/api/clients.py @@ -1,4 +1,5 @@ import abc +import asyncio import datetime import functools import logging @@ -30,6 +31,33 @@ def __init__(self, path: str, user: str): super().__init__(message) +def create_batch(batch: dict, dag_run_id: int): + attributes = defaultdict(list) + task_updates = defaultdict(list) + for item in batch: + if item["dag_run_id"] == dag_run_id: + for attr in item["attributes"]: + if attr is None: + continue + attributes[attr["node_name"]].append(attr) + for task_update in item["task_updates"]: + if task_update is None: + continue + task_updates[task_update["node_name"]].append(task_update) + + # We do not care about disambiguating here -- only one named attribute should be logged + + attributes_list = [] + for node_name in attributes: + attributes_list.extend(attributes[node_name]) + # in this case we do care about order so we don't send double the updates. + task_updates_list = [ + functools.reduce(lambda x, y: {**x, **y}, task_updates[node_name]) + for node_name in task_updates + ] + return attributes_list, task_updates_list + + class HamiltonClient: @abc.abstractmethod def validate_auth(self): @@ -220,30 +248,7 @@ def flush(self, batch): # group by dag_run_id -- just incase someone does something weird? dag_run_ids = set([item["dag_run_id"] for item in batch]) for dag_run_id in dag_run_ids: - attributes = defaultdict(list) - task_updates = defaultdict(list) - for item in batch: - if item["dag_run_id"] == dag_run_id: - for attr in item["attributes"]: - if attr is None: - continue - attributes[attr["node_name"]].append(attr) - for task_update in item["task_updates"]: - if task_update is None: - continue - task_updates[task_update["node_name"]].append(task_update) - - # We do not care about disambiguating here -- only one named attribute should be logged - - attributes_list = [] - for node_name in attributes: - attributes_list.extend(attributes[node_name]) - # in this case we do care about order so we don't send double the updates. - task_updates_list = [ - functools.reduce(lambda x, y: {**x, **y}, task_updates[node_name]) - for node_name in task_updates - ] - + attributes_list, task_updates_list = create_batch(batch, dag_run_id) response = requests.put( f"{self.base_url}/dag_runs_bulk?dag_run_id={dag_run_id}", json={ @@ -514,6 +519,65 @@ def __init__(self, api_key: str, username: str, h_api_url: str, base_path: str = self.api_key = api_key self.username = username self.base_url = h_api_url + base_path + self.flush_interval = 5 + self.data_queue = asyncio.Queue() + self.running = True + self.max_batch_size = 100 + + async def ainit(self): + asyncio.create_task(self.worker()) + + async def flush(self, batch): + """Flush the batch (send it to the backend or process it).""" + logger.debug(f"Flushing batch: {len(batch)}") # Replace with actual processing logic + # group by dag_run_id -- just incase someone does something weird? + dag_run_ids = set([item["dag_run_id"] for item in batch]) + for dag_run_id in dag_run_ids: + attributes_list, task_updates_list = create_batch(batch, dag_run_id) + async with aiohttp.ClientSession() as session: + async with session.put( + f"{self.base_url}/dag_runs_bulk?dag_run_id={dag_run_id}", + json={ + "attributes": make_json_safe(attributes_list), + "task_updates": make_json_safe(task_updates_list), + }, + headers=self._common_headers(), + ) as response: + try: + response.raise_for_status() + logger.debug(f"Updated tasks for DAG run {dag_run_id}") + except HTTPError: + logger.exception(f"Failed to update tasks for DAG run {dag_run_id}") + # zraise + + async def worker(self): + """Worker thread to process the queue""" + batch = [] + last_flush_time = time.time() + logger.debug("Starting worker") + while True: + logger.debug( + f"Awaiting item from queue -- current batched # of items are: {len(batch)}" + ) + try: + item = await asyncio.wait_for(self.data_queue.get(), timeout=self.flush_interval) + batch.append(item) + except asyncio.TimeoutError: + # This is fine, we just keep waiting + pass + else: + if item is None: + await self.flush(batch) + return + + # Check if batch is full or flush interval has passed + if ( + len(batch) >= self.max_batch_size + or (time.time() - last_flush_time) >= self.flush_interval + ): + await self.flush(batch) + batch = [] + last_flush_time = time.time() def _common_headers(self) -> Dict[str, Any]: """Yields the common headers for all requests. @@ -728,26 +792,14 @@ async def update_tasks( f"Updating tasks for DAG run {dag_run_id} with {len(attributes)} " f"attributes and {len(task_updates)} task updates" ) - url = f"{self.base_url}/dag_runs_bulk?dag_run_id={dag_run_id}" - headers = self._common_headers() - data = { - "attributes": make_json_safe(attributes), - "task_updates": make_json_safe(task_updates), - } - - async with aiohttp.ClientSession() as session: - async with session.put(url, json=data, headers=headers) as response: - try: - response.raise_for_status() - logger.debug(f"Updated tasks for DAG run {dag_run_id}") - except aiohttp.ClientResponseError: - logger.exception(f"Failed to update tasks for DAG run {dag_run_id}") - raise + await self.data_queue.put( + {"dag_run_id": dag_run_id, "attributes": attributes, "task_updates": task_updates} + ) async def log_dag_run_end(self, dag_run_id: int, status: str): logger.debug(f"Logging end of DAG run {dag_run_id} with status {status}") url = f"{self.base_url}/dag_runs/{dag_run_id}/" - data = (make_json_safe({"run_status": status, "run_end_time": datetime.datetime.utcnow()}),) + data = make_json_safe({"run_status": status, "run_end_time": datetime.datetime.utcnow()}) headers = self._common_headers() async with aiohttp.ClientSession() as session: async with session.put(url, json=data, headers=headers) as response: