From 40d863f1da7c876375648147e1d0ab60fde90f42 Mon Sep 17 00:00:00 2001 From: Jonathan Citrin Date: Wed, 26 Feb 2025 03:36:01 -0800 Subject: [PATCH] Plotting: only include time slider when needed PiperOrigin-RevId: 731246552 --- torax/plotting/plotruns_lib.py | 51 ++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/torax/plotting/plotruns_lib.py b/torax/plotting/plotruns_lib.py index ef0fa47a..03715c71 100644 --- a/torax/plotting/plotruns_lib.py +++ b/torax/plotting/plotruns_lib.py @@ -82,6 +82,14 @@ def __post_init__(self): if len(self.axes) > self.rows * self.cols: raise ValueError('len(axes) in plot_config is more than rows * columns.') + @property + def contains_spatial_plot_type(self) -> bool: + """Checks if any plot is a spatial plottype.""" + return any( + plot_properties.plot_type == PlotType.SPATIAL + for plot_properties in self.axes + ) + @dataclasses.dataclass class PlotData: @@ -164,8 +172,8 @@ class PlotData: ti_volume_avg: Volume-averaged ion temperature [:math:`\mathrm{keV}`]. ne_volume_avg: Volume-averaged electron density [:math:`\mathrm{10^{20} m^{-3}}`]. - ni_volume_avg: Volume-averaged ion density - [:math:`\mathrm{10^{20} m^{-3}}`]. + ni_volume_avg: Volume-averaged ion density [:math:`\mathrm{10^{20} + m^{-3}}`]. W_thermal_tot: Total thermal stored energy [:math:`\mathrm{MJ}`]. q95: Safety factor at 95% of the normalized poloidal flux. """ @@ -443,17 +451,19 @@ def plot_run( ) format_plots(plot_config, plotdata1, plotdata2, axes) - timeslider = create_slider(slider_ax, plotdata1, plotdata2) - fig.canvas.draw() - def update(newtime): - """Update plots with new values following slider manipulation.""" - fig.constrained_layout = False - _update(newtime, plot_config, plotdata1, lines1, plotdata2, lines2) - fig.constrained_layout = True - fig.canvas.draw_idle() + # Only create the slider if needed. + if plot_config.contains_spatial_plot_type: + timeslider = create_slider(slider_ax, plotdata1, plotdata2) + def update(newtime): + """Update plots with new values following slider manipulation.""" + fig.constrained_layout = False + _update(newtime, plot_config, plotdata1, lines1, plotdata2, lines2) + fig.constrained_layout = True + fig.canvas.draw_idle() + + timeslider.on_changed(update) - timeslider.on_changed(update) fig.canvas.draw() plt.show() @@ -668,16 +678,23 @@ def create_figure(plot_config: FigureProperties): ), constrained_layout=True, ) - # Create the GridSpec - leave space for the slider at the bottom - gs = gridspec.GridSpec( - rows + 1, cols, figure=fig, height_ratios=[1] * rows + [0.2] - ) # Adjust 0.2 for slider height + # Create the GridSpec - Adjust height ratios to include the slider + # in the plot, only if a slider is required: + if plot_config.contains_spatial_plot_type: + # Add an extra smaller is a spatial plottypeider + height_ratios = [1] * rows + [0.2] + gs = gridspec.GridSpec( + rows + 1, cols, figure=fig, height_ratios=height_ratios + ) + # slider spans all columns + slider_ax = fig.add_subplot(gs[rows, :]) + else: + gs = gridspec.GridSpec(rows, cols, figure=fig) + slider_ax = None axes = [] for i in range(rows * cols): row = i // cols col = i % cols axes.append(fig.add_subplot(gs[row, col])) # Add subplots to the grid - # slider spans all columns in the last row - slider_ax = fig.add_subplot(gs[rows, :]) return fig, axes, slider_ax