Skip to content

Commit

Permalink
assume DOF type for read-only DOFs
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Apr 21, 2024
1 parent cb7f461 commit 2493e02
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 49 deletions.
83 changes: 49 additions & 34 deletions src/blop/dofs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time as ttime
import uuid
import warnings
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field, fields
from operator import attrgetter
Expand Down Expand Up @@ -44,30 +45,30 @@ def _validate_dofs(dofs):
return list(dofs)


def _validate_continuous_dof_domains(search_domain, trust_domain, domain):
def _validate_continuous_dof_domains(search_domain, trust_domain, domain, read_only):
"""
A DOF MUST have a search domain, and it MIGHT have a trust domain or a transform domain.
Check that all the domains are kosher by enforcing that:
search_domain \\subseteq trust_domain \\subseteq domain
"""
if not read_only:
try:
search_domain = tuple((float(search_domain[0]), float(search_domain[1])))
assert len(search_domain) == 2
except: # noqa
raise ValueError("If type='continuous', then 'search_domain' must be a tuple of two numbers.")

try:
search_domain = tuple((float(search_domain[0]), float(search_domain[1])))
assert len(search_domain) == 2
except: # noqa
raise ValueError("If type='continuous', then 'search_domain' must be a tuple of two numbers.")
if search_domain[0] >= search_domain[1]:
raise ValueError("The lower search bound must be strictly less than the upper search bound.")

if search_domain[0] >= search_domain[1]:
raise ValueError("The lower search bound must be strictly less than the upper search bound.")
if domain is not None:
if (search_domain[0] <= domain[0]) or (search_domain[1] >= domain[1]):
raise ValueError(f"The search domain {search_domain} must be a strict subset of the domain {domain}.")

if domain is not None:
if (search_domain[0] <= domain[0]) or (search_domain[1] >= domain[1]):
raise ValueError(f"The search domain {search_domain} must be a strict subset of the domain {domain}.")

if trust_domain is not None:
if (search_domain[0] < trust_domain[0]) or (search_domain[1] > trust_domain[1]):
raise ValueError(f"The search domain {search_domain} must be a subset of the trust domain {trust_domain}.")
if trust_domain is not None:
if (search_domain[0] < trust_domain[0]) or (search_domain[1] > trust_domain[1]):
raise ValueError(f"The search domain {search_domain} must be a subset of the trust domain {trust_domain}.")

if (trust_domain is not None) and (domain is not None):
if (trust_domain[0] < domain[0]) or (trust_domain[1] > domain[1]):
Expand Down Expand Up @@ -153,36 +154,50 @@ def __post_init__(self):
self.name = self.device.name
else:
raise ValueError("You must specify exactly one of 'name' or 'device'.")

if self.search_domain is None:
if not self.read_only:
raise ValueError("You must specify search_domain if read_only=False.")

if self.type is None:
if isinstance(self.search_domain, tuple):
self.type = "continuous"
elif isinstance(self.search_domain, set):
if len(self.search_domain) == 2:
self.type = "binary"
if self.read_only:
if self.type is None:
if isinstance(self.readback, float):
self.type = "continuous"
else:
self.type = "categorical"
warnings.warn(f"No type was specified for DOF {self.name}. Assuming type={self.type}.")
else:
if self.search_domain is None:
raise ValueError("You must specify the search domain if read_only=False.")
# if there is no type, infer it from the search_domain
if self.type is None:
if isinstance(self.search_domain, tuple):
self.type = "continuous"
elif isinstance(self.search_domain, set):
if len(self.search_domain) == 2:
self.type = "binary"
else:
self.type = "categorical"
else:
raise TypeError("'search_domain' must be either a 2-tuple of numbers or a set.")

if self.type not in DOF_TYPES:
raise ValueError(f"'type' must be one of {DOF_TYPES}")
raise ValueError(f"Invalid DOF type '{self.type}'. 'type' must be one of {DOF_TYPES}.")

# our input is usually continuous
if self.type == "continuous":
_validate_continuous_dof_domains(self._search_domain, self._trust_domain, self.domain)
if not self.read_only:
_validate_continuous_dof_domains(
search_domain=self._search_domain,
trust_domain=self._trust_domain,
domain=self.domain,
read_only=self.read_only,
)

self.search_domain = tuple((float(self.search_domain[0]), float(self.search_domain[1])))
self.search_domain = tuple((float(self.search_domain[0]), float(self.search_domain[1])))

if self.device is None:
center = float(self._untransform(np.mean([self._transform(np.array(self.search_domain))])))
self.device = Signal(name=self.name, value=center)
if self.device is None:
center = float(self._untransform(np.mean([self._transform(np.array(self.search_domain))])))
self.device = Signal(name=self.name, value=center)

# otherwise it must be discrete
else:
_validate_discrete_dof_domains(self._search_domain, self._trust_domain)
_validate_discrete_dof_domains(search_domain=self._search_domain, trust_domain=self._trust_domain)

if self.type == "binary":
if self.search_domain is None:
Expand Down Expand Up @@ -213,7 +228,7 @@ def _search_domain(self):
if self.read_only:
value = self.readback
if self.type == "continuous":
return tuple(value, value)
return tuple((value, value))
else:
return {value}
else:
Expand Down
2 changes: 1 addition & 1 deletion src/blop/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def digestion(db, uid):


@pytest.fixture(scope="function")
def agent_with_passive_dofs(db):
def agent_with_read_only_dofs(db):
"""
A simple agent minimizing two Himmelblau's functions
"""
Expand Down
8 changes: 3 additions & 5 deletions src/blop/tests/test_dofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,21 @@ def test_dof_types():
description="A binary DOF",
type="binary",
name="x2",
search_domain=["in", "out"],
trust_domain=["in"],
search_domain={"in", "out"},
units="is it in or out?",
)
dof3 = DOF(
description="An ordinal DOF",
type="ordinal",
name="x3",
search_domain=["low", "medium", "high"],
trust_domain=["low", "medium"],
search_domain={"low", "medium", "high"},
units="noise level",
)
dof4 = DOF(
description="A categorical DOF",
type="categorical",
name="x4",
search_domain=["mango", "orange", "banana", "papaya"],
search_domain={"mango", "orange", "banana", "papaya"},
units="fruit",
)

Expand Down
6 changes: 3 additions & 3 deletions src/blop/tests/test_passive_dofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@


@pytest.mark.test_func
def test_passive_dofs(agent_with_passive_dofs, RE):
RE(agent_with_passive_dofs.learn("qr", n=32))
RE(agent_with_passive_dofs.learn("qei", n=2))
def test_read_only_dofs(agent_with_read_only_dofs, RE):
RE(agent_with_read_only_dofs.learn("qr", n=32))
RE(agent_with_read_only_dofs.learn("qei", n=2))
12 changes: 6 additions & 6 deletions src/blop/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def test_plots_multiple_objs(RE, mo_agent):
mo_agent.plot_history()


def test_plots_passive_dofs(RE, agent_with_passive_dofs):
RE(agent_with_passive_dofs.learn("qr", n=16))
def test_plots_read_only_dofs(RE, agent_with_read_only_dofs):
RE(agent_with_read_only_dofs.learn("qr", n=16))

agent_with_passive_dofs.plot_objectives()
agent_with_passive_dofs.plot_acquisition()
agent_with_passive_dofs.plot_validity()
agent_with_passive_dofs.plot_history()
agent_with_read_only_dofs.plot_objectives()
agent_with_read_only_dofs.plot_acquisition()
agent_with_read_only_dofs.plot_validity()
agent_with_read_only_dofs.plot_history()

0 comments on commit 2493e02

Please sign in to comment.