Skip to content

Commit

Permalink
Merge pull request #471 from pyiron/datacontainer_wrap
Browse files Browse the repository at this point in the history
Use a blacklist from wrapping values
  • Loading branch information
pmrv authored Oct 26, 2021
2 parents 28c60c8 + 32be283 commit 46245b5
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 18 deletions.
36 changes: 21 additions & 15 deletions pyiron_base/generic/datacontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def __new__(cls, *args, **kwargs):

return instance

def __init__(self, init=None, table_name=None, lazy=False):
def __init__(self, init=None, table_name=None, lazy=False, wrap_blacklist=()):
"""
Create new container.
Expand All @@ -268,11 +268,13 @@ def __init__(self, init=None, table_name=None, lazy=False):
translated to nested containers
table_name (str): default name of the data container in HDF5
lazy (bool): if True, use :class:`.HDFStub` to load values lazily from HDF5
wrap_blacklist (tuple of types): any values in `init` that are instances of the given types are *not*
wrapped in :class:`.DataContainer`
"""
self.table_name = table_name
self._lazy = lazy
if init is not None:
self.update(init, wrap=True)
self.update(init, wrap=True, blacklist=wrap_blacklist)

def __len__(self):
return len(self._store)
Expand Down Expand Up @@ -432,13 +434,6 @@ def __repr__(self):
else:
return name + "([" + ", ".join("{!r}".format(v) for v in self._store) + "])"

@classmethod
def _wrap_val(cls, val):
if isinstance(val, (tuple, list, dict)):
return cls(val)
else:
return val

@property
def read_only(self):
"""
Expand Down Expand Up @@ -595,26 +590,37 @@ def _search_parent(self, key, stop_on_first_hit=True):
first_hit = hit
return first_hit

def update(self, init, wrap=False, **kwargs):
@classmethod
def _wrap_val(cls, val, blacklist):
if isinstance(val, (Sequence, Set, Mapping)) and not isinstance(val, blacklist):
return cls(val, wrap_blacklist=blacklist)
else:
return val

def update(self, init, wrap=False, blacklist=(), **kwargs):
"""
Add all elements or key-value pairs from init to this container. If wrap is
not given, behaves as the generic method.
Args:
init (Sequence, Set, Mapping): container to draw new elements from
wrap (bool): if True wrap all encountered Sequences and Mappings in
DataContainers recursively
:class:`.DataContainer` recursively
blacklist (list of types): when `wrap` is True, don't wrap these types even if they're instances of Sequence
or Mapping
**kwargs: update from this mapping as well
"""
if wrap:
if str not in blacklist:
blacklist += (str,)
if wrap and (isinstance(wrap, bool) or not isinstance(init, blacklist)):
if isinstance(init, (Sequence, Set)):
for v in init:
self.append(self._wrap_val(v))
self.append(self._wrap_val(v, blacklist))

elif isinstance(init, Mapping):
for i, (k, v) in enumerate(init.items()):
k = _normalize(k)
v = self._wrap_val(v)
v = self._wrap_val(v, blacklist)
if isinstance(k, int):
if k == i:
self.append(v)
Expand All @@ -629,7 +635,7 @@ def update(self, init, wrap=False, **kwargs):
ValueError("init must be Sequence, Set or Mapping")

for k in kwargs:
self[k] = kwargs[k]
self[k] = self._wrap_val(kwargs[k], blacklist)
else:
super().update(init, **kwargs)

Expand Down
47 changes: 44 additions & 3 deletions tests/generic/test_datacontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
import os
import unittest
import warnings
import h5py
import numpy as np
import pandas as pd


class Sub(DataContainer):
def __init__(self, init=None, table_name=None, lazy=False):
super().__init__(init=init, table_name=table_name, lazy=lazy)
def __init__(self, init=None, table_name=None, lazy=False, wrap_blacklist=()):
super().__init__(init=init, table_name=table_name, lazy=lazy, wrap_blacklist=())
self.foo = 42


class TestDataContainer(TestWithCleanProject):

@property
Expand Down Expand Up @@ -194,6 +194,47 @@ def test_update(self):
pl.update(d)
self.assertEqual(dict(pl), d, "update without options does not call generic method")

def test_update_blacklist(self):
"""Wrapping nested mapping should only apply to types not in the blacklist."""
pl = DataContainer()
pl.update([ {"a": 1, "b": 2}, [{"c": 3, "d": 4}] ], wrap=True, blacklist=(dict,))
self.assertTrue(isinstance(pl[0], dict), "nested dict wrapped, even if black listed")
self.assertTrue(isinstance(pl[1][0], dict), "nested dict wrapped, even if black listed")
pl.clear()

pl.update({"a": [1, 2, 3], "b": {"c": [4, 5, 6]}}, wrap=True, blacklist=(list,))
self.assertTrue(isinstance(pl.a, list), "nested list wrapped, even if black listed")
self.assertTrue(isinstance(pl.b.c, list), "nested list wrapped, even if black listed")
pl.clear()

def test_wrap_hdf(self):
"""DataContainer should be able to be initialized by HDF objects."""
h = self.project.create_hdf(self.project.path, "wrap_test")
h["foo"] = 42
h.create_group("bar")["test"] = 23
h["bar"].create_group("nested")["test"] = 23
d = DataContainer(h)
self.assertTrue(isinstance(d.bar, DataContainer),
"HDF group not wrapped from ProjectHDFio.")
self.assertTrue(isinstance(d.bar.nested, DataContainer),
"Nested HDF group not wrapped from ProjectHDFio.")
self.assertEqual(d.foo, 42, "Top-level node not correctly wrapped from ProjectHDFio.")
self.assertEqual(d.bar.test, 23, "Nested node not correctly wrapped from ProjectHDFio.")
self.assertEqual(d.bar.nested.test, 23, "Nested node not correctly wrapped from ProjectHDFio.")

h = h5py.File(h.file_name)
d = DataContainer(h)
self.assertTrue(isinstance(d.wrap_test.bar, DataContainer),
"HDF group not wrapped from h5py.File.")
self.assertTrue(isinstance(d.wrap_test.bar.nested, DataContainer),
"Nested HDF group not wrapped from h5py.File.")
self.assertEqual(d.wrap_test.foo, h["wrap_test/foo"],
"Top-level node not correctly wrapped from h5py.File.")
self.assertEqual(d.wrap_test.bar.test, h["wrap_test/bar/test"],
"Nested node not correctly wrapped from h5py.File.")
self.assertEqual(d.wrap_test.bar.nested.test, h["wrap_test/bar/nested/test"],
"Nested node not correctly wrapped from h5py.File.")

def test_extend(self):
pl = DataContainer()
pl.extend([1, 2, 3])
Expand Down

0 comments on commit 46245b5

Please sign in to comment.