Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PlotLosses figsize #142

Merged
merged 2 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 20 additions & 25 deletions examples/minimal.ipynb

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions livelossplot/outputs/matplotlib_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
after_subplot: Optional[Callable[[plt.Axes, str, str], None]] = None,
before_plots: Optional[Callable[[plt.Figure, np.ndarray, int], None]] = None,
after_plots: Optional[Callable[[plt.Figure], None]] = None,
figsize: Optional[Tuple[int, int]] = None,
):
"""
Args:
Expand All @@ -36,6 +37,7 @@ def __init__(
after_subplot: function which will be called after every subplot
before_plots: function which will be called before all subplots
after_plots: function which will be called after all subplots
figsize: optional tuple to explicitly set figure size (overrides cell_size calculation)
"""
self.cell_size = cell_size
self.max_cols = max_cols
Expand All @@ -47,6 +49,7 @@ def __init__(
self._after_subplot = after_subplot if after_subplot else self._default_after_subplot
self._before_plots = before_plots if before_plots else self._default_before_plots
self._after_plots = after_plots if after_plots else self._default_after_plots
self.figsize = figsize

def send(self, logger: MainLogger):
"""Draw figures with metrics and show"""
Expand Down Expand Up @@ -87,9 +90,12 @@ def _default_before_plots(self, fig: plt.Figure, axes: np.ndarray, num_of_log_gr
num_of_log_groups: number of log groups
"""
clear_output(wait=True)
figsize_x = self.max_cols * self.cell_size[0]
figsize_y = ((num_of_log_groups + 1) // self.max_cols + 1) * self.cell_size[1]
fig.set_size_inches(figsize_x, figsize_y)
if self.figsize is not None:
fig.set_size_inches(*self.figsize)
else:
figsize_x = self.max_cols * self.cell_size[0]
figsize_y = ((num_of_log_groups + 1) // self.max_cols + 1) * self.cell_size[1]
fig.set_size_inches(figsize_x, figsize_y)
if num_of_log_groups < axes.size:
for idx, ax in enumerate(axes[-1]):
if idx >= (num_of_log_groups + len(self.extra_plots)) % self.max_cols:
Expand Down
9 changes: 8 additions & 1 deletion livelossplot/plot_losses.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import warnings
from typing import Type, TypeVar, List, Union
from typing import Type, TypeVar, List, Union, Optional, Tuple

import livelossplot
from livelossplot.main_logger import MainLogger
from livelossplot import outputs
from livelossplot.outputs.matplotlib_plot import MatplotlibPlot

BO = TypeVar('BO', bound=outputs.BaseOutput)

Expand All @@ -12,10 +13,12 @@ class PlotLosses:
"""
Class collect metrics from the training engine and send it to plugins, when send is called
"""

def __init__(
self,
outputs: List[Union[Type[BO], str]] = ['MatplotlibPlot', 'ExtremaPrinter'],
mode: str = 'notebook',
figsize: Optional[Tuple[int, int]] = None,
**kwargs
):
"""
Expand All @@ -24,12 +27,16 @@ def __init__(
or strings for livelossplot built-in output methods with default parameters
mode: Options: 'notebook' or 'script' - some of outputs need to change some behaviors,
depending on the working environment
figsize: tuple of (width, height) in inches for the figure
**kwargs: key-arguments which are passed to MainLogger constructor
"""
self.logger = MainLogger(**kwargs)
self.outputs = [getattr(livelossplot.outputs, out)() if isinstance(out, str) else out for out in outputs]
for out in self.outputs:
out.set_output_mode(mode)
if figsize is not None and isinstance(out, MatplotlibPlot):
print(f"Setting figsize to {figsize}")
out.figsize = figsize

def update(self, *args, **kwargs):
"""update logs with arguments that will be passed to main logger"""
Expand Down
Loading