Skip to content

Commit

Permalink
start viewer serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
marsipu committed Jun 22, 2024
1 parent 89e3136 commit 5e34498
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 7 deletions.
36 changes: 35 additions & 1 deletion mne_pipeline_hd/gui/node/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ def add_input(
PortItem
port qgraphics item.
"""
# port names must be unique
if name in self._inputs:
logging.warning(f"Input port {name} already exists.")
return
port = Port(self, name, "in", multi_connection, accepted_ports)
self._inputs[port.name] = port
if self.scene():
Expand Down Expand Up @@ -252,7 +256,10 @@ def add_output(
PortItem
port qgraphics item.
"""

# port names must be unique
if name in self._outputs:
logging.warning(f"Output port {name} already exists.")
return
port = Port(self, name, "out", multi_connection, accepted_ports)
self._outputs[port.name] = port
if self.scene():
Expand Down Expand Up @@ -381,6 +388,33 @@ def delete(self):
self.scene().removeItem(self)
del self

def to_dict(self):
node_dict = {
"name": self.name,
"class": self.__class__.__name__,
"pos": self.xy_pos,
"inputs": self.inputs,
"outputs": self.outputs,
"connections": {
"inputs": {
p.name: {
nid: [cp.name for cp in cpts]
for nid, cpts in p.connected_ports.items()
}
for p in self.inputs
},
"outputs": {
p.name: {
nid: [cp.name for cp in cpts]
for nid, cpts in p.connected_ports.items()
}
for p in self.outputs
},
},
}

return node_dict

# ----------------------------------------------------------------------------------
# Qt methods
# ----------------------------------------------------------------------------------
Expand Down
18 changes: 12 additions & 6 deletions mne_pipeline_hd/gui/node/node_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,24 +248,30 @@ def node(self, node_idx=None, node_name=None, node_id=None):
return self.nodes[node_id]

def to_dict(self):
# ToDo: Implement this
graph_dict = {"nodes": dict(), "connections": dict()}
viewer_dict = {node_id: node.to_dict() for node_id, node in self.nodes.items()}

return graph_dict
return viewer_dict

def from_dict(self, graph_dict):
def from_dict(self, viewer_dict):
# ToDo: Implement this
for node_id, node_data in graph_dict["nodes"].items():
for node_id, node_data in viewer_dict["nodes"].items():
node = self.add_node(node_data["type"])
node.from_dict(node_data)

for conn_id, conn_data in graph_dict["connections"].items():
for conn_id, conn_data in viewer_dict["connections"].items():
start_port = self.nodes[conn_data["start_node"]].outputs[
conn_data["start_port"]
]
end_port = self.nodes[conn_data["end_node"]].inputs[conn_data["end_port"]]
start_port.connect_to(end_port)

def clear(self):
"""
Clear the node graph.
"""
for node in self.nodes.values():
self.remove_node(node)

# ----------------------------------------------------------------------------------
# Qt methods
# ----------------------------------------------------------------------------------
Expand Down
9 changes: 9 additions & 0 deletions mne_pipeline_hd/tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,12 @@ def test_nodes_basic_interaction(nodeviewer):
)
# Check if connection was sliced
assert len(node1.output(1).connected_ports) == 0


def test_node_serialization(qtbot, nodeviewer):
viewer_dict = nodeviewer.to_dict()
qtbot.wait(2000)
nodeviewer.clear()
qtbot.wait(1000)
nodeviewer.from_dict(viewer_dict)
qtbot.wait(10000)

0 comments on commit 5e34498

Please sign in to comment.