Skip to content

Commit

Permalink
Plotting: only include time slider when needed
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 731246552
  • Loading branch information
jcitrin authored and Torax team committed Feb 27, 2025
1 parent b50deb1 commit 40d863f
Showing 1 changed file with 34 additions and 17 deletions.
51 changes: 34 additions & 17 deletions torax/plotting/plotruns_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ def __post_init__(self):
if len(self.axes) > self.rows * self.cols:
raise ValueError('len(axes) in plot_config is more than rows * columns.')

@property
def contains_spatial_plot_type(self) -> bool:
"""Checks if any plot is a spatial plottype."""
return any(
plot_properties.plot_type == PlotType.SPATIAL
for plot_properties in self.axes
)


@dataclasses.dataclass
class PlotData:
Expand Down Expand Up @@ -164,8 +172,8 @@ class PlotData:
ti_volume_avg: Volume-averaged ion temperature [:math:`\mathrm{keV}`].
ne_volume_avg: Volume-averaged electron density [:math:`\mathrm{10^{20}
m^{-3}}`].
ni_volume_avg: Volume-averaged ion density
[:math:`\mathrm{10^{20} m^{-3}}`].
ni_volume_avg: Volume-averaged ion density [:math:`\mathrm{10^{20}
m^{-3}}`].
W_thermal_tot: Total thermal stored energy [:math:`\mathrm{MJ}`].
q95: Safety factor at 95% of the normalized poloidal flux.
"""
Expand Down Expand Up @@ -443,17 +451,19 @@ def plot_run(
)

format_plots(plot_config, plotdata1, plotdata2, axes)
timeslider = create_slider(slider_ax, plotdata1, plotdata2)
fig.canvas.draw()

def update(newtime):
"""Update plots with new values following slider manipulation."""
fig.constrained_layout = False
_update(newtime, plot_config, plotdata1, lines1, plotdata2, lines2)
fig.constrained_layout = True
fig.canvas.draw_idle()
# Only create the slider if needed.
if plot_config.contains_spatial_plot_type:
timeslider = create_slider(slider_ax, plotdata1, plotdata2)
def update(newtime):
"""Update plots with new values following slider manipulation."""
fig.constrained_layout = False
_update(newtime, plot_config, plotdata1, lines1, plotdata2, lines2)
fig.constrained_layout = True
fig.canvas.draw_idle()

timeslider.on_changed(update)

timeslider.on_changed(update)
fig.canvas.draw()
plt.show()

Expand Down Expand Up @@ -668,16 +678,23 @@ def create_figure(plot_config: FigureProperties):
),
constrained_layout=True,
)
# Create the GridSpec - leave space for the slider at the bottom
gs = gridspec.GridSpec(
rows + 1, cols, figure=fig, height_ratios=[1] * rows + [0.2]
) # Adjust 0.2 for slider height
# Create the GridSpec - Adjust height ratios to include the slider
# in the plot, only if a slider is required:
if plot_config.contains_spatial_plot_type:
# Add an extra smaller is a spatial plottypeider
height_ratios = [1] * rows + [0.2]
gs = gridspec.GridSpec(
rows + 1, cols, figure=fig, height_ratios=height_ratios
)
# slider spans all columns
slider_ax = fig.add_subplot(gs[rows, :])
else:
gs = gridspec.GridSpec(rows, cols, figure=fig)
slider_ax = None

axes = []
for i in range(rows * cols):
row = i // cols
col = i % cols
axes.append(fig.add_subplot(gs[row, col])) # Add subplots to the grid
# slider spans all columns in the last row
slider_ax = fig.add_subplot(gs[rows, :])
return fig, axes, slider_ax

0 comments on commit 40d863f

Please sign in to comment.