diff --git a/pyiron_base/generic/datacontainer.py b/pyiron_base/generic/datacontainer.py index 196fba0e4..eba914a0a 100644 --- a/pyiron_base/generic/datacontainer.py +++ b/pyiron_base/generic/datacontainer.py @@ -259,7 +259,7 @@ def __new__(cls, *args, **kwargs): return instance - def __init__(self, init=None, table_name=None, lazy=False): + def __init__(self, init=None, table_name=None, lazy=False, wrap_blacklist=()): """ Create new container. @@ -268,11 +268,13 @@ def __init__(self, init=None, table_name=None, lazy=False): translated to nested containers table_name (str): default name of the data container in HDF5 lazy (bool): if True, use :class:`.HDFStub` to load values lazily from HDF5 + wrap_blacklist (tuple of types): any values in `init` that are instances of the given types are *not* + wrapped in :class:`.DataContainer` """ self.table_name = table_name self._lazy = lazy if init is not None: - self.update(init, wrap=True) + self.update(init, wrap=True, blacklist=wrap_blacklist) def __len__(self): return len(self._store) @@ -432,13 +434,6 @@ def __repr__(self): else: return name + "([" + ", ".join("{!r}".format(v) for v in self._store) + "])" - @classmethod - def _wrap_val(cls, val): - if isinstance(val, (tuple, list, dict)): - return cls(val) - else: - return val - @property def read_only(self): """ @@ -595,7 +590,14 @@ def _search_parent(self, key, stop_on_first_hit=True): first_hit = hit return first_hit - def update(self, init, wrap=False, **kwargs): + @classmethod + def _wrap_val(cls, val, blacklist): + if isinstance(val, (Sequence, Set, Mapping)) and not isinstance(val, blacklist): + return cls(val, wrap_blacklist=blacklist) + else: + return val + + def update(self, init, wrap=False, blacklist=(), **kwargs): """ Add all elements or key-value pairs from init to this container. If wrap is not given, behaves as the generic method. @@ -603,18 +605,22 @@ def update(self, init, wrap=False, **kwargs): Args: init (Sequence, Set, Mapping): container to draw new elements from wrap (bool): if True wrap all encountered Sequences and Mappings in - DataContainers recursively + :class:`.DataContainer` recursively + blacklist (list of types): when `wrap` is True, don't wrap these types even if they're instances of Sequence + or Mapping **kwargs: update from this mapping as well """ - if wrap: + if str not in blacklist: + blacklist += (str,) + if wrap and (isinstance(wrap, bool) or not isinstance(init, blacklist)): if isinstance(init, (Sequence, Set)): for v in init: - self.append(self._wrap_val(v)) + self.append(self._wrap_val(v, blacklist)) elif isinstance(init, Mapping): for i, (k, v) in enumerate(init.items()): k = _normalize(k) - v = self._wrap_val(v) + v = self._wrap_val(v, blacklist) if isinstance(k, int): if k == i: self.append(v) @@ -629,7 +635,7 @@ def update(self, init, wrap=False, **kwargs): ValueError("init must be Sequence, Set or Mapping") for k in kwargs: - self[k] = kwargs[k] + self[k] = self._wrap_val(kwargs[k], blacklist) else: super().update(init, **kwargs) diff --git a/tests/generic/test_datacontainer.py b/tests/generic/test_datacontainer.py index d407222b0..bf765e024 100644 --- a/tests/generic/test_datacontainer.py +++ b/tests/generic/test_datacontainer.py @@ -10,16 +10,16 @@ import os import unittest import warnings +import h5py import numpy as np import pandas as pd class Sub(DataContainer): - def __init__(self, init=None, table_name=None, lazy=False): - super().__init__(init=init, table_name=table_name, lazy=lazy) + def __init__(self, init=None, table_name=None, lazy=False, wrap_blacklist=()): + super().__init__(init=init, table_name=table_name, lazy=lazy, wrap_blacklist=()) self.foo = 42 - class TestDataContainer(TestWithCleanProject): @property @@ -194,6 +194,47 @@ def test_update(self): pl.update(d) self.assertEqual(dict(pl), d, "update without options does not call generic method") + def test_update_blacklist(self): + """Wrapping nested mapping should only apply to types not in the blacklist.""" + pl = DataContainer() + pl.update([ {"a": 1, "b": 2}, [{"c": 3, "d": 4}] ], wrap=True, blacklist=(dict,)) + self.assertTrue(isinstance(pl[0], dict), "nested dict wrapped, even if black listed") + self.assertTrue(isinstance(pl[1][0], dict), "nested dict wrapped, even if black listed") + pl.clear() + + pl.update({"a": [1, 2, 3], "b": {"c": [4, 5, 6]}}, wrap=True, blacklist=(list,)) + self.assertTrue(isinstance(pl.a, list), "nested list wrapped, even if black listed") + self.assertTrue(isinstance(pl.b.c, list), "nested list wrapped, even if black listed") + pl.clear() + + def test_wrap_hdf(self): + """DataContainer should be able to be initialized by HDF objects.""" + h = self.project.create_hdf(self.project.path, "wrap_test") + h["foo"] = 42 + h.create_group("bar")["test"] = 23 + h["bar"].create_group("nested")["test"] = 23 + d = DataContainer(h) + self.assertTrue(isinstance(d.bar, DataContainer), + "HDF group not wrapped from ProjectHDFio.") + self.assertTrue(isinstance(d.bar.nested, DataContainer), + "Nested HDF group not wrapped from ProjectHDFio.") + self.assertEqual(d.foo, 42, "Top-level node not correctly wrapped from ProjectHDFio.") + self.assertEqual(d.bar.test, 23, "Nested node not correctly wrapped from ProjectHDFio.") + self.assertEqual(d.bar.nested.test, 23, "Nested node not correctly wrapped from ProjectHDFio.") + + h = h5py.File(h.file_name) + d = DataContainer(h) + self.assertTrue(isinstance(d.wrap_test.bar, DataContainer), + "HDF group not wrapped from h5py.File.") + self.assertTrue(isinstance(d.wrap_test.bar.nested, DataContainer), + "Nested HDF group not wrapped from h5py.File.") + self.assertEqual(d.wrap_test.foo, h["wrap_test/foo"], + "Top-level node not correctly wrapped from h5py.File.") + self.assertEqual(d.wrap_test.bar.test, h["wrap_test/bar/test"], + "Nested node not correctly wrapped from h5py.File.") + self.assertEqual(d.wrap_test.bar.nested.test, h["wrap_test/bar/nested/test"], + "Nested node not correctly wrapped from h5py.File.") + def test_extend(self): pl = DataContainer() pl.extend([1, 2, 3])