Skip to content

Commit

Permalink
Merge pull request #142 from stared/i139-figsize
Browse files Browse the repository at this point in the history
Add PlotLosses figsize
  • Loading branch information
stared authored Jan 3, 2025
2 parents 93aa155 + b5a0550 commit 40c887c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 29 deletions.
45 changes: 20 additions & 25 deletions examples/minimal.ipynb

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions livelossplot/outputs/matplotlib_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
after_subplot: Optional[Callable[[plt.Axes, str, str], None]] = None,
before_plots: Optional[Callable[[plt.Figure, np.ndarray, int], None]] = None,
after_plots: Optional[Callable[[plt.Figure], None]] = None,
figsize: Optional[Tuple[int, int]] = None,
):
"""
Args:
Expand All @@ -36,6 +37,7 @@ def __init__(
after_subplot: function which will be called after every subplot
before_plots: function which will be called before all subplots
after_plots: function which will be called after all subplots
figsize: optional tuple to explicitly set figure size (overrides cell_size calculation)
"""
self.cell_size = cell_size
self.max_cols = max_cols
Expand All @@ -47,6 +49,7 @@ def __init__(
self._after_subplot = after_subplot if after_subplot else self._default_after_subplot
self._before_plots = before_plots if before_plots else self._default_before_plots
self._after_plots = after_plots if after_plots else self._default_after_plots
self.figsize = figsize

def send(self, logger: MainLogger):
"""Draw figures with metrics and show"""
Expand Down Expand Up @@ -87,9 +90,12 @@ def _default_before_plots(self, fig: plt.Figure, axes: np.ndarray, num_of_log_gr
num_of_log_groups: number of log groups
"""
clear_output(wait=True)
figsize_x = self.max_cols * self.cell_size[0]
figsize_y = ((num_of_log_groups + 1) // self.max_cols + 1) * self.cell_size[1]
fig.set_size_inches(figsize_x, figsize_y)
if self.figsize is not None:
fig.set_size_inches(*self.figsize)
else:
figsize_x = self.max_cols * self.cell_size[0]
figsize_y = ((num_of_log_groups + 1) // self.max_cols + 1) * self.cell_size[1]
fig.set_size_inches(figsize_x, figsize_y)
if num_of_log_groups < axes.size:
for idx, ax in enumerate(axes[-1]):
if idx >= (num_of_log_groups + len(self.extra_plots)) % self.max_cols:
Expand Down
9 changes: 8 additions & 1 deletion livelossplot/plot_losses.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import warnings
from typing import Type, TypeVar, List, Union
from typing import Type, TypeVar, List, Union, Optional, Tuple

import livelossplot
from livelossplot.main_logger import MainLogger
from livelossplot import outputs
from livelossplot.outputs.matplotlib_plot import MatplotlibPlot

BO = TypeVar('BO', bound=outputs.BaseOutput)

Expand All @@ -12,10 +13,12 @@ class PlotLosses:
"""
Class collect metrics from the training engine and send it to plugins, when send is called
"""

def __init__(
self,
outputs: List[Union[Type[BO], str]] = ['MatplotlibPlot', 'ExtremaPrinter'],
mode: str = 'notebook',
figsize: Optional[Tuple[int, int]] = None,
**kwargs
):
"""
Expand All @@ -24,12 +27,16 @@ def __init__(
or strings for livelossplot built-in output methods with default parameters
mode: Options: 'notebook' or 'script' - some of outputs need to change some behaviors,
depending on the working environment
figsize: tuple of (width, height) in inches for the figure
**kwargs: key-arguments which are passed to MainLogger constructor
"""
self.logger = MainLogger(**kwargs)
self.outputs = [getattr(livelossplot.outputs, out)() if isinstance(out, str) else out for out in outputs]
for out in self.outputs:
out.set_output_mode(mode)
if figsize is not None and isinstance(out, MatplotlibPlot):
print(f"Setting figsize to {figsize}")
out.figsize = figsize

def update(self, *args, **kwargs):
"""update logs with arguments that will be passed to main logger"""
Expand Down

0 comments on commit 40c887c

Please sign in to comment.