Skip to content

Commit

Permalink
Add validate_node function
Browse files Browse the repository at this point in the history
* Check if target node is valid
* Create sub-group, if it doesn't exist
  • Loading branch information
GDYendell committed Mar 10, 2017
1 parent 91c7830 commit 65a9c9a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tests/vdsgen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,35 @@ def test_create_vds_maps(self, target_mock, source_mock, map_mock):
self.assertEqual([map_mock.return_value]*6, map_list)


class ValidateNodeTest(unittest.TestCase):

def setUp(self):
self.file_mock = MagicMock()

def test_validate_node_creates(self):
self.file_mock.get.return_value = None

vdsgen.validate_node(self.file_mock, "entry/detector/detector1")

self.file_mock.create_group.assert_called_once_with("entry/detector")

def test_validate_node_exists_then_no_op(self):
self.file_mock.get.return_value = "Group"

vdsgen.validate_node(self.file_mock, "entry/detector/detector1")

self.file_mock.create_group.assert_not_called()

def test_validate_node_invalid_then_error(self):

with self.assertRaises(ValueError):
vdsgen.validate_node(self.file_mock, "/entry/detector/detector1")
with self.assertRaises(ValueError):
vdsgen.validate_node(self.file_mock, "entry/detector/detector1/")
with self.assertRaises(ValueError):
vdsgen.validate_node(self.file_mock, "/entry/detector/detector1/")


class MainTest(unittest.TestCase):

file_mock = MagicMock()
Expand Down
22 changes: 22 additions & 0 deletions vdsgen/vdsgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,27 @@ def create_vds_maps(source, vds_data, source_node, target_node):
return map_list


def validate_node(vds_file, target_node):
"""Check if it is possible to create the given node.
Check the target node is valid (no leading or trailing slashes)
Create any sub-group of the target node if it doesn't exist.
Args:
vds_file(h5py.File): File to check for node
target_node(str): Full path to node
"""
if target_node.startswith("/") or target_node.endswith("/"):
raise ValueError("Target node should have no leading or trailing "
"slashes, got {}".format(target_node))

if "/" in target_node:
sub_group = target_node.rsplit("/", 1)[0]
if vds_file.get(sub_group) is None:
vds_file.create_group(sub_group)


def generate_vds(path, prefix=None, files=None, output=None, source=None,
source_node=None, target_node=None,
stripe_spacing=None, module_spacing=None):
Expand Down Expand Up @@ -312,6 +333,7 @@ def generate_vds(path, prefix=None, files=None, output=None, source=None,

logging.info("Creating VDS at %s", output_file)
with h5.File(output_file, "w", libver="latest") as vds_file:
validate_node(vds_file, target_node)
vds_file.create_virtual_dataset(VMlist=map_list, fill_value=0x1)

logging.info("Creation successful!")
Expand Down

0 comments on commit 65a9c9a

Please sign in to comment.