Skip to content

Commit

Permalink
load cell content correctly when only_load is used
Browse files Browse the repository at this point in the history
  • Loading branch information
skjerns committed Aug 31, 2023
1 parent 2fe6991 commit 946a59e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 10 deletions.
19 changes: 13 additions & 6 deletions mat73/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,21 @@ class HDF5Decoder():
def __init__(self, verbose=True, use_attrdict=False,
only_include=None):

# if only_include is a string, convert into a list
if isinstance(only_include, str):
only_include = [only_include]

# make sure all paths start with '/' and are not ending in '/'
if only_include is not None:
only_include = [s if s[0]=='/' else f'/{s}' for s in only_include]
only_include = [s[:-1] if s[-1]=='/' else s for s in only_include]

self.verbose = verbose
self._dict_class = AttrDict if use_attrdict else dict
self.refs = {} # this is used in case of matlab matrices
self.only_include = only_include

# set a check if requested include_only var was actually found
if only_include is not None:
_vardict = dict(zip(only_include, [False]*len(only_include)))
self._found_include_var = _vardict
Expand Down Expand Up @@ -97,7 +103,7 @@ def mat2dict(self, hdf5):
return d

# @profile
def unpack_mat(self, hdf5, depth=0, MATLAB_class=None):
def unpack_mat(self, hdf5, depth=0, MATLAB_class=None, force=False):
"""
unpack a h5py entry: if it's a group expand,
if it's a dataset convert
Expand All @@ -111,7 +117,7 @@ def unpack_mat(self, hdf5, depth=0, MATLAB_class=None):

for key in hdf5:
elem = hdf5[key]
if not self.is_included(elem):
if not self.is_included(elem) and not force:
continue
if 'MATLAB_class' in elem.attrs:
MATLAB_class = elem.attrs.get('MATLAB_class')
Expand Down Expand Up @@ -176,7 +182,7 @@ def unpack_mat(self, hdf5, depth=0, MATLAB_class=None):

return d
elif isinstance(hdf5, h5py._hl.dataset.Dataset):
if self.is_included(hdf5):
if self.is_included(hdf5) or force:
return self.convert_mat(hdf5, depth, MATLAB_class=MATLAB_class)
else:
raise Exception(f'Unknown hdf5 type: {key}:{type(hdf5)}')
Expand Down Expand Up @@ -229,7 +235,8 @@ def convert_mat(self, dataset, depth, MATLAB_class=None):
# some weird style MATLAB have no refs, but direct floats or int
if isinstance(ref, Iterable):
for r in ref:
entry = self.unpack_mat(self.refs.get(r), depth+1)
# force=True because we want to load cell contents
entry = self.unpack_mat(self.refs.get(r), depth+1, force=True)
row.append(entry)
else:
row = [ref]
Expand Down Expand Up @@ -335,8 +342,8 @@ def savemat(filename, verbose=True):

if __name__=='__main__':
# for testing / debugging
d = loadmat('../tests/testfile2.mat')
d = loadmat('../tests/testfile11.mat', only_include='foo')


# file = '../tests/testfile8.mat'
# data = loadmat(file)
# data = loadmat(file)
24 changes: 20 additions & 4 deletions tests/test_mat73.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Testing(unittest.TestCase):

def setUp(self):
"""make links to test files and make sure they are present"""
for i in range(1, 9):
for i in range(1, 12):
file = 'testfile{}.mat'.format(i)
if not os.path.exists(file):
file = os.path.join('./tests', file)
Expand All @@ -55,9 +55,9 @@ def test_file_obj_loading(self):
d = mat73.loadmat(f, use_attrdict=False)
data = d['data']
assert len(d)==3
assert len(d.keys())==3
assert len(d.keys())==3



def test_file1_noattr(self):
"""Test each default MATLAB type loads correctly"""
d = mat73.loadmat(self.testfile1, use_attrdict=False)
Expand Down Expand Up @@ -396,8 +396,24 @@ def test_file10_nullchars(self):
self.assertEqual(len(data['char_string']), 11, 'not all elements loaded')
self.assertEqual(data['char_array'], '\x01\x02\x03\x00\x04\x05\x06')

def test_file11_specificvars_cells(self):
"""see if contents of cells are also loaded when using only_include"""
# check regular loading works
data = mat73.loadmat(self.testfile11)
assert len(data)==1
assert data['foo'][0]==1
assert data['foo'][1]==2

# load cells correctly
data = mat73.loadmat(self.testfile11, only_include=['foo'])
assert len(data)==1
assert data['foo'][0]==1
assert data['foo'][1]==2

# loading should be empty for non-existend var
data = mat73.loadmat(self.testfile11, only_include=['bar'])
assert len(data)==0

if __name__ == '__main__':

unittest.main()
unittest.main()
Binary file added tests/testfile11.mat
Binary file not shown.

0 comments on commit 946a59e

Please sign in to comment.