Skip to content

Commit

Permalink
Updates the Hamilton SDK to work with async
Browse files Browse the repository at this point in the history
This has a basic queuing system -- we laucnh a periodic task that runs
and flushes out a queue -- it times out at the flush interval and
flushes everything it has. This will often flush after a request is done
but it will always flush.
  • Loading branch information
elijahbenizzy committed Jun 26, 2024
1 parent 9383e17 commit 6549130
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 57 deletions.
43 changes: 26 additions & 17 deletions ui/sdk/src/hamilton_sdk/adapters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import datetime
import hashlib
import logging
Expand Down Expand Up @@ -148,7 +147,6 @@ def pre_graph_execute(
dag_template_id = self.dag_template_id_cache[fg_id]
else:
raise ValueError("DAG template ID not found in cache. This should never happen.")

tracking_state = TrackingState(run_id)
self.tracking_states[run_id] = tracking_state # cache
tracking_state.clock_start()
Expand Down Expand Up @@ -386,7 +384,7 @@ def post_graph_execute(
)


class AsyncHamiltonAdapter(
class AsyncHamiltonTracker(
base.BasePostGraphConstructAsync,
base.BasePreGraphExecuteAsync,
base.BasePreNodeExecuteAsync,
Expand All @@ -396,13 +394,13 @@ class AsyncHamiltonAdapter(
def __init__(
self,
project_id: int,
api_key: str,
username: str,
dag_name: str,
tags: Dict[str, str] = None,
client_factory: Callable[
[str, str, str], clients.HamiltonClient
[str, str, str], clients.BasicAsynchronousHamiltonClient
] = clients.BasicAsynchronousHamiltonClient,
api_key: str = os.environ.get("HAMILTON_API_KEY", ""),
hamilton_api_url=os.environ.get("HAMILTON_API_URL", constants.HAMILTON_API_URL),
hamilton_ui_url=os.environ.get("HAMILTON_UI_URL", constants.HAMILTON_UI_URL),
):
Expand All @@ -416,11 +414,22 @@ def __init__(
driver.validate_tags(self.base_tags)
self.dag_name = dag_name
self.hamilton_ui_url = hamilton_ui_url
logger.debug("Validating authentication against Hamilton BE API...")
asyncio.run(self.client.validate_auth())
logger.debug(f"Ensuring project {self.project_id} exists...")
self.dag_template_id_cache = {}
self.tracking_states = {}
self.dw_run_ids = {}
self.task_runs = {}
self.initialized = False
super().__init__()

async def ainit(self):
if self.initialized:
return self
"""You must call this to initialize the tracker."""
logger.info("Validating authentication against Hamilton BE API...")
await self.client.validate_auth()
logger.info(f"Ensuring project {self.project_id} exists...")
try:
asyncio.run(self.client.project_exists(self.project_id))
await self.client.project_exists(self.project_id)
except clients.UnauthorizedException:
logger.exception(
f"Authentication failed. Please check your username and try again. "
Expand All @@ -433,11 +442,11 @@ def __init__(
f"You can do so at {self.hamilton_ui_url}/dashboard/projects"
)
raise
self.dag_template_id_cache = {}
self.tracking_states = {}
self.dw_run_ids = {}
self.task_runs = {}
super().__init__()
logger.info("Initializing Hamilton tracker.")
await self.client.ainit()
logger.info("Initialized Hamilton tracker.")
self.initialized = True
return self

async def post_graph_construct(
self, graph: h_graph.FunctionGraph, modules: List[ModuleType], config: Dict[str, Any]
Expand Down Expand Up @@ -476,7 +485,6 @@ async def pre_graph_execute(
overrides: Dict[str, Any],
):
logger.debug("pre_graph_execute %s", run_id)
self.run_id = run_id
fg_id = id(graph)
if fg_id in self.dag_template_id_cache:
dag_template_id = self.dag_template_id_cache[fg_id]
Expand Down Expand Up @@ -511,15 +519,15 @@ async def pre_node_execute(

task_update = dict(
node_template_name=node_.name,
node_name=get_node_name(node_.name, task_id),
node_name=get_node_name(node_, task_id),
realized_dependencies=[dep.name for dep in node_.dependencies],
status=task_run.status,
start_time=task_run.start_time,
end_time=None,
)
await self.client.update_tasks(
self.dw_run_ids[run_id],
attributes=[None],
attributes=[],
task_updates=[task_update],
in_samples=[task_run.is_in_sample],
)
Expand All @@ -532,6 +540,7 @@ async def post_node_execute(
error: Optional[Exception],
result: Any,
task_id: Optional[str] = None,
**future_kwargs,
):
logger.debug("post_node_execute %s", run_id)
task_run = self.task_runs[run_id][node_.name]
Expand Down
132 changes: 92 additions & 40 deletions ui/sdk/src/hamilton_sdk/api/clients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import asyncio
import datetime
import functools
import logging
Expand Down Expand Up @@ -30,6 +31,33 @@ def __init__(self, path: str, user: str):
super().__init__(message)


def create_batch(batch: dict, dag_run_id: int):
attributes = defaultdict(list)
task_updates = defaultdict(list)
for item in batch:
if item["dag_run_id"] == dag_run_id:
for attr in item["attributes"]:
if attr is None:
continue
attributes[attr["node_name"]].append(attr)
for task_update in item["task_updates"]:
if task_update is None:
continue
task_updates[task_update["node_name"]].append(task_update)

# We do not care about disambiguating here -- only one named attribute should be logged

attributes_list = []
for node_name in attributes:
attributes_list.extend(attributes[node_name])
# in this case we do care about order so we don't send double the updates.
task_updates_list = [
functools.reduce(lambda x, y: {**x, **y}, task_updates[node_name])
for node_name in task_updates
]
return attributes_list, task_updates_list


class HamiltonClient:
@abc.abstractmethod
def validate_auth(self):
Expand Down Expand Up @@ -220,30 +248,7 @@ def flush(self, batch):
# group by dag_run_id -- just incase someone does something weird?
dag_run_ids = set([item["dag_run_id"] for item in batch])
for dag_run_id in dag_run_ids:
attributes = defaultdict(list)
task_updates = defaultdict(list)
for item in batch:
if item["dag_run_id"] == dag_run_id:
for attr in item["attributes"]:
if attr is None:
continue
attributes[attr["node_name"]].append(attr)
for task_update in item["task_updates"]:
if task_update is None:
continue
task_updates[task_update["node_name"]].append(task_update)

# We do not care about disambiguating here -- only one named attribute should be logged

attributes_list = []
for node_name in attributes:
attributes_list.extend(attributes[node_name])
# in this case we do care about order so we don't send double the updates.
task_updates_list = [
functools.reduce(lambda x, y: {**x, **y}, task_updates[node_name])
for node_name in task_updates
]

attributes_list, task_updates_list = create_batch(batch, dag_run_id)
response = requests.put(
f"{self.base_url}/dag_runs_bulk?dag_run_id={dag_run_id}",
json={
Expand Down Expand Up @@ -514,6 +519,65 @@ def __init__(self, api_key: str, username: str, h_api_url: str, base_path: str =
self.api_key = api_key
self.username = username
self.base_url = h_api_url + base_path
self.flush_interval = 5
self.data_queue = asyncio.Queue()
self.running = True
self.max_batch_size = 100

async def ainit(self):
asyncio.create_task(self.worker())

async def flush(self, batch):
"""Flush the batch (send it to the backend or process it)."""
logger.debug(f"Flushing batch: {len(batch)}") # Replace with actual processing logic
# group by dag_run_id -- just incase someone does something weird?
dag_run_ids = set([item["dag_run_id"] for item in batch])
for dag_run_id in dag_run_ids:
attributes_list, task_updates_list = create_batch(batch, dag_run_id)
async with aiohttp.ClientSession() as session:
async with session.put(
f"{self.base_url}/dag_runs_bulk?dag_run_id={dag_run_id}",
json={
"attributes": make_json_safe(attributes_list),
"task_updates": make_json_safe(task_updates_list),
},
headers=self._common_headers(),
) as response:
try:
response.raise_for_status()
logger.debug(f"Updated tasks for DAG run {dag_run_id}")
except HTTPError:
logger.exception(f"Failed to update tasks for DAG run {dag_run_id}")
# zraise

async def worker(self):
"""Worker thread to process the queue"""
batch = []
last_flush_time = time.time()
logger.debug("Starting worker")
while True:
logger.debug(
f"Awaiting item from queue -- current batched # of items are: {len(batch)}"
)
try:
item = await asyncio.wait_for(self.data_queue.get(), timeout=self.flush_interval)
batch.append(item)
except asyncio.TimeoutError:
# This is fine, we just keep waiting
pass
else:
if item is None:
await self.flush(batch)
return

# Check if batch is full or flush interval has passed
if (
len(batch) >= self.max_batch_size
or (time.time() - last_flush_time) >= self.flush_interval
):
await self.flush(batch)
batch = []
last_flush_time = time.time()

def _common_headers(self) -> Dict[str, Any]:
"""Yields the common headers for all requests.
Expand Down Expand Up @@ -728,26 +792,14 @@ async def update_tasks(
f"Updating tasks for DAG run {dag_run_id} with {len(attributes)} "
f"attributes and {len(task_updates)} task updates"
)
url = f"{self.base_url}/dag_runs_bulk?dag_run_id={dag_run_id}"
headers = self._common_headers()
data = {
"attributes": make_json_safe(attributes),
"task_updates": make_json_safe(task_updates),
}

async with aiohttp.ClientSession() as session:
async with session.put(url, json=data, headers=headers) as response:
try:
response.raise_for_status()
logger.debug(f"Updated tasks for DAG run {dag_run_id}")
except aiohttp.ClientResponseError:
logger.exception(f"Failed to update tasks for DAG run {dag_run_id}")
raise
await self.data_queue.put(
{"dag_run_id": dag_run_id, "attributes": attributes, "task_updates": task_updates}
)

async def log_dag_run_end(self, dag_run_id: int, status: str):
logger.debug(f"Logging end of DAG run {dag_run_id} with status {status}")
url = f"{self.base_url}/dag_runs/{dag_run_id}/"
data = (make_json_safe({"run_status": status, "run_end_time": datetime.datetime.utcnow()}),)
data = make_json_safe({"run_status": status, "run_end_time": datetime.datetime.utcnow()})
headers = self._common_headers()
async with aiohttp.ClientSession() as session:
async with session.put(url, json=data, headers=headers) as response:
Expand Down

0 comments on commit 6549130

Please sign in to comment.