Skip to content

Commit

Permalink
Merge pull request #80 from jennmald/plotting
Browse files Browse the repository at this point in the history
fix fitness and constraint plots
  • Loading branch information
thomaswmorris authored Feb 3, 2025
2 parents 0d80ed3 + 49752b5 commit fb2b081
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions src/blop/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL
targets = agent.train_targets()[obj.name].numpy()

values = obj._untransform(targets)
# mask does not generate properly when values is a tensor (returns values of 0 instead of booleans)
values = np.array(values)

val_vmin, val_vmax = np.nanquantile(values, q=[0.01, 0.99])
val_norm = (
Expand All @@ -170,7 +172,13 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL
obj_vmin, obj_vmax = np.nanquantile(targets, q=[0.01, 0.99])
obj_norm = mpl.colors.Normalize(obj_vmin, obj_vmax)

val_ax = agent.obj_axes[obj_index, 0].scatter(x_values, y_values, c=values, s=size, norm=val_norm, cmap=cmap)
# mask for nan values, uses unfilled o marker
mask = np.isnan(values)

val_ax = agent.obj_axes[obj_index, 0].scatter(
x_values[~mask], y_values[~mask], c=values[~mask], s=size, norm=val_norm, cmap=cmap
)
agent.obj_axes[obj_index, 0].scatter(x_values[mask], y_values[mask], marker="o", ec="k", fc="w", s=size)

# mean and sigma will have shape (*input_shape,)
test_posterior = obj.model.posterior(model_inputs)
Expand All @@ -180,9 +188,15 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL
# test_values = obj.fitness_inverse(test_mean) if obj.kind == "fitness" else test_mean

test_constraint = None
if obj.constraint is not None:
if obj.constraint is None:
# test_constraint = obj.constraint_probability(model_inputs).detach().squeeze().numpy()
test_constraint = agent.constraint(model_inputs).squeeze().numpy()
else:
test_constraint = obj.constraint_probability(model_inputs).detach().squeeze().numpy()

fitness_ax = None
fit_err_ax = None

if gridded:
# _ = agent.obj_axes[obj_index, 1].pcolormesh(
# test_x,
Expand All @@ -192,22 +206,22 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL
# cmap=cmap,
# norm=val_norm,
# )
if obj.constraint is not None:
if obj.constraint is None:
fitness_ax = agent.obj_axes[obj_index, 1].pcolormesh(
test_x,
test_y,
test_mean,
shading=shading,
cmap=cmap,
norm=obj_norm,
cmap=cmap,
)
fit_err_ax = agent.obj_axes[obj_index, 2].pcolormesh(
test_x,
test_y,
test_sigma,
shading=shading,
cmap=cmap,
norm=mpl.colors.LogNorm(),
cmap=cmap,
)

if test_constraint is not None:
Expand Down Expand Up @@ -260,7 +274,7 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL
val_cbar = agent.obj_fig.colorbar(val_ax, ax=agent.obj_axes[obj_index, 0], location="bottom", aspect=32, shrink=0.8)
val_cbar.set_label(f"{obj.units or ''}")

if obj.constraint is not None:
if obj.constraint is None:
_ = agent.obj_fig.colorbar(fitness_ax, ax=agent.obj_axes[obj_index, 1], location="bottom", aspect=32, shrink=0.8)
_ = agent.obj_fig.colorbar(fit_err_ax, ax=agent.obj_axes[obj_index, 2], location="bottom", aspect=32, shrink=0.8)

Expand Down

0 comments on commit fb2b081

Please sign in to comment.