Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 28, 2024
1 parent aad2b78 commit ace5823
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 21 deletions.
20 changes: 12 additions & 8 deletions tests/examples/test_ips_lotf.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Mock version of IPS LotF workflow for testing purposes."""

import dataclasses
import znflow
import random

import pytest

import znflow


@dataclasses.dataclass
class AddData(znflow.Node):
file: str

def run(self):
if self.file is None:
raise ValueError("File is None")
Expand All @@ -17,6 +21,7 @@ def run(self):
def atoms(self):
return "Atoms"


@dataclasses.dataclass
class TrainModel(znflow.Node):
data: str
Expand All @@ -28,6 +33,7 @@ def run(self):
self.model = "Model"
print(f"Model: {self.model}")


@dataclasses.dataclass
class MD(znflow.Node):
model: str
Expand All @@ -39,6 +45,7 @@ def run(self):
self.atoms = "Atoms"
print(f"Atoms: {self.atoms}")


@dataclasses.dataclass
class EvaluateModel(znflow.Node):
model: str
Expand All @@ -52,10 +59,8 @@ def run(self):
self.metrics = random.random()
print(f"Metrics: {self.metrics}")

@pytest.mark.parametrize(
"deployment",
["vanilla_deployment", "dask_deployment"]
)

@pytest.mark.parametrize("deployment", ["vanilla_deployment", "dask_deployment"])
def test_lotf(deployment, request):
deployment = request.getfixturevalue(deployment)

Expand All @@ -72,6 +77,5 @@ def test_lotf(deployment, request):
if znflow.resolve(metrics.metrics) == pytest.approx(0.623, 1e-3):
# break loop after 6th iteration
break

assert len(graph) == 22

assert len(graph) == 22
3 changes: 1 addition & 2 deletions znflow/deployment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def run(self, nodes: t.Optional[t.List] = None):
else:
# convert nodes to UUIDs
nodes = [node.uuid for node in nodes]

for node_uuid in nodes:
node_available = self.graph.nodes[node_uuid].get("available", False)
if self.graph.immutable_nodes and node_available:
Expand All @@ -27,4 +27,3 @@ def set_graph(self, graph: "DiGraph"):
@abc.abstractmethod
def _run_node(self, node_uuid):
pass

17 changes: 7 additions & 10 deletions znflow/deployment/dask_depl.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,18 @@ def _run_node(self, node_uuid):
predecessors = list(self.graph.predecessors(node_uuid))
for predecessor in predecessors:
self._run_node(predecessor)

node_available = self.graph.nodes[node_uuid].get("available", False)
if self.graph.immutable_nodes and node_available:
return


self.results[node_uuid] = self.client.submit(
node_submit,
node=node,
predecessors={
x: self.results[x] for x in self.results if x in predecessors
},
pure=False,
key=f"{node.__class__.__name__}-{node_uuid}",
)
node_submit,
node=node,
predecessors={x: self.results[x] for x in self.results if x in predecessors},
pure=False,
key=f"{node.__class__.__name__}-{node_uuid}",
)
if self.graph.immutable_nodes:
self.graph.nodes[node_uuid]["available"] = True

Expand Down
1 change: 0 additions & 1 deletion znflow/deployment/vanilla.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import dataclasses

from networkx import predecessor

from znflow import handler

Expand Down

0 comments on commit ace5823

Please sign in to comment.