Skip to content

Commit

Permalink
fix not running nodes multiple times
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Feb 27, 2024
1 parent 2a43260 commit 2cdafa8
Showing 1 changed file with 47 additions and 76 deletions.
123 changes: 47 additions & 76 deletions znflow/deployment/dask_depl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class DaskDeployment(DeploymentBase):
)

def run(self, nodes: t.Optional[list] = None):
if nodes != "ABC":
if nodes is None:
for node_uuid in self.graph.reverse():
assert self.graph.immutable_nodes
node = self.graph.nodes[node_uuid]["value"]
Expand Down Expand Up @@ -89,78 +89,49 @@ def run(self, nodes: t.Optional[list] = None):
self.graph._update_node_attributes(node, handler.UpdateConnectors())
else:
node.result = self.results[node.uuid].result().result

# wait(self.results.values())


# @dataclasses.dataclass
# class DaskDeployment:
# """ZnFlow deployment using Dask.

# Attributes
# ----------
# graph: DiGraph
# the znflow graph containing the nodes.
# client: Client, optional
# the Dask client.
# results: Dict[uuid, Future]
# a dictionary of {uuid: Future} shape that is filled after the graph is submitted.

# """

# graph: "DiGraph"
# client: Client = dataclasses.field(default_factory=Client)
# results: typing.Dict[uuid.UUID, Future] = dataclasses.field(
# default_factory=dict, init=False
# )

# def submit_graph(self):
# """Submit the graph to Dask.

# When submitting to Dask, a Node is serialized, processed and a
# copy can be returned.

# This requires:
# - the connections to be updated to the respective Nodes coming from Dask futures.
# - the Node to be returned from the workers and passed to all successors.
# """
# for node_uuid in self.graph.reverse():
# node = self.graph.nodes[node_uuid]["value"]
# predecessors = list(self.graph.predecessors(node.uuid))

# if len(predecessors) == 0:
# self.results[node.uuid] = self.client.submit( # TODO how to name
# node_submit, node=node, pure=False
# )
# else:
# self.results[node.uuid] = self.client.submit(
# node_submit,
# node=node,
# predecessors={
# x: self.results[x] for x in self.results if x in predecessors
# },
# pure=False,
# )

# def get_results(self, obj: typing.Union[Node, list, dict, NodeView], /):
# """Get the results from Dask based on the original object.

# Parameters
# ----------
# obj: any
# either a single Node or multiple Nodes from the submitted graph.

# Returns
# -------
# any:
# Returns an instance of obj which is updated with the results from Dask.

# """
# from znflow import DiGraph

# if isinstance(obj, NodeView):
# data = LoadNodeFromDeploymentResults()(dict(obj), results=self.results)
# return {x: v["value"] for x, v in data.items()}
# elif isinstance(obj, DiGraph):
# raise NotImplementedError
# return LoadNodeFromDeploymentResults()(obj, results=self.results)

else:
for node_uuid in self.graph.reverse():
assert self.graph.immutable_nodes
node = self.graph.nodes[node_uuid]["value"]
if node in nodes:
predecessors = list(self.graph.predecessors(node.uuid))

if len(predecessors) == 0:
if node.uuid not in self.results:
self.results[node.uuid] = self.client.submit( # TODO how to name
node_submit, node=node, pure=False
)
else:
if node.uuid not in self.results:
for predecessor in predecessors:
# submit the predecessors first
_node = self.graph.nodes[predecessor]["value"]
if _node.uuid not in self.results:
self.results[predecessor] = self.client.submit(
node_submit, node=_node, pure=False
)

self.results[node.uuid] = self.client.submit(
node_submit,
node=node,
predecessors={
x: self.results[x]
for x in self.results
if x in predecessors
},
pure=False,
)
# load the results when done
for node_uuid in self.graph.reverse():
node = self.graph.nodes[node_uuid]["value"]
try:
future = self.results[node.uuid]
print(future.result())
if isinstance(node, Node):
node.__dict__.update(self.results[node.uuid].result().__dict__)
self.graph._update_node_attributes(node, handler.UpdateConnectors())
else:
node.result = self.results[node.uuid].result().result
except KeyError:
pass

0 comments on commit 2cdafa8

Please sign in to comment.