diff --git a/mat73/core.py b/mat73/core.py index 01d96c8..c44de97 100644 --- a/mat73/core.py +++ b/mat73/core.py @@ -46,15 +46,21 @@ class HDF5Decoder(): def __init__(self, verbose=True, use_attrdict=False, only_include=None): + # if only_include is a string, convert into a list if isinstance(only_include, str): only_include = [only_include] + + # make sure all paths start with '/' and are not ending in '/' if only_include is not None: only_include = [s if s[0]=='/' else f'/{s}' for s in only_include] only_include = [s[:-1] if s[-1]=='/' else s for s in only_include] + self.verbose = verbose self._dict_class = AttrDict if use_attrdict else dict self.refs = {} # this is used in case of matlab matrices self.only_include = only_include + + # set a check if requested include_only var was actually found if only_include is not None: _vardict = dict(zip(only_include, [False]*len(only_include))) self._found_include_var = _vardict @@ -97,7 +103,7 @@ def mat2dict(self, hdf5): return d # @profile - def unpack_mat(self, hdf5, depth=0, MATLAB_class=None): + def unpack_mat(self, hdf5, depth=0, MATLAB_class=None, force=False): """ unpack a h5py entry: if it's a group expand, if it's a dataset convert @@ -111,7 +117,7 @@ def unpack_mat(self, hdf5, depth=0, MATLAB_class=None): for key in hdf5: elem = hdf5[key] - if not self.is_included(elem): + if not self.is_included(elem) and not force: continue if 'MATLAB_class' in elem.attrs: MATLAB_class = elem.attrs.get('MATLAB_class') @@ -176,7 +182,7 @@ def unpack_mat(self, hdf5, depth=0, MATLAB_class=None): return d elif isinstance(hdf5, h5py._hl.dataset.Dataset): - if self.is_included(hdf5): + if self.is_included(hdf5) or force: return self.convert_mat(hdf5, depth, MATLAB_class=MATLAB_class) else: raise Exception(f'Unknown hdf5 type: {key}:{type(hdf5)}') @@ -229,7 +235,8 @@ def convert_mat(self, dataset, depth, MATLAB_class=None): # some weird style MATLAB have no refs, but direct floats or int if isinstance(ref, Iterable): for r in ref: - entry = self.unpack_mat(self.refs.get(r), depth+1) + # force=True because we want to load cell contents + entry = self.unpack_mat(self.refs.get(r), depth+1, force=True) row.append(entry) else: row = [ref] @@ -335,8 +342,8 @@ def savemat(filename, verbose=True): if __name__=='__main__': # for testing / debugging - d = loadmat('../tests/testfile2.mat') + d = loadmat('../tests/testfile11.mat', only_include='foo') # file = '../tests/testfile8.mat' - # data = loadmat(file) + # data = loadmat(file) \ No newline at end of file diff --git a/tests/test_mat73.py b/tests/test_mat73.py index e9f9a09..f7e5c05 100644 --- a/tests/test_mat73.py +++ b/tests/test_mat73.py @@ -34,7 +34,7 @@ class Testing(unittest.TestCase): def setUp(self): """make links to test files and make sure they are present""" - for i in range(1, 9): + for i in range(1, 12): file = 'testfile{}.mat'.format(i) if not os.path.exists(file): file = os.path.join('./tests', file) @@ -55,9 +55,9 @@ def test_file_obj_loading(self): d = mat73.loadmat(f, use_attrdict=False) data = d['data'] assert len(d)==3 - assert len(d.keys())==3 + assert len(d.keys())==3 + - def test_file1_noattr(self): """Test each default MATLAB type loads correctly""" d = mat73.loadmat(self.testfile1, use_attrdict=False) @@ -396,8 +396,24 @@ def test_file10_nullchars(self): self.assertEqual(len(data['char_string']), 11, 'not all elements loaded') self.assertEqual(data['char_array'], '\x01\x02\x03\x00\x04\x05\x06') + def test_file11_specificvars_cells(self): + """see if contents of cells are also loaded when using only_include""" + # check regular loading works + data = mat73.loadmat(self.testfile11) + assert len(data)==1 + assert data['foo'][0]==1 + assert data['foo'][1]==2 + + # load cells correctly + data = mat73.loadmat(self.testfile11, only_include=['foo']) + assert len(data)==1 + assert data['foo'][0]==1 + assert data['foo'][1]==2 + # loading should be empty for non-existend var + data = mat73.loadmat(self.testfile11, only_include=['bar']) + assert len(data)==0 if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/testfile11.mat b/tests/testfile11.mat new file mode 100644 index 0000000..a629bf2 Binary files /dev/null and b/tests/testfile11.mat differ