Skip to content

Commit

Permalink
Fix resizing for arrays with shape!=() & more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pmrv committed Oct 21, 2021
1 parent 81fd14c commit 4c09aec
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
8 changes: 5 additions & 3 deletions pyiron_base/generic/flattenedstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def get_array_ragged(self, name):
numpy.ndarray, dtype=object: ragged arrray of all elements in all chunks
"""
if name in self._per_chunk_arrays:
return self._per_chunk_arrays[name].copy()
return self.get_array(name)
return np.array([self.get_array(name, i) for i in range(len(self))],
dtype=object)

Expand All @@ -374,12 +374,14 @@ def get_array_filled(self, name):
numpy.ndarray: padded arrray of all elements in all chunks
"""
if name in self._per_chunk_arrays:
return self._per_chunk_arrays[name].copy()
return self.get_array(name)
values = self.get_array_ragged(name)
max_len = self._per_chunk_arrays["length"].max()
def resize_and_pad(v):
l = len(v)
v = np.resize(v, max_len)
per_shape = self._per_element_arrays[name].shape[1:]
v = np.resize(v, max_len * np.prod(per_shape, dtype=int))
v = v.reshape((max_len,) + per_shape)
if name in self._fill_values:
fill = self._fill_values[name]
else:
Expand Down
19 changes: 19 additions & 0 deletions tests/generic/test_flattenedstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,27 @@ def test_get_array_filled(self):
store.set_array("fill", 0, [-1])
store.set_array("fill", 1, [-2, -3])
store.set_array("fill", 2, [-4, -5, -6])
store.add_array("complex", shape=(3,), dtype=np.float64)
store.set_array("complex", 0, [ [1, 1, 1] ])
store.set_array("complex", 1, [ [2, 2, 2],
[2, 2, 2],
])
store.set_array("complex", 2, [ [3, 3, 3],
[3, 3, 3],
[3, 3, 3],
])
val = store.get_array_filled("elem")
self.assertEqual(val.shape, (3, 3), "shape not correct!")
self.assertTrue(np.array_equal(val, [[1, -1, -1], [2, 3, -1], [4, 5, 6]]),
"values in returned array not the same as in original array!")
self.assertEqual(store.get_array_filled("fill")[0, 1], 23.42,
"incorrect fill value!")
val = store.get_array_filled("complex")
self.assertEqual(val.shape, (3, 3, 3), "shape not correct!")
self.assertTrue(np.array_equal(
store.get_array("chunk"),
store.get_array_filled("chunk"),
), "get_array_filled does not give same result as get_array for per chunk array")

def test_get_array_ragged(self):
"""get_array_ragged should return a raggend array of all elements in the storage."""
Expand All @@ -218,6 +233,10 @@ def test_get_array_ragged(self):
f"array {i} has incorrect length!")
self.assertTrue(np.array_equal(v, [[1], [2, 3], [4, 5, 6]][i]),
f"array {i} has incorrect values, {v}!")
self.assertTrue(np.array_equal(
store.get_array("chunk"),
store.get_array_ragged("chunk"),
), "get_array_ragged does not give same result as get_array for per chunk array")

def test_has_array(self):
"""hasarray should return correct information for added array; None otherwise."""
Expand Down

0 comments on commit 4c09aec

Please sign in to comment.