From 2493e026d5a0a0291f14d050a6d494080791e268 Mon Sep 17 00:00:00 2001 From: Thomas Morris Date: Sun, 21 Apr 2024 15:22:33 -0700 Subject: [PATCH] assume DOF type for read-only DOFs --- src/blop/dofs.py | 83 +++++++++++++++++------------ src/blop/tests/conftest.py | 2 +- src/blop/tests/test_dofs.py | 8 ++- src/blop/tests/test_passive_dofs.py | 6 +-- src/blop/tests/test_plots.py | 12 ++--- 5 files changed, 62 insertions(+), 49 deletions(-) diff --git a/src/blop/dofs.py b/src/blop/dofs.py index 44faff4..8ecce32 100644 --- a/src/blop/dofs.py +++ b/src/blop/dofs.py @@ -1,5 +1,6 @@ import time as ttime import uuid +import warnings from collections.abc import Iterable, Sequence from dataclasses import dataclass, field, fields from operator import attrgetter @@ -44,30 +45,30 @@ def _validate_dofs(dofs): return list(dofs) -def _validate_continuous_dof_domains(search_domain, trust_domain, domain): +def _validate_continuous_dof_domains(search_domain, trust_domain, domain, read_only): """ A DOF MUST have a search domain, and it MIGHT have a trust domain or a transform domain. Check that all the domains are kosher by enforcing that: search_domain \\subseteq trust_domain \\subseteq domain """ + if not read_only: + try: + search_domain = tuple((float(search_domain[0]), float(search_domain[1]))) + assert len(search_domain) == 2 + except: # noqa + raise ValueError("If type='continuous', then 'search_domain' must be a tuple of two numbers.") - try: - search_domain = tuple((float(search_domain[0]), float(search_domain[1]))) - assert len(search_domain) == 2 - except: # noqa - raise ValueError("If type='continuous', then 'search_domain' must be a tuple of two numbers.") + if search_domain[0] >= search_domain[1]: + raise ValueError("The lower search bound must be strictly less than the upper search bound.") - if search_domain[0] >= search_domain[1]: - raise ValueError("The lower search bound must be strictly less than the upper search bound.") + if domain is not None: + if (search_domain[0] <= domain[0]) or (search_domain[1] >= domain[1]): + raise ValueError(f"The search domain {search_domain} must be a strict subset of the domain {domain}.") - if domain is not None: - if (search_domain[0] <= domain[0]) or (search_domain[1] >= domain[1]): - raise ValueError(f"The search domain {search_domain} must be a strict subset of the domain {domain}.") - - if trust_domain is not None: - if (search_domain[0] < trust_domain[0]) or (search_domain[1] > trust_domain[1]): - raise ValueError(f"The search domain {search_domain} must be a subset of the trust domain {trust_domain}.") + if trust_domain is not None: + if (search_domain[0] < trust_domain[0]) or (search_domain[1] > trust_domain[1]): + raise ValueError(f"The search domain {search_domain} must be a subset of the trust domain {trust_domain}.") if (trust_domain is not None) and (domain is not None): if (trust_domain[0] < domain[0]) or (trust_domain[1] > domain[1]): @@ -153,36 +154,50 @@ def __post_init__(self): self.name = self.device.name else: raise ValueError("You must specify exactly one of 'name' or 'device'.") - - if self.search_domain is None: - if not self.read_only: - raise ValueError("You must specify search_domain if read_only=False.") - - if self.type is None: - if isinstance(self.search_domain, tuple): - self.type = "continuous" - elif isinstance(self.search_domain, set): - if len(self.search_domain) == 2: - self.type = "binary" + if self.read_only: + if self.type is None: + if isinstance(self.readback, float): + self.type = "continuous" else: self.type = "categorical" + warnings.warn(f"No type was specified for DOF {self.name}. Assuming type={self.type}.") + else: + if self.search_domain is None: + raise ValueError("You must specify the search domain if read_only=False.") + # if there is no type, infer it from the search_domain + if self.type is None: + if isinstance(self.search_domain, tuple): + self.type = "continuous" + elif isinstance(self.search_domain, set): + if len(self.search_domain) == 2: + self.type = "binary" + else: + self.type = "categorical" + else: + raise TypeError("'search_domain' must be either a 2-tuple of numbers or a set.") if self.type not in DOF_TYPES: - raise ValueError(f"'type' must be one of {DOF_TYPES}") + raise ValueError(f"Invalid DOF type '{self.type}'. 'type' must be one of {DOF_TYPES}.") # our input is usually continuous if self.type == "continuous": - _validate_continuous_dof_domains(self._search_domain, self._trust_domain, self.domain) + if not self.read_only: + _validate_continuous_dof_domains( + search_domain=self._search_domain, + trust_domain=self._trust_domain, + domain=self.domain, + read_only=self.read_only, + ) - self.search_domain = tuple((float(self.search_domain[0]), float(self.search_domain[1]))) + self.search_domain = tuple((float(self.search_domain[0]), float(self.search_domain[1]))) - if self.device is None: - center = float(self._untransform(np.mean([self._transform(np.array(self.search_domain))]))) - self.device = Signal(name=self.name, value=center) + if self.device is None: + center = float(self._untransform(np.mean([self._transform(np.array(self.search_domain))]))) + self.device = Signal(name=self.name, value=center) # otherwise it must be discrete else: - _validate_discrete_dof_domains(self._search_domain, self._trust_domain) + _validate_discrete_dof_domains(search_domain=self._search_domain, trust_domain=self._trust_domain) if self.type == "binary": if self.search_domain is None: @@ -213,7 +228,7 @@ def _search_domain(self): if self.read_only: value = self.readback if self.type == "continuous": - return tuple(value, value) + return tuple((value, value)) else: return {value} else: diff --git a/src/blop/tests/conftest.py b/src/blop/tests/conftest.py index 33d8dbc..6506f53 100644 --- a/src/blop/tests/conftest.py +++ b/src/blop/tests/conftest.py @@ -146,7 +146,7 @@ def digestion(db, uid): @pytest.fixture(scope="function") -def agent_with_passive_dofs(db): +def agent_with_read_only_dofs(db): """ A simple agent minimizing two Himmelblau's functions """ diff --git a/src/blop/tests/test_dofs.py b/src/blop/tests/test_dofs.py index 5ff3664..4c83613 100644 --- a/src/blop/tests/test_dofs.py +++ b/src/blop/tests/test_dofs.py @@ -9,23 +9,21 @@ def test_dof_types(): description="A binary DOF", type="binary", name="x2", - search_domain=["in", "out"], - trust_domain=["in"], + search_domain={"in", "out"}, units="is it in or out?", ) dof3 = DOF( description="An ordinal DOF", type="ordinal", name="x3", - search_domain=["low", "medium", "high"], - trust_domain=["low", "medium"], + search_domain={"low", "medium", "high"}, units="noise level", ) dof4 = DOF( description="A categorical DOF", type="categorical", name="x4", - search_domain=["mango", "orange", "banana", "papaya"], + search_domain={"mango", "orange", "banana", "papaya"}, units="fruit", ) diff --git a/src/blop/tests/test_passive_dofs.py b/src/blop/tests/test_passive_dofs.py index bc8157a..fb8624a 100644 --- a/src/blop/tests/test_passive_dofs.py +++ b/src/blop/tests/test_passive_dofs.py @@ -2,6 +2,6 @@ @pytest.mark.test_func -def test_passive_dofs(agent_with_passive_dofs, RE): - RE(agent_with_passive_dofs.learn("qr", n=32)) - RE(agent_with_passive_dofs.learn("qei", n=2)) +def test_read_only_dofs(agent_with_read_only_dofs, RE): + RE(agent_with_read_only_dofs.learn("qr", n=32)) + RE(agent_with_read_only_dofs.learn("qei", n=2)) diff --git a/src/blop/tests/test_plots.py b/src/blop/tests/test_plots.py index e67836c..4f2c535 100644 --- a/src/blop/tests/test_plots.py +++ b/src/blop/tests/test_plots.py @@ -19,10 +19,10 @@ def test_plots_multiple_objs(RE, mo_agent): mo_agent.plot_history() -def test_plots_passive_dofs(RE, agent_with_passive_dofs): - RE(agent_with_passive_dofs.learn("qr", n=16)) +def test_plots_read_only_dofs(RE, agent_with_read_only_dofs): + RE(agent_with_read_only_dofs.learn("qr", n=16)) - agent_with_passive_dofs.plot_objectives() - agent_with_passive_dofs.plot_acquisition() - agent_with_passive_dofs.plot_validity() - agent_with_passive_dofs.plot_history() + agent_with_read_only_dofs.plot_objectives() + agent_with_read_only_dofs.plot_acquisition() + agent_with_read_only_dofs.plot_validity() + agent_with_read_only_dofs.plot_history()