Skip to content

Commit

Permalink
fixed read-only DOFs
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Jan 18, 2024
1 parent a7b6439 commit 3c21a4b
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 46 deletions.
6 changes: 3 additions & 3 deletions blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __getattr__(self, attr):

raise AttributeError(f"DOFList object has no attribute named '{attr}'.")

def view(self, item: str = "mean", cmap: str = "turbo", max_inputs: int = MAX_TEST_INPUTS):
def view(self, item: str = "mean", cmap: str = "turbo", max_inputs: int = 2**16):
"""
Use napari to see a high-dimensional array.
Expand Down Expand Up @@ -821,7 +821,7 @@ def train_inputs(self, index=None, **subset_kwargs):
inputs = self.table.loc[:, dof.name].values.copy()

# check that inputs values are inside acceptable values
valid = (inputs >= dof.trust_bounds[0]) & (inputs <= dof.trust_bounds[1])
valid = (inputs >= dof.trust_lower_bound) & (inputs <= dof.trust_upper_bound)
inputs = np.where(valid, inputs, np.nan)

# transform if needed
Expand All @@ -840,7 +840,7 @@ def train_targets(self, index=None, **subset_kwargs):
targets = self.table.loc[:, obj.name].values.copy()

# check that targets values are inside acceptable values
valid = (targets >= obj.trust_bounds[0]) & (targets <= obj.trust_bounds[1])
valid = (targets >= obj.trust_lower_bound) & (targets <= obj.trust_upper_bound)
targets = np.where(valid, targets, np.nan)

# transform if needed
Expand Down
46 changes: 26 additions & 20 deletions blop/dofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"read_only": "bool",
"log": "bool",
"tags": "object",
"device": "object",
}


Expand Down Expand Up @@ -68,11 +67,11 @@ class DOF:
An ophyd device. If not supplied, a dummy ophyd device will be generated.
"""

name: str
name: str = None
description: str = ""
units: str = ""
search_bounds: Tuple[float, float]
search_bounds: Tuple[float, float] = None
trust_bounds: Tuple[float, float] = None
units: str = ""
read_only: bool = False
active: bool = True
log: bool = False
Expand All @@ -81,14 +80,13 @@ class DOF:

# Some post-processing. This is specific to dataclasses
def __post_init__(self):
if self.trust_bounds is None:
if self.log:
self.trust_bounds = (0, np.inf)
else:
self.trust_bounds = (-np.inf, np.inf)

self.search_bounds = tuple(self.search_bounds)
self.trust_bounds = tuple(self.trust_bounds)
if self.trust_bounds is not None:
self.trust_bounds = tuple(self.trust_bounds)
if self.search_bounds is None:
if not self.read_only:
raise ValueError("You must specify search_bounds if the device is not read-only.")
else:
self.search_bounds = tuple(self.search_bounds)

self.uuid = str(uuid.uuid4())

Expand All @@ -101,34 +99,42 @@ def __post_init__(self):
if not self.read_only:
# check that the device has a put method
if isinstance(self.device, SignalRO):
raise ValueError("Must specify read_only=True for a read-only device!")
raise ValueError("You must specify read_only=True for a read-only device.")

if self.log:
if not self.search_lower_bound > 0:
raise ValueError("Search bounds must be positive if log=True.")

# all dof degrees of freedom are hinted
self.device.kind = "hinted"

@property
def search_lower_bound(self):
return float(self.search_bounds[0])
if self.read_only:
raise ValueError("Read-only DOFs do not have search bounds.")
return float(self.summary.search_bounds[0])

@property
def search_upper_bound(self):
return float(self.search_bounds[1])
if self.read_only:
raise ValueError("Read-only DOFs do not have search bounds.")
return float(self.summary.search_bounds[1])

@property
def trust_lower_bound(self):
return float(self.trust_bounds[0])
return float(self.summary.trust_bounds[0])

@property
def trust_upper_bound(self):
return float(self.trust_bounds[1])
return float(self.summary.trust_bounds[1])

@property
def readback(self):
return self.device.read()[self.device.name]["value"]

@property
def summary(self) -> pd.Series:
series = pd.Series(index=list(DOF_FIELD_TYPES.keys()))
series = pd.Series(index=list(DOF_FIELD_TYPES.keys()), dtype="object")
for attr in series.index:
value = getattr(self, attr)
if attr == "trust_bounds":
Expand Down Expand Up @@ -203,11 +209,11 @@ def device_names(self) -> list:

@property
def search_lower_bounds(self) -> np.array:
return np.array([dof.search_lower_bound for dof in self.dofs])
return np.array([dof.search_lower_bound if not dof.read_only else dof.readback for dof in self.dofs])

@property
def search_upper_bounds(self) -> np.array:
return np.array([dof.search_upper_bound for dof in self.dofs])
return np.array([dof.search_upper_bound if not dof.read_only else dof.readback for dof in self.dofs])

@property
def search_bounds(self) -> np.array:
Expand Down
32 changes: 22 additions & 10 deletions blop/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,14 @@ def label(self):
return f"{'log ' if self.log else ''}{self.description}"

@property
def summary(self):
series = pd.Series()
for attr, dtype in OBJ_FIELD_TYPES.items():
series[attr] = getattr(self, attr)

def summary(self) -> pd.Series:
series = pd.Series(index=list(OBJ_FIELD_TYPES.keys()), dtype="object")
for attr in series.index:
value = getattr(self, attr)
if attr == "trust_bounds":
if value is None:
value = (0, np.inf) if self.log else (-np.inf, np.inf)
series[attr] = value
return series

def __repr__(self):
Expand All @@ -112,13 +115,21 @@ def __repr__(self):
def __repr_html__(self):
return self.summary.__repr_html__()

@property
def trust_lower_bound(self):
return float(self.summary.trust_bounds[0])

@property
def trust_upper_bound(self):
return float(self.summary.trust_bounds[1])

@property
def noise(self):
return self.model.likelihood.noise.item() if hasattr(self, "model") else None

@property
def snr(self):
return np.round(1 / self.model.likelihood.noise.sqrt().item(), 1) if hasattr(self, "model") else None
return np.round(1 / self.model.likelihood.noise.sqrt().item(), 3) if hasattr(self, "model") else None

@property
def n(self):
Expand Down Expand Up @@ -151,10 +162,11 @@ def __len__(self):
def summary(self) -> pd.DataFrame:
table = pd.DataFrame(columns=list(OBJ_FIELD_TYPES.keys()), index=self.names)

for attr, dtype in OBJ_FIELD_TYPES.items():
for obj in self.objectives:
table.at[obj.name, attr] = getattr(obj, attr)
for obj in self.objectives:
for attr, value in obj.summary.items():
table.at[obj.name, attr] = value

for attr, dtype in OBJ_FIELD_TYPES.items():
table[attr] = table[attr].astype(dtype)

return table
Expand All @@ -180,7 +192,7 @@ def names(self) -> list:
return [obj.name for obj in self.objectives]

@property
def targets(self) -> np.array:
def targets(self) -> list:
"""
Returns an array of the objective targets.
"""
Expand Down
22 changes: 11 additions & 11 deletions blop/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from databroker import Broker
from ophyd.utils import make_dir_tree

from blop.bayesian import DOF, Agent, Objective
from blop.bayesian.dofs import BrownianMotion
from blop import DOF, Agent, Objective
from blop.dofs import BrownianMotion
from blop.utils import functions


Expand Down Expand Up @@ -51,8 +51,8 @@ def agent(db):
"""

dofs = [
DOF(name="x1", limits=(-8.0, 8.0)),
DOF(name="x2", limits=(-8.0, 8.0)),
DOF(name="x1", search_bounds=(-8.0, 8.0)),
DOF(name="x2", search_bounds=(-8.0, 8.0)),
]

objectives = [Objective(name="himmelblau", target="min")]
Expand Down Expand Up @@ -85,8 +85,8 @@ def digestion(db, uid):
return products

dofs = [
DOF(name="x1", limits=(-5.0, 5.0)),
DOF(name="x2", limits=(-5.0, 5.0)),
DOF(name="x1", search_bounds=(-5.0, 5.0)),
DOF(name="x2", search_bounds=(-5.0, 5.0)),
]

objectives = [Objective(name="obj1", target="min"), Objective(name="obj2", target="min")]
Expand All @@ -110,11 +110,11 @@ def agent_with_passive_dofs(db):
"""

dofs = [
DOF(name="x1", limits=(-5.0, 5.0)),
DOF(name="x2", limits=(-5.0, 5.0)),
DOF(name="x3", limits=(-5.0, 5.0), active=False),
DOF(BrownianMotion(name="brownian1"), read_only=True),
DOF(BrownianMotion(name="brownian2"), read_only=True, active=False),
DOF(name="x1", search_bounds=(-5.0, 5.0)),
DOF(name="x2", search_bounds=(-5.0, 5.0)),
DOF(name="x3", search_bounds=(-5.0, 5.0), active=False),
DOF(device=BrownianMotion(name="brownian1"), read_only=True),
DOF(device=BrownianMotion(name="brownian2"), read_only=True, active=False),
]

objectives = [
Expand Down
4 changes: 2 additions & 2 deletions blop/tests/test_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ def test_agent_save_load_data(agent, RE):
RE(agent.learn("qr", n=4))
agent.save_data("/tmp/test_save_data.h5")
agent.reset()
agent.load_data(data_file="/tmp/test_save_data.h5")
agent.load_data("/tmp/test_save_data.h5")
RE(agent.learn("qr", n=4))


def test_agent_save_load_hypers(agent, RE):
RE(agent.learn("qr", n=4))
agent.save_hypers("/tmp/test_save_hypers.h5")
agent.reset()
RE(agent.learn("qr", n=16, hypers_file="/tmp/test_save_hypers.h5"))
RE(agent.learn("qr", n=16, hypers="/tmp/test_save_hypers.h5"))

0 comments on commit 3c21a4b

Please sign in to comment.