diff --git a/tests/app_test.py b/tests/app_test.py new file mode 100644 index 0000000..2a173a0 --- /dev/null +++ b/tests/app_test.py @@ -0,0 +1,173 @@ +import unittest +from pkg_resources import require + +require("mock") +from mock import MagicMock, patch, call + +from vdsgen import app + +app_patch_path = "vdsgen.app" +parser_patch_path = app_patch_path + ".ArgumentParser" +VDSGenerator_patch_path = app_patch_path + ".VDSGenerator" + + +class ParseArgsTest(unittest.TestCase): + + @patch(VDSGenerator_patch_path) + @patch(app_patch_path + '.ArgumentDefaultsHelpFormatter') + @patch(parser_patch_path) + def test_parser(self, parser_init_mock, formatter_mock, gen_mock): + parser_mock = parser_init_mock.return_value + add_mock = parser_mock.add_argument + add_group_mock = parser_mock.add_argument_group + add_exclusive_group_mock = parser_mock.add_mutually_exclusive_group + parse_mock = parser_mock.parse_args + parse_mock.return_value = MagicMock(empty=False, files=None) + empty_mock = MagicMock() + other_mock = MagicMock() + add_group_mock.side_effect = [empty_mock, other_mock] + exclusive_group_mock = add_exclusive_group_mock.return_value + expected_message = """ +------------------------------------------------------------------------------- +A script to create a virtual dataset composed of multiple raw HDF5 files. + +The minimum required arguments are and either -p or -f . + +For example: + + > ../vdsgen/app.py /scratch/images -p stripe_ + > ../vdsgen/app.py /scratch/images -f stripe_1.hdf5 stripe_2.hdf5 + +You can create an empty VDS, for raw files that don't exist yet, with the -e +flag; you will then need to provide --shape and --data_type, though defaults +are provided for these. +------------------------------------------------------------------------------- +""" + + args = app.parse_args() + + parser_init_mock.assert_called_once_with( + usage=expected_message, + formatter_class=formatter_mock) + add_exclusive_group_mock.assert_called_with(required=True) + exclusive_group_mock.add_argument.assert_has_calls( + [call("-p", "--prefix", type=str, default=None, dest="prefix", + help="Prefix of files to search for - e.g 'stripe_' " + "to combine 'stripe_1.hdf5' and 'stripe_2.hdf5'."), + call("-f", "--files", nargs="*", type=str, default=None, + dest="files", + help="Explicit names of raw files in .")]) + + add_mock.assert_called_with( + "path", type=str, help="Root folder of source files and VDS.") + + add_group_mock.assert_has_calls([call()] * 2) + empty_mock.add_argument.assert_has_calls( + [call("-e", "--empty", action="store_true", dest="empty", + help="Make empty VDS pointing to datasets " + "that don't exist yet."), + call("--shape", type=int, nargs="*", default=[1, 256, 2048], + dest="shape", + help="Shape of dataset - 'frames height width', where " + "frames is N dimensional."), + call("-t", "--data_type", type=str, default="uint16", + dest="data_type", help="Data type of raw datasets.")]) + other_mock.add_argument.assert_has_calls( + [call("-o", "--output", type=str, default=None, dest="output", + help="Output file name. If None then generated as input " + "file prefix with vds suffix."), + call("-s", "--stripe_spacing", type=int, dest="stripe_spacing", + default=gen_mock.stripe_spacing, + help="Spacing between two stripes in a module."), + call("-m", "--module_spacing", type=int, dest="module_spacing", + default=gen_mock.module_spacing, + help="Spacing between two modules."), + call("--source_node", type=str, dest="source_node", + default=gen_mock.source_node, + help="Data node in source HDF5 files."), + call("--target_node", type=str, + default=gen_mock.target_node, dest="target_node", + help="Data node in VDS file."), + call("-l", "--log_level", type=int, dest="log_level", + default=gen_mock.log_level, + help="Logging level (off=3, info=2, debug=1).")]) + + parse_mock.assert_called_once_with() + self.assertEqual(parse_mock.return_value, args) + + @patch(parser_patch_path + '.error') + @patch(parser_patch_path + '.parse_args', + return_value=MagicMock(empty=True, files=None)) + def test_empty_and_not_files_then_error(self, parse_mock, error_mock): + + app.parse_args() + + error_mock.assert_called_once_with( + "To make an empty VDS you must explicitly define --files for the " + "eventual raw datasets.") + + @patch(parser_patch_path + '.error') + @patch(parser_patch_path + '.parse_args', + return_value=MagicMock(empty=True, files=["file"])) + def test_only_one_file_then_error(self, parse_mock, error_mock): + + app.parse_args() + + error_mock.assert_called_once_with( + "Must define at least two files to combine.") + + +class MainTest(unittest.TestCase): + @patch(VDSGenerator_patch_path) + @patch(app_patch_path + '.parse_args', + return_value=MagicMock( + path="/test/path", prefix="stripe_", empty=True, + files=["file1.hdf5", "file2.hdf5"], output="vds", + shape=[3, 256, 2048], data_type="int16", + source_node="data", target_node="full_frame", + stripe_spacing=3, module_spacing=127, + log_level=2)) + def test_main_empty(self, parse_mock, init_mock): + gen_mock = init_mock.return_value + args_mock = parse_mock.return_value + + app.main() + + parse_mock.assert_called_once_with() + init_mock.assert_called_once_with( + args_mock.path, + prefix=args_mock.prefix, files=args_mock.files, + output=args_mock.output, + source=dict(shape=args_mock.shape, dtype=args_mock.data_type), + source_node=args_mock.source_node, + target_node=args_mock.target_node, + stripe_spacing=args_mock.stripe_spacing, + module_spacing=args_mock.module_spacing, + log_level=args_mock.log_level) + + gen_mock.generate_vds.assert_called_once_with() + + @patch(VDSGenerator_patch_path) + @patch(app_patch_path + '.parse_args', + return_value=MagicMock( + path="/test/path", prefix="stripe_", empty=False, + files=["file1.hdf5", "file2.hdf5"], output="vds", + frames=3, height=256, width=2048, data_type="int16", + source_node="data", target_node="full_frame", + stripe_spacing=3, module_spacing=127, + log_level=2)) + def test_main_not_empty(self, parse_mock, generate_mock): + args_mock = parse_mock.return_value + + app.main() + + parse_mock.assert_called_once_with() + generate_mock.assert_called_once_with( + args_mock.path, + prefix=args_mock.prefix, output="vds", files=args_mock.files, + source=None, + source_node=args_mock.source_node, + stripe_spacing=args_mock.stripe_spacing, + target_node=args_mock.target_node, + module_spacing=args_mock.module_spacing, + log_level=args_mock.log_level) diff --git a/tests/vdsgen_test.py b/tests/vdsgen_test.py deleted file mode 100644 index fbdd497..0000000 --- a/tests/vdsgen_test.py +++ /dev/null @@ -1,190 +0,0 @@ -import unittest - -from pkg_resources import require -require("mock") -from mock import MagicMock, patch, ANY, call -vdsgen_patch_path = "vdsgen.vdsgen" -parser_patch_path = "argparse.ArgumentParser" -h5py_patch_path = "h5py" - -import os -import sys -sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "h5py")) - -from vdsgen import vdsgen - - -class ParseArgsTest(unittest.TestCase): - - @patch(parser_patch_path + '.add_argument') - @patch(parser_patch_path + '.parse_args') - def test_no_args_given(self, parse_mock, add_mock): - args = vdsgen.parse_args() - - add_mock.has_calls(call("path", type=str, - help="Path to folder containing HDF5 files."), - call("prefix", type=str, - help="Root name of images - e.g 'stripe_' to " - "combine the images 'stripe_1.hdf5', " - "'stripe_2.hdf5' and 'stripe_3.hdf5' " - "located at .")) - parse_mock.assert_called_once_with() - self.assertEqual(parse_mock.return_value, args) - - -class FindFilesTest(unittest.TestCase): - - @patch('os.listdir', - return_value=["stripe_1.h5", "stripe_2.h5", "stripe_3.h5", - "stripe_4.h5", "stripe_5.h5", "stripe_6.h5"]) - def test_given_files_then_return(self, _): - expected_files = ["/test/path/stripe_1.h5", "/test/path/stripe_2.h5", - "/test/path/stripe_3.h5", "/test/path/stripe_4.h5", - "/test/path/stripe_5.h5", "/test/path/stripe_6.h5"] - - files = vdsgen.find_files("/test/path", "stripe_") - - self.assertEqual(expected_files, files) - - @patch('os.listdir', - return_value=["stripe_4.h5", "stripe_1.h5", "stripe_6.h5", - "stripe_3.h5", "stripe_2.h5", "stripe_5.h5"]) - def test_given_files_out_of_order_then_return(self, _): - expected_files = ["/test/path/stripe_1.h5", "/test/path/stripe_2.h5", - "/test/path/stripe_3.h5", "/test/path/stripe_4.h5", - "/test/path/stripe_5.h5", "/test/path/stripe_6.h5"] - - files = vdsgen.find_files("/test/path", "stripe_") - - self.assertEqual(expected_files, files) - - @patch('os.listdir', return_value=["stripe_1.h5"]) - def test_given_one_file_then_error(self, _): - - with self.assertRaises(IOError): - vdsgen.find_files("/test/path", "stripe_") - - @patch('os.listdir', return_value=[]) - def test_given_no_files_then_error(self, _): - - with self.assertRaises(IOError): - vdsgen.find_files("/test/path", "stripe_") - - -class SimpleFunctionsTest(unittest.TestCase): - - def test_generate_vds_name(self): - expected_name = "stripe_vds.h5" - files = ["stripe_1.h5", "stripe_2.h5", "stripe_3.h5", - "stripe_4.h5", "stripe_5.h5", "stripe_6.h5"] - - vds_name = vdsgen.construct_vds_name("stripe_", files) - - self.assertEqual(expected_name, vds_name) - - mock_data = dict(data=MagicMock(shape=(3, 256, 2048), dtype="uint16")) - - @patch(h5py_patch_path + '.File', return_value=mock_data) - def test_grab_metadata(self, h5file_mock): - expected_data = dict(frames=3, height=256, width=2048, dtype="uint16") - - meta_data = vdsgen.grab_metadata("/test/path") - - h5file_mock.assert_called_once_with("/test/path", "r") - self.assertEqual(expected_data, meta_data) - - @patch(vdsgen_patch_path + '.grab_metadata', - return_value=dict(frames=3, height=256, width=2048, dtype="uint16")) - def test_process_source_datasets_given_valid_data(self, grab_mock): - files = ["stripe_1.h5", "stripe_2.h5"] - expected_source = vdsgen.Source(frames=3, height=256, width=2048, - dtype="uint16", datasets=files) - - source = vdsgen.process_source_datasets(files) - - grab_mock.assert_has_calls([call("stripe_1.h5"), call("stripe_2.h5")]) - self.assertEqual(expected_source, source) - - @patch(vdsgen_patch_path + '.grab_metadata', - side_effect=[dict(frames=3, height=256, width=2048, dtype="uint16"), - dict(frames=4, height=256, width=2048, - dtype="uint16")]) - def test_process_source_datasets_given_mismatched_data(self, grab_mock): - files = ["stripe_1.h5", "stripe_2.h5"] - - with self.assertRaises(ValueError): - vdsgen.process_source_datasets(files) - - grab_mock.assert_has_calls([call("stripe_1.h5"), call("stripe_2.h5")]) - - def test_construct_vds_metadata(self): - source = vdsgen.Source(frames=3, height=256, width=2048, - dtype="uint16", datasets=[""]*6) - expected_vds = vdsgen.VDS(shape=(3, 1586, 2048), spacing=266, - path="/test/path") - - vds = vdsgen.construct_vds_metadata(source, "/test/path") - - self.assertEqual(expected_vds, vds) - - @patch(h5py_patch_path + '.VirtualMap') - @patch(h5py_patch_path + '.VirtualSource') - @patch(h5py_patch_path + '.VirtualTarget') - def test_create_vds_maps(self, target_mock, source_mock, map_mock): - source = vdsgen.Source(frames=3, height=256, width=2048, - dtype="uint16", datasets=["source"]*6) - vds = vdsgen.VDS(shape=(3, 1586, 2048), spacing=266, path="/test/path") - - map_list = vdsgen.create_vds_maps(source, vds) - - target_mock.assert_called_once_with("/test/path", "full_frame", - shape=(3, 1586, 2048)) - source_mock.assert_has_calls([call("source", "data", - shape=(3, 256, 2048))]*6) - # TODO: Improve this assert by passing numpy arrays to check slicing - map_mock.assert_has_calls([call(source_mock.return_value, - target_mock.return_value.__getitem__.return_value, - dtype="uint16")]*6) - self.assertEqual([map_mock.return_value]*6, map_list) - - -class MainTest(unittest.TestCase): - - file_mock = MagicMock() - - @patch(h5py_patch_path + '.File', return_value=file_mock) - @patch(vdsgen_patch_path + '.create_vds_maps') - @patch(vdsgen_patch_path + '.construct_vds_metadata') - @patch(vdsgen_patch_path + '.process_source_datasets') - @patch(vdsgen_patch_path + '.construct_vds_name', - return_value="stripe_vds.h5") - @patch(vdsgen_patch_path + '.find_files', - return_value=["stripe_1.hdf5", "stripe_2.hdf5", "stripe_3.hdf5"]) - def test_generate_vds(self, find_mock, gen_mock, process_mock, - construct_mock, create_mock, h5file_mock): - vds_file_mock = self.file_mock.__enter__.return_value - - vdsgen.generate_vds("/test/path", "stripe_") - - find_mock.assert_called_once_with("/test/path", "stripe_") - gen_mock.assert_called_once_with("stripe_", find_mock.return_value) - process_mock.assert_called_once_with(find_mock.return_value) - construct_mock.assert_called_once_with(process_mock.return_value, - "/test/path/stripe_vds.h5") - create_mock.assert_called_once_with(process_mock.return_value, - construct_mock.return_value) - h5file_mock.assert_called_once_with("/test/path/stripe_vds.h5", "w", - libver="latest") - vds_file_mock.create_virtual_dataset.assert_called_once_with( - VMlist=create_mock.return_value, fill_value=0x1) - - @patch(vdsgen_patch_path + '.generate_vds') - @patch(vdsgen_patch_path + '.parse_args', - return_value=MagicMock(path="/test/path", prefix="stripe_")) - def test_main(self, parse_mock, generate_mock): - args_mock = parse_mock.return_value - - vdsgen.main() - - parse_mock.assert_called_once_with() - generate_mock.assert_called_once_with(args_mock.path, args_mock.prefix) diff --git a/tests/vdsgenerator_test.py b/tests/vdsgenerator_test.py new file mode 100644 index 0000000..4aaf63e --- /dev/null +++ b/tests/vdsgenerator_test.py @@ -0,0 +1,361 @@ +import os +import sys +import unittest + +from pkg_resources import require +require("mock") +from mock import MagicMock, patch, call + +from vdsgen import vdsgenerator +from vdsgen.vdsgenerator import VDSGenerator + +vdsgen_patch_path = "vdsgen.vdsgenerator" +VDSGenerator_patch_path = vdsgen_patch_path + ".VDSGenerator" +h5py_patch_path = "h5py" + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "h5py")) + + +class VDSGeneratorTester(VDSGenerator): + + """A version of VDSGenerator without initialisation. + + For testing single methods of the class. Must have required attributes + passed before calling testee function. + + """ + + def __init__(self, **kwargs): + for attribute, value in kwargs.items(): + self.__setattr__(attribute, value) + + +class VDSGeneratorInitTest(unittest.TestCase): + + @patch('os.path.isfile', return_value=True) + @patch(VDSGenerator_patch_path + '.process_source_datasets') + @patch(VDSGenerator_patch_path + '.construct_vds_name', + return_value="stripe_vds.hdf5") + @patch(VDSGenerator_patch_path + '.find_files', + return_value=["/test/path/stripe_1.hdf5", + "/test/path/stripe_2.hdf5", + "/test/path/stripe_3.hdf5"]) + def test_generate_vds_defaults(self, find_mock, construct_mock, + process_mock, isfile_mock): + expected_files = ["stripe_1.hdf5", "stripe_2.hdf5", "stripe_3.hdf5"] + + gen = VDSGenerator("/test/path", prefix="stripe_") + + find_mock.assert_called_once_with() + construct_mock.assert_called_once_with(expected_files) + process_mock.assert_called_once_with() + + self.assertEqual("/test/path", gen.path) + self.assertEqual("stripe_", gen.prefix) + self.assertEqual("stripe_vds.hdf5", gen.name) + self.assertEqual(find_mock.return_value, gen.datasets) + self.assertEqual(process_mock.return_value, gen.source_metadata) + self.assertEqual("data", gen.source_node) + self.assertEqual("full_frame", gen.target_node) + self.assertEqual(10, gen.stripe_spacing) + self.assertEqual(10, gen.module_spacing) + self.assertEqual(gen.CREATE, gen.mode) + + def test_generate_vds_given_args(self): + files = ["stripe_1.h5", "stripe_2.h5"] + file_paths = ["/test/path/" + file_ for file_ in files] + source_dict = dict(shape=(3, 256, 2048), dtype="int16") + source = vdsgenerator.Source(frames=(3,), height=256, width=2048, + dtype="int16") + + gen = VDSGenerator("/test/path", + files=files, + output="vds.hdf5", + source=source_dict, + source_node="entry/data/data", + target_node="entry/detector/detector1", + stripe_spacing=3, module_spacing=127) + + self.assertEqual("/test/path", gen.path) + self.assertEqual("stripe_", gen.prefix) + self.assertEqual("vds.hdf5", gen.name) + self.assertEqual(file_paths, gen.datasets) + self.assertEqual(source, gen.source_metadata) + self.assertEqual("entry/data/data", gen.source_node) + self.assertEqual("entry/detector/detector1", gen.target_node) + self.assertEqual(3, gen.stripe_spacing) + self.assertEqual(127, gen.module_spacing) + self.assertEqual(gen.CREATE, gen.mode) + + def test_generate_vds_prefix_and_files_then_error(self): + files = ["stripe_1.h5", "stripe_2.h5"] + source_dict = dict(frames=3, height=256, width=2048, dtype="int16") + + with self.assertRaises(ValueError): + VDSGenerator("/test/path", + prefix="stripe_", files=files, + output="vds.hdf5", + source=source_dict, + source_node="entry/data/data", + target_node="entry/detector/detector1", + stripe_spacing=3, module_spacing=127) + + @patch('os.path.isfile', return_value=False) + def test_generate_vds_no_source_or_files_then_error(self, _): + + with self.assertRaises(IOError) as e: + VDSGenerator("/test/path", + files=["file1", "file2"], + output="vds.hdf5") + + self.assertEqual("File /test/path/file1 does not exist. To create VDS " + "from raw files that haven't been created yet, " + "source must be provided.", + e.exception.message) + + +class FindFilesTest(unittest.TestCase): + + def setUp(self): + self.gen = VDSGeneratorTester(path="/test/path", prefix="stripe_") + + @patch('os.listdir', + return_value=["stripe_1.h5", "stripe_2.h5", "stripe_3.h5", + "stripe_4.h5", "stripe_5.h5", "stripe_6.h5"]) + def test_given_files_then_return(self, _): + expected_files = ["/test/path/stripe_1.h5", "/test/path/stripe_2.h5", + "/test/path/stripe_3.h5", "/test/path/stripe_4.h5", + "/test/path/stripe_5.h5", "/test/path/stripe_6.h5"] + + files = self.gen.find_files() + + self.assertEqual(expected_files, files) + + @patch('os.listdir', + return_value=["stripe_4.h5", "stripe_1.h5", "stripe_6.h5", + "stripe_3.h5", "stripe_2.h5", "stripe_5.h5"]) + def test_given_files_out_of_order_then_return(self, _): + expected_files = ["/test/path/stripe_1.h5", "/test/path/stripe_2.h5", + "/test/path/stripe_3.h5", "/test/path/stripe_4.h5", + "/test/path/stripe_5.h5", "/test/path/stripe_6.h5"] + + files = self.gen.find_files() + + self.assertEqual(expected_files, files) + + @patch('os.listdir', return_value=["stripe_1.h5"]) + def test_given_one_file_then_error(self, _): + + with self.assertRaises(IOError): + self.gen.find_files() + + @patch('os.listdir', return_value=[]) + def test_given_no_files_then_error(self, _): + + with self.assertRaises(IOError): + self.gen.find_files() + + +class SimpleFunctionsTest(unittest.TestCase): + + def test_generate_vds_name(self): + gen = VDSGeneratorTester(prefix="stripe_") + expected_name = "stripe_vds.h5" + files = ["stripe_1.h5", "stripe_2.h5", "stripe_3.h5", + "stripe_4.h5", "stripe_5.h5", "stripe_6.h5"] + + vds_name = gen.construct_vds_name(files) + + self.assertEqual(expected_name, vds_name) + + mock_data = dict(data=MagicMock(shape=(3, 256, 2048), dtype="uint16")) + + @patch(h5py_patch_path + '.File', return_value=mock_data) + def test_grab_metadata(self, h5file_mock): + gen = VDSGeneratorTester(source_node="data") + expected_data = dict(frames=(3,), height=256, width=2048, dtype="uint16") + + meta_data = gen.grab_metadata("/test/path/stripe.hdf5") + + h5file_mock.assert_called_once_with("/test/path/stripe.hdf5", "r") + self.assertEqual(expected_data, meta_data) + + @patch(VDSGenerator_patch_path + '.grab_metadata', + return_value=dict(frames=(3,), height=256, width=2048, dtype="uint16")) + def test_process_source_datasets_given_valid_data(self, grab_mock): + gen = VDSGeneratorTester(datasets=["stripe_1.h5", "stripe_2.h5"]) + expected_source = vdsgenerator.Source(frames=(3,), height=256, + width=2048, + dtype="uint16") + + source = gen.process_source_datasets() + + grab_mock.assert_has_calls([call("stripe_1.h5"), call("stripe_2.h5")]) + self.assertEqual(expected_source, source) + + @patch(VDSGenerator_patch_path + '.grab_metadata', + side_effect=[dict(frames=3, height=256, width=2048, dtype="uint16"), + dict(frames=4, height=256, width=2048, + dtype="uint16")]) + def test_process_source_datasets_given_mismatched_data(self, grab_mock): + gen = VDSGeneratorTester(datasets=["stripe_1.h5", "stripe_2.h5"]) + + with self.assertRaises(ValueError): + gen.process_source_datasets() + + grab_mock.assert_has_calls([call("stripe_1.h5"), call("stripe_2.h5")]) + + def test_construct_vds_metadata(self): + gen = VDSGeneratorTester(datasets=[""] * 6, stripe_spacing=10, + module_spacing=100) + source = vdsgenerator.Source(frames=(3,), height=256, width=2048, + dtype="uint16") + expected_vds = vdsgenerator.VDS(shape=(3, 1766, 2048), + spacing=[10, 100, 10, 100, 10, 0]) + + vds = gen.construct_vds_metadata(source) + + self.assertEqual(expected_vds, vds) + + @patch(h5py_patch_path + '.VirtualMap') + @patch(h5py_patch_path + '.VirtualSource') + @patch(h5py_patch_path + '.VirtualTarget') + def test_create_vds_maps(self, target_mock, source_mock, map_mock): + gen = VDSGeneratorTester(output_file="/test/path/vds.hdf5", + stripe_spacing=10, module_spacing=100, + target_node="full_frame", source_node="data", + datasets=["source"] * 6, name="vds.hdf5") + source = vdsgenerator.Source(frames=(3,), height=256, width=2048, + dtype="uint16") + vds = vdsgenerator.VDS(shape=(3, 1586, 2048), spacing=[10] * 5 + [0]) + + map_list = gen.create_vds_maps(source, vds) + + target_mock.assert_called_once_with("/test/path/vds.hdf5", + "full_frame", + shape=(3, 1586, 2048)) + source_mock.assert_has_calls([call("source", "data", + shape=(3, 256, 2048))] * 6) + # TODO: Improve this assert by passing numpy arrays to check slicing + map_mock.assert_has_calls([ + call(source_mock.return_value, + target_mock.return_value.__getitem__.return_value, + dtype="uint16")]*6) + self.assertEqual([map_mock.return_value]*6, map_list) + + +class ValidateNodeTest(unittest.TestCase): + + def setUp(self): + self.file_mock = MagicMock() + + def test_validate_node_creates(self): + gen = VDSGeneratorTester(target_node="/entry/detector/detector1") + self.file_mock.get.return_value = None + + gen.validate_node(self.file_mock) + + self.file_mock.create_group.assert_called_once_with("/entry/detector") + + def test_validate_node_exists_then_no_op(self): + gen = VDSGeneratorTester(target_node="entry/detector/detector1") + self.file_mock.get.return_value = "Group" + + gen.validate_node(self.file_mock) + + self.file_mock.create_group.assert_not_called() + + def test_validate_node_trailing_slash_then_removed(self): + gen = VDSGeneratorTester(target_node="/entry/detector/detector1//") + self.file_mock.get.return_value = None + + gen.validate_node(self.file_mock) + + self.file_mock.create_group.assert_called_once_with("/entry/detector") + + +class GenerateVDSTest(unittest.TestCase): + + file_mock = MagicMock() + + @patch('os.path.isfile', return_value=False) + @patch(VDSGenerator_patch_path + '.validate_node') + @patch(h5py_patch_path + '.File', return_value=file_mock) + @patch(VDSGenerator_patch_path + '.create_vds_maps') + @patch(VDSGenerator_patch_path + '.construct_vds_metadata') + def test_generate_vds_create(self, construct_mock, create_mock, + h5file_mock, validate_mock, isfile_mock): + source_mock = MagicMock() + gen = VDSGeneratorTester(path="/test/path", prefix="stripe_", + output_file="/test/path/vds.hdf5", + name="vds.hdf5", + target_node="full_frame", source_node="data", + datasets=["stripe_1.hdf5", "stripe_2.hdf5", + "stripe_3.hdf5"], + source_metadata=source_mock) + self.file_mock.reset_mock() + vds_file_mock = self.file_mock.__enter__.return_value + vds_file_mock.get.return_value = None + + gen.generate_vds() + + isfile_mock.assert_called_once_with("/test/path/vds.hdf5") + construct_mock.assert_called_once_with(source_mock) + create_mock.assert_called_once_with(source_mock, + construct_mock.return_value) + validate_mock.assert_called_once_with(vds_file_mock) + h5file_mock.assert_called_once_with( + "/test/path/vds.hdf5", "w", libver="latest") + vds_file_mock.create_virtual_dataset.assert_called_once_with( + VMlist=create_mock.return_value, fillvalue=0x1) + + @patch('os.path.isfile', return_value=True) + @patch(VDSGenerator_patch_path + '.validate_node') + @patch(h5py_patch_path + '.File', return_value=file_mock) + @patch(VDSGenerator_patch_path + '.create_vds_maps') + @patch(VDSGenerator_patch_path + '.construct_vds_metadata') + def test_generate_vds_append(self, construct_mock, create_mock, + h5file_mock, validate_mock, isfile_mock): + source_mock = MagicMock() + gen = VDSGeneratorTester(path="/test/path", prefix="stripe_", + output_file="/test/path/vds.hdf5", + name="vds.hdf5", + target_node="full_frame", source_node="data", + datasets=["stripe_1.hdf5", "stripe_2.hdf5", + "stripe_3.hdf5"], + source_metadata=source_mock) + self.file_mock.reset_mock() + vds_file_mock = self.file_mock.__enter__.return_value + vds_file_mock.get.return_value = None + + gen.generate_vds() + + isfile_mock.assert_called_once_with("/test/path/vds.hdf5") + construct_mock.assert_called_once_with(source_mock) + create_mock.assert_called_once_with(source_mock, + construct_mock.return_value) + validate_mock.assert_called_once_with(vds_file_mock) + h5file_mock.assert_has_calls([ + call("/test/path/vds.hdf5", "r", libver="latest"), + call("/test/path/vds.hdf5", "a", libver="latest")]) + vds_file_mock.create_virtual_dataset.assert_called_once_with( + VMlist=create_mock.return_value, fillvalue=0x1) + + @patch('os.path.isfile', return_value=True) + @patch(h5py_patch_path + '.File', return_value=file_mock) + def test_generate_vds_node_exists_then_error(self, h5file_mock, + isfile_mock): + source_mock = MagicMock() + gen = VDSGeneratorTester(path="/test/path", prefix="stripe_", + output_file="/test/path/vds.hdf5", + name="vds.hdf5", + target_node="full_frame", source_node="data", + datasets=["stripe_1.hdf5", "stripe_2.hdf5", + "stripe_3.hdf5"], + source_metadata=source_mock) + self.file_mock.reset_mock() + vds_file_mock = self.file_mock.__enter__.return_value + vds_file_mock.get.return_value = "Group" + + with self.assertRaises(IOError): + gen.generate_vds() diff --git a/vdsgen/__init__.py b/vdsgen/__init__.py index b47eda2..a2f2240 100644 --- a/vdsgen/__init__.py +++ b/vdsgen/__init__.py @@ -1,4 +1,4 @@ -"""Make 'generate_vds' easy to import.""" -from vdsgen import generate_vds +"""Make VDSGenerator easy to import.""" +from vdsgenerator import VDSGenerator -__all__ = ["generate_vds"] +__all__ = ["VDSGenerator"] diff --git a/vdsgen/app.py b/vdsgen/app.py new file mode 100644 index 0000000..4eeb52b --- /dev/null +++ b/vdsgen/app.py @@ -0,0 +1,116 @@ +import sys +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter + +from vdsgenerator import VDSGenerator + +help_message = """ +------------------------------------------------------------------------------- +A script to create a virtual dataset composed of multiple raw HDF5 files. + +The minimum required arguments are and either -p or -f . + +For example: + + > ../vdsgen/app.py /scratch/images -p stripe_ + > ../vdsgen/app.py /scratch/images -f stripe_1.hdf5 stripe_2.hdf5 + +You can create an empty VDS, for raw files that don't exist yet, with the -e +flag; you will then need to provide --shape and --data_type, though defaults +are provided for these. +------------------------------------------------------------------------------- +""" + + +def parse_args(): + """Parse command line arguments.""" + parser = ArgumentParser(usage=help_message, + formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument( + "path", type=str, help="Root folder of source files and VDS.") + + # Definition of file names in - Common prefix or explicit list + file_definition = parser.add_mutually_exclusive_group(required=True) + file_definition.add_argument( + "-p", "--prefix", type=str, default=None, dest="prefix", + help="Prefix of files to search for - e.g 'stripe_' to combine " + "'stripe_1.hdf5' and 'stripe_2.hdf5'.") + file_definition.add_argument( + "-f", "--files", nargs="*", type=str, default=None, dest="files", + help="Explicit names of raw files in .") + + # Arguments required to allow VDS to be created before raw files exist + empty_vds = parser.add_argument_group() + empty_vds.add_argument( + "-e", "--empty", action="store_true", dest="empty", + help="Make empty VDS pointing to datasets that don't exist yet.") + empty_vds.add_argument( + "--shape", type=int, nargs="*", default=[1, 256, 2048], dest="shape", + help="Shape of dataset - 'frames height width', where frames is N " + "dimensional.") + empty_vds.add_argument( + "-t", "--data_type", type=str, default="uint16", dest="data_type", + help="Data type of raw datasets.") + + # Arguments to override defaults - each is atomic + other_args = parser.add_argument_group() + other_args.add_argument( + "-o", "--output", type=str, default=None, dest="output", + help="Output file name. If None then generated as input file prefix " + "with vds suffix.") + other_args.add_argument( + "-s", "--stripe_spacing", type=int, dest="stripe_spacing", + default=VDSGenerator.stripe_spacing, + help="Spacing between two stripes in a module.") + other_args.add_argument( + "-m", "--module_spacing", type=int, dest="module_spacing", + default=VDSGenerator.module_spacing, + help="Spacing between two modules.") + other_args.add_argument( + "--source_node", type=str, dest="source_node", + default=VDSGenerator.source_node, + help="Data node in source HDF5 files.") + other_args.add_argument( + "--target_node", type=str, dest="target_node", + default=VDSGenerator.target_node, help="Data node in VDS file.") + other_args.add_argument( + "-l", "--log_level", type=int, dest="log_level", + default=VDSGenerator.log_level, + help="Logging level (off=3, info=2, debug=1).") + + args = parser.parse_args() + args.shape = tuple(args.shape) + + if args.empty and args.files is None: + parser.error( + "To make an empty VDS you must explicitly define --files for the " + "eventual raw datasets.") + if args.files is not None and len(args.files) < 2: + parser.error("Must define at least two files to combine.") + + return args + + +def main(): + """Run program.""" + args = parse_args() + + if args.empty: + source_metadata = dict(shape=args.shape, dtype=args.data_type) + else: + source_metadata = None + + gen = VDSGenerator(args.path, + prefix=args.prefix, files=args.files, + output=args.output, + source=source_metadata, + source_node=args.source_node, + target_node=args.target_node, + stripe_spacing=args.stripe_spacing, + module_spacing=args.module_spacing, + log_level=args.log_level) + + gen.generate_vds() + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/vdsgen/vdsgen.py b/vdsgen/vdsgen.py deleted file mode 100644 index af86758..0000000 --- a/vdsgen/vdsgen.py +++ /dev/null @@ -1,201 +0,0 @@ -#!/bin/env dls-python -"""A CLI tool for generating virtual datasets from individual HDF5 files.""" - -import os -import sys -from argparse import ArgumentParser -import re -import logging - -from collections import namedtuple - -import h5py as h5 - -logging.basicConfig(level=logging.INFO) - -Source = namedtuple("Source", - ["datasets", "frames", "height", "width", "dtype"]) -VDS = namedtuple("VDS", ["shape", "spacing", "path"]) - -DATASET_SPACING = 10 # Pixel spacing between each dataset in VDS - - -def parse_args(): - """Parse command line arguments.""" - parser = ArgumentParser() - parser.add_argument("path", type=str, - help="Path to folder containing HDF5 files.") - parser.add_argument("prefix", type=str, - help="Root name of images - e.g 'stripe_' to combine " - "the images 'stripe_1.hdf5', 'stripe_2.hdf5' " - "and 'stripe_3.hdf5' located at .") - - return parser.parse_args() - - -def find_files(path, prefix): - """Find HDF5 files in given folder with given prefix. - - Args: - path(str): Path to folder containing files - prefix(str): Root name of image files - - Returns: - list: HDF5 files in folder that have the given prefix - - """ - regex = re.compile(prefix + r"\d\.(hdf5|h5)") - - files = [] - for file_ in sorted(os.listdir(path)): - if re.match(regex, file_): - files.append(os.path.join(path, file_)) - - if len(files) == 0: - raise IOError("No files matching pattern found.") - elif len(files) < 2: - raise IOError("Folder must contain more than one matching HDF5 file.") - else: - return files - - -def construct_vds_name(prefix, files): - """Generate the file name for the VDS from the sub files. - - Args: - prefix(str): Root name of image files - files(list(str)): HDF5 files being combined - - Returns: - str: Name of VDS file - - """ - _, ext = os.path.splitext(files[0]) - vds_name = "{prefix}vds{ext}".format(prefix=prefix, ext=ext) - - return vds_name - - -def grab_metadata(file_path): - """Grab data from given HDF5 file. - - Args: - file_path(str): Path to HDF5 file - - Returns: - dict: Number of frames, height, width and data type of datasets - - """ - h5_data = h5.File(file_path, 'r')["data"] - frames, height, width = h5_data.shape - data_type = h5_data.dtype - - return dict(frames=frames, height=height, width=width, dtype=data_type) - - -def process_source_datasets(datasets): - """Grab data from the given HDF5 files and check for consistency. - - Args: - datasets(list(str)): Datasets to grab data from - - Returns: - Source: Number of datasets and the attributes of them (frames, height - width and data type) - - """ - data = grab_metadata(datasets[0]) - for path in datasets[1:]: - temp_data = grab_metadata(path) - for attribute, value in data.items(): - if temp_data[attribute] != value: - raise ValueError("Files have mismatched {}".format(attribute)) - - return Source(frames=data['frames'], height=data['height'], - width=data['width'], dtype=data['dtype'], datasets=datasets) - - -def construct_vds_metadata(source, output_file): - """Construct VDS data attributes from source attributes. - - Args: - source(Source): Attributes of data sets - output_file(str): File path of new VDS - - Returns: - VDS: Shape, dataset spacing and output path of virtual data set - - """ - datasets = len(source.datasets) - height = (source.height * datasets) + (DATASET_SPACING * (datasets - 1)) - shape = (source.frames, height, source.width) - spacing = source.height + DATASET_SPACING - - return VDS(shape=shape, spacing=spacing, path=output_file) - - -def create_vds_maps(source, vds_data): - """Create a list of VirtualMaps of raw data to the VDS. - - Args: - source(Source): Source attributes - vds_data(VDS): VDS attributes - - Returns: - list(VirtualMap): Maps describing links between raw data and VDS - - """ - source_shape = (source.frames, source.height, source.width) - vds = h5.VirtualTarget(vds_data.path, "full_frame", shape=vds_data.shape) - - map_list = [] - for idx, dataset in enumerate(source.datasets): - logging.info("Processing dataset %s", idx + 1) - - v_source = h5.VirtualSource(dataset, "data", shape=source_shape) - - start = idx * vds_data.spacing - stop = start + source.height - v_target = vds[:, start:stop, :] - - v_map = h5.VirtualMap(v_source, v_target, dtype=source.dtype) - map_list.append(v_map) - - return map_list - - -def generate_vds(path, prefix): - """Generate a virtual dataset. - - Args: - path(str): Path to folder containing HDF5 files - prefix(str): Prefix of HDF5 files to generate from (in folder) - e.g. image_ for image_1.hdf5, image_2.hdf5, image_3.hdf5 - - """ - file_paths = find_files(path, prefix) - vds_name = construct_vds_name(prefix, file_paths) - output_file = os.path.join(path, vds_name) - - file_names = [file_.split('/')[-1] for file_ in file_paths] - logging.info("Combining datasets %s into %s", - ", ".join(file_names), vds_name) - - source = process_source_datasets(file_paths) - vds_data = construct_vds_metadata(source, output_file) - map_list = create_vds_maps(source, vds_data) - - logging.info("Creating VDS at %s", output_file) - with h5.File(output_file, "w", libver="latest") as vds_file: - vds_file.create_virtual_dataset(VMlist=map_list, fill_value=0x1) - - logging.info("Creation successful!") - - -def main(): - """Run program.""" - args = parse_args() - generate_vds(args.path, args.prefix) - -if __name__ == "__main__": - sys.exit(main()) diff --git a/vdsgen/vdsgenerator.py b/vdsgen/vdsgenerator.py new file mode 100644 index 0000000..2b08985 --- /dev/null +++ b/vdsgen/vdsgenerator.py @@ -0,0 +1,307 @@ +#!/bin/env dls-python +"""A CLI tool for generating virtual datasets from individual HDF5 files.""" + +import os +import re +import logging + +from collections import namedtuple + +import h5py as h5 + +Source = namedtuple("Source", ["frames", "height", "width", "dtype"]) +VDS = namedtuple("VDS", ["shape", "spacing"]) + + +class VDSGenerator(object): + + """A class to generate Virtual Datasets from raw HDF5 files.""" + + # Constants + CREATE = "w" # Will overwrite any existing file + APPEND = "a" + READ = "r" + FULL_SLICE = slice(None) + + # Default Values + stripe_spacing = 10 # Pixel spacing between stripes in a module + module_spacing = 10 # Pixel spacing between modules + source_node = "data" # Data node in source HDF5 files + target_node = "full_frame" # Data node in VDS file + mode = CREATE # Write mode for vds file + log_level = 2 + + logger = logging.getLogger("VDSGenerator") + logger.addHandler(logging.StreamHandler()) + logger.setLevel(log_level * 10) + + def __init__(self, path, prefix=None, files=None, output=None, source=None, + source_node=None, target_node=None, + stripe_spacing=None, module_spacing=None, + log_level=None): + """ + Args: + path(str): Root folder to find raw files and create VDS + prefix(str): Prefix of HDF5 files to generate from + e.g. image_ for image_1.hdf5, image_2.hdf5, image_3.hdf5 + files(list(str)): List of HDF5 files to generate from + output(str): Name of VDS file. + source(dict): Height, width, data_type and frames for source data + Provide this to create a VDS for raw files that don't exist yet + source_node(str): Data node in source HDF5 files + target_node(str): Data node in VDS file + stripe_spacing(int): Spacing between stripes in module + module_spacing(int): Spacing between modules + log_level(int): Logging level (off=3, info=2, debug=1) - + Default is info + + """ + if (prefix is None and files is None) or \ + (prefix is not None and files is not None): + raise ValueError("One, and only one, of prefix or files required.") + + self.path = path + + # Overwrite default values with arguments, if given + if source_node is not None: + self.source_node = source_node + if target_node is not None: + self.target_node = target_node + if stripe_spacing is not None: + self.stripe_spacing = stripe_spacing + if module_spacing is not None: + self.module_spacing = module_spacing + if log_level is not None: + self.logger.setLevel(log_level * 10) + + # If Files not given, find files using path and prefix. + if files is None: + self.prefix = prefix + self.datasets = self.find_files() + files = [path_.split("/")[-1] for path_ in self.datasets] + # Else, get common prefix of given files and store full path + else: + self.prefix = os.path.commonprefix(files) + self.datasets = [os.path.join(path, file_) for file_ in files] + + # If output vds file name given, use, otherwise generate a default + if output is None: + self.name = self.construct_vds_name(files) + else: + self.name = output + + # If source not given, check files exist and get metadata. + if source is None: + for file_ in self.datasets: + if not os.path.isfile(file_): + raise IOError( + "File {} does not exist. To create VDS from raw " + "files that haven't been created yet, source " + "must be provided.".format(file_)) + self.source_metadata = self.process_source_datasets() + # Else, store given source metadata + else: + frames, height, width = self.parse_shape(source['shape']) + self.source_metadata = Source( + frames=frames, height=height, width=width, + dtype=source['dtype']) + + self.output_file = os.path.abspath(os.path.join(self.path, self.name)) + + @staticmethod + def parse_shape(shape): + """Split shape into height, width and frames. + + Args: + shape(tuple): Shape of dataset + + Returns: + frames, height, width + + """ + # The last two elements of shape are the height and width of the image + height, width = shape[-2:] + # Everything before that is the frames for each axis + frames = shape[:-2] + + return frames, height, width + + def generate_vds(self): + """Generate a virtual dataset.""" + if os.path.isfile(self.output_file): + with h5.File(self.output_file, self.READ, libver="latest") as vds: + node = vds.get(self.target_node) + if node is not None: + raise IOError("VDS {file} already has an entry for node " + "{node}".format(file=self.output_file, + node=self.target_node)) + else: + self.mode = self.APPEND + + vds_data = self.construct_vds_metadata(self.source_metadata) + map_list = self.create_vds_maps(self.source_metadata, vds_data) + + self.logger.info("Creating VDS at %s", self.output_file) + with h5.File(self.output_file, self.mode, libver="latest") as vds: + self.validate_node(vds) + vds.create_virtual_dataset(VMlist=map_list, fillvalue=0x1) + + def find_files(self): + """Find HDF5 files in given folder with given prefix. + + Returns: + list: HDF5 files in folder that have the given prefix + + """ + regex = re.compile(self.prefix + r"\d+\.(hdf5|hdf|h5)") + + files = [] + for file_ in sorted(os.listdir(self.path)): + if re.match(regex, file_): + files.append(os.path.abspath(os.path.join(self.path, file_))) + + if len(files) == 0: + raise IOError("No files matching pattern found. Got path: {path}, " + "prefix: {prefix}".format(path=self.path, + prefix=self.prefix)) + elif len(files) < 2: + raise IOError("Folder must contain more than one matching HDF5 " + "file.") + else: + self.logger.debug("Found datasets %s", + ", ".join([f.split("/")[-1] for f in files])) + return files + + def construct_vds_name(self, files): + """Generate the file name for the VDS from the sub files. + + Args: + files(list(str)): HDF5 files being combined + + Returns: + str: Name of VDS file + + """ + _, ext = os.path.splitext(files[0]) + vds_name = "{prefix}vds{ext}".format(prefix=self.prefix, ext=ext) + + self.logger.debug("Generated VDS name: %s", vds_name) + return vds_name + + def grab_metadata(self, file_path): + """Grab data from given HDF5 file. + + Args: + file_path(str): Path to HDF5 file + + Returns: + dict: Number of frames, height, width and data type of datasets + + """ + h5_data = h5.File(file_path, 'r')[self.source_node] + frames, height, width = self.parse_shape(h5_data.shape) + data_type = h5_data.dtype + + return dict(frames=frames, height=height, width=width, dtype=data_type) + + def process_source_datasets(self): + """Grab data from the given HDF5 files and check for consistency. + + Returns: + Source: Number of datasets and the attributes of them (frames, + height width and data type) + + """ + data = self.grab_metadata(self.datasets[0]) + for dataset in self.datasets[1:]: + temp_data = self.grab_metadata(dataset) + for attribute, value in data.items(): + if temp_data[attribute] != value: + raise ValueError("Files have mismatched " + "{}".format(attribute)) + + source = Source(frames=data['frames'], height=data['height'], + width=data['width'], dtype=data['dtype']) + + self.logger.debug("Source metadata retrieved: %s", source) + return source + + def construct_vds_metadata(self, source): + """Construct VDS data attributes from source attributes. + + Args: + source(Source): Attributes of data sets + + Returns: + VDS: Shape, dataset spacing and output path of virtual data set + + """ + stripes = len(self.datasets) + spacing = [0] * stripes + for idx in range(0, stripes - 1, 2): + spacing[idx] = self.stripe_spacing + for idx in range(1, stripes, 2): + spacing[idx] = self.module_spacing + # We don't want the final stripe to have a gap afterwards + spacing[-1] = 0 + + height = (source.height * stripes) + sum(spacing) + shape = source.frames + (height, source.width) + + vds = VDS(shape=shape, spacing=spacing) + self.logger.debug("VDS metadata constructed: %s", vds) + return vds + + def create_vds_maps(self, source, vds_data): + """Create a list of VirtualMaps of raw data to the VDS. + + Args: + source(Source): Source attributes + vds_data(VDS): VDS attributes + + Returns: + list(VirtualMap): Maps describing links between raw data and VDS + + """ + source_shape = source.frames + (source.height, source.width) + vds = h5.VirtualTarget(self.output_file, self.target_node, + shape=vds_data.shape) + + map_list = [] + current_position = 0 + for idx, dataset in enumerate(self.datasets): + + v_source = h5.VirtualSource(dataset, self.source_node, + shape=source_shape) + + start = current_position + stop = start + source.height + vds_data.spacing[idx] + current_position = stop + + index = tuple([self.FULL_SLICE] * len(source.frames) + + [slice(start, stop)] + [self.FULL_SLICE]) + v_target = vds[index] + v_map = h5.VirtualMap(v_source, v_target, dtype=source.dtype) + + self.logger.debug("Mapping dataset %s to %s of %s.", + dataset.split("/")[-1], index, self.name) + map_list.append(v_map) + + return map_list + + def validate_node(self, vds_file): + """Check if it is possible to create the given node. + + Create any sub-group of the target node if it doesn't exist. + + Args: + vds_file(h5py.File): File to check for node + + """ + while self.target_node.endswith("/"): + self.target_node = self.target_node[:-1] + + if "/" in self.target_node: + sub_group = self.target_node.rsplit("/", 1)[0] + if vds_file.get(sub_group) is None: + vds_file.create_group(sub_group)