diff --git a/process/io/plot_solutions.py b/process/io/plot_solutions.py index 2de7875d8..17396f434 100644 --- a/process/io/plot_solutions.py +++ b/process/io/plot_solutions.py @@ -54,6 +54,7 @@ def plot_mfile_solutions( normalising_tag: Optional[str] = None, rmse: bool = False, normalisation_type: Optional[str] = "init", + plot_obj_func: bool = True, ) -> Tuple[mpl.figure.Figure, pd.DataFrame]: """Plot multiple solutions, optionally normalised by a given solution. @@ -68,6 +69,8 @@ def plot_mfile_solutions( :type rmse: bool, optional :param normalisation_type: opt param normalisation to use: one of ["init", "range", None], defaults to "init" :type normalisation_type: str, optional + :param plot_obj_func: show the objective function plot, defaults to True + :type: plot_obj_func: bool, optional :return: figure and dataframe of solutions :rtype: Tuple[mpl.figure.Figure, pd.DataFrame] """ @@ -122,6 +125,7 @@ def plot_mfile_solutions( rmse_df=rmse_df, normalising_tag=normalising_tag, normalisation_type=normalisation_type, + plot_obj_func=plot_obj_func, ) # Return fig for optional further customisation @@ -314,6 +318,7 @@ def _plot_solutions( plot_title: str, normalising_tag: Union[str, None], rmse_df: pd.DataFrame, + plot_obj_func: bool, ) -> mpl.figure.Figure: """Plot multiple solutions, optionally normalised by a given solution. @@ -329,6 +334,8 @@ def _plot_solutions( :type normalising_tag: Union[str, None] :param rmse_df: RMS errors relative to reference solution :type rmse_df: pd.DataFrame + :param plot_obj_func: show the objective function plot if True + :type: plot_obj_func: bool :return: figure containing varying numbers of axes :rtype: mpl.figure.Figure """ @@ -344,7 +351,7 @@ def _plot_solutions( # Acquire objective function name(s), then check only one type is being plotted objf_list = norm_objf_df[NORM_OBJF_NAME].unique() - if len(objf_list) != 1: + if plot_obj_func and len(objf_list) != 1: raise ValueError("Can't plot different objective functions on the same plot") objf_name = objf_list[0] @@ -400,29 +407,42 @@ def _plot_solutions( ) + external_legend_height figsize = [7, final_subfig_height_opt_params] - # Add obj func/RMSE subfig - base_subfig_height_obj_func = 0.75 - if rmse_df is not None: + # Add fig height for obj func/RMSE subfig if required + if plot_obj_func and rmse_df is not None: nrows_obj_func_rmse_subplot = 2 - final_subfig_height_obj_func = ( - nrows_obj_func_rmse_subplot * base_subfig_height_obj_func - ) - else: + elif plot_obj_func or rmse_df is not None: nrows_obj_func_rmse_subplot = 1 - final_subfig_height_obj_func = base_subfig_height_obj_func + else: + nrows_obj_func_rmse_subplot = 0 + base_subfig_height_obj_func = 0.75 + final_subfig_height_obj_func = ( + nrows_obj_func_rmse_subplot * base_subfig_height_obj_func + ) figsize[1] += final_subfig_height_obj_func - fig = plt.figure(layout="constrained", figsize=figsize) fig.suptitle(plot_title) - # 2 subfig rows: 1 for opt params, 1 for objective function and optional RMSE - # Use actual calculated height to define height ratios - subfigs = fig.subfigures( - nrows=2, - ncols=1, - height_ratios=[final_subfig_height_opt_params, final_subfig_height_obj_func], - ) - axs_opt_params = subfigs[0].subplots(nrows=nrows_opt_param_subplot) + + # Plot opt params plot only: one figure + # Opt params and objective function/RMSE plots: fig with 2 subfigures + + if nrows_obj_func_rmse_subplot == 0: + # Just plotting opt params fig + opt_params_subfig = fig + else: + # Plotting opt params and obj func/rmse subfigs + # Use actual calculated height to define height ratios + subfigs = fig.subfigures( + nrows=2, + ncols=1, + height_ratios=[ + final_subfig_height_opt_params, + final_subfig_height_obj_func, + ], + ) + opt_params_subfig = subfigs[0] + obj_func_subfig = subfigs[1] + axs_opt_params = opt_params_subfig.subplots(nrows=nrows_opt_param_subplot) # Adapt x axis label for normalisation type if normalisation_type == "init": @@ -468,8 +488,8 @@ def _plot_solutions( handles, labels = axs_opt_params[1].get_legend_handles_labels() axs_opt_params[0].legend(handles=handles, labels=labels, ncols=3, loc="center") - subfigs[0].supxlabel(x_axis_label) - subfigs[0].supylabel("Optimisation parameter") + opt_params_subfig.supxlabel(x_axis_label) + opt_params_subfig.supylabel("Optimisation parameter") else: sns.stripplot( data=opt_params_values_with_names_df_melt, @@ -485,30 +505,31 @@ def _plot_solutions( axs_opt_params.legend() axs_opt_params.grid() - # Plot objf change separately - axs_opt_params = subfigs[1].subplots(nrows=nrows_obj_func_rmse_subplot) - if nrows_obj_func_rmse_subplot > 1: - ax_obj_func = axs_opt_params[0] - ax_rmse = axs_opt_params[1] - else: - ax_obj_func = axs_opt_params - - # Melt for seaborn stripplot - norm_objf_values_df_melt = norm_objf_values_df.melt(id_vars=TAG) - sns.stripplot( - data=norm_objf_values_df_melt, - x="value", - y="variable", - hue=TAG, - jitter=True, - ax=ax_obj_func, - formatter=lambda label: objf_name, - ) + if plot_obj_func: + # Plot objf change separately + axs_opt_params = obj_func_subfig.subplots(nrows=nrows_obj_func_rmse_subplot) + if nrows_obj_func_rmse_subplot > 1: + ax_obj_func = axs_opt_params[0] + ax_rmse = axs_opt_params[1] + else: + ax_obj_func = axs_opt_params + + # Melt for seaborn stripplot + norm_objf_values_df_melt = norm_objf_values_df.melt(id_vars=TAG) + sns.stripplot( + data=norm_objf_values_df_melt, + x="value", + y="variable", + hue=TAG, + jitter=True, + ax=ax_obj_func, + formatter=lambda label: objf_name, + ) - ax_obj_func.get_legend().remove() - # Objective function values are not normalised (no initial value): be explicit - ax_obj_func.set_xlabel("Objective function value") - ax_obj_func.set_ylabel("") + ax_obj_func.get_legend().remove() + # Objective function values are not normalised (no initial value): be explicit + ax_obj_func.set_xlabel("Objective function value") + ax_obj_func.set_ylabel("") if rmse_df is not None: # Ensure solution legend colours are the same: tags match between opt