diff --git a/xplt/base.py b/xplt/base.py index c0e44bf..1c4dcda 100644 --- a/xplt/base.py +++ b/xplt/base.py @@ -670,19 +670,14 @@ def _create_artists(self, callback): axis is the axis and the string p is the property to plot. """ self.artists = [] - self._legend_entries = [] for i, ppp in enumerate(self.on_y): self.artists.append([]) - self._legend_entries.append([]) for j, pp in enumerate(ppp): self.artists[i].append([]) a = self.axis(i, j) for k, p in enumerate(pp): artist = callback(i, j, k, a, p) self.artists[i][j].append(artist) - for art in flattened(artist): - if art: - self._legend_entries[i].append(art) self.legend(i, show="auto") @@ -747,14 +742,26 @@ def legend(self, subplot="all", show=True, **kwargs): show = True # always show legend if single subplot is specified for s in subplot: - # use topmost axes for legend - ax = self.axflat_twin[s][-1] if len(self.axflat_twin[s]) > 0 else self.axflat[s] - handles = self._legend_entries[s] + # aggregate handles and use topmost axes for legend + handles = [] + for ax in [self.axflat[s], *self.axflat_twin[s]]: + handles.extend(ax.get_legend_handles_labels()[0]) if not show or (show == "auto" and len(handles) <= 1): if ax.get_legend(): ax.get_legend().remove() else: - ax.legend(handles=handles, **kwargs) + # join handles + handle_map = { + h: [h] for h in handles if not hasattr(h, "_join_legend_entry_with") + } + for h in handles: + if main_handle := getattr(h, "_join_legend_entry_with", None): + handle_map[main_handle].append(h) + handles = [tuple(hs) for hs in handle_map.values()] + labels = [h.get_label() for h in handle_map] + + # show legend + ax.legend(handles=handles, labels=labels, **kwargs) def autoscale(self, subplot="all", reset=False, freeze=True): """Autoscale the axes of a subplot diff --git a/xplt/timestructure.py b/xplt/timestructure.py index 2ebee35..74f110d 100644 --- a/xplt/timestructure.py +++ b/xplt/timestructure.py @@ -1060,6 +1060,7 @@ def create_artists(i, j, k, ax, p): self._errkw = kwargs.copy() self._errkw.update(zorder=1.8, alpha=0.3, ls="-", lw=0) errorbar = ax.fill_between([], [], [], **self._errkw) + errorbar._join_legend_entry_with = plot else: errorbar = None if poisson: @@ -1073,11 +1074,7 @@ def create_artists(i, j, k, ax, p): # legend with combined patch if std: - # merge plot and errorbar patches - for i, h in enumerate(self._legend_entries): - labels = [h[0].get_label()] + [_.get_label() for _ in h[2:]] - self._legend_entries[i] = [tuple(h[0:2])] + h[2:] - self.legend(i, show="auto", labels=labels) + self.legend() # set data if particles is not None: