Skip to content

Commit

Permalink
change marker for bad points and fix colorbar
Browse files Browse the repository at this point in the history
  • Loading branch information
jennmald committed Feb 3, 2025
1 parent 7fc4151 commit df2819e
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/blop/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL
targets = agent.train_targets()[obj.name].numpy()

values = obj._untransform(targets)
# mask does not generate properly when values is a tensor (returns values of 0 instead of booleans)
values = np.array(values)

val_vmin, val_vmax = np.nanquantile(values, q=[0.01, 0.99])
val_norm = (
Expand All @@ -170,10 +172,11 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL
obj_vmin, obj_vmax = np.nanquantile(targets, q=[0.01, 0.99])
obj_norm = mpl.colors.Normalize(obj_vmin, obj_vmax)

# mask for nan values, uses x marker
# mask for nan values, uses unfilled o marker
mask = np.isnan(values)
val_ax = agent.obj_axes[obj_index, 0].scatter(x_values, y_values, c=values, s=size, norm=val_norm, cmap=cmap)
val_ax = agent.obj_axes[obj_index, 0].scatter(x_values[mask], y_values[mask], c="k", marker="x", linewidths=4, s=75)

val_ax = agent.obj_axes[obj_index, 0].scatter(x_values[~mask], y_values[~mask], c=values[~mask], s=size, norm=val_norm, cmap=cmap)
agent.obj_axes[obj_index, 0].scatter(x_values[mask], y_values[mask], marker="o", ec='k', fc='w', s=size)

# mean and sigma will have shape (*input_shape,)
test_posterior = obj.model.posterior(model_inputs)
Expand Down Expand Up @@ -266,7 +269,7 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL
norm=mpl.colors.LogNorm(),
)

val_cbar = agent.obj_fig.colorbar(val_ax, ax=agent.obj_axes[obj_index, 0], location="bottom", aspect=32, shrink=0.8)
val_cbar = agent.obj_fig.colorbar(val_ax, ax=agent.obj_axes[obj_index,0], location="bottom", aspect=32, shrink=0.8)
val_cbar.set_label(f"{obj.units or ''}")

if obj.constraint is None:
Expand Down

0 comments on commit df2819e

Please sign in to comment.