Skip to content

Commit

Permalink
Improve joined legend labels
Browse files Browse the repository at this point in the history
  • Loading branch information
eltos committed Jan 2, 2024
1 parent 64dfc24 commit 040ee59
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
25 changes: 16 additions & 9 deletions xplt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,19 +670,14 @@ def _create_artists(self, callback):
axis is the axis and the string p is the property to plot.
"""
self.artists = []
self._legend_entries = []
for i, ppp in enumerate(self.on_y):
self.artists.append([])
self._legend_entries.append([])
for j, pp in enumerate(ppp):
self.artists[i].append([])
a = self.axis(i, j)
for k, p in enumerate(pp):
artist = callback(i, j, k, a, p)
self.artists[i][j].append(artist)
for art in flattened(artist):
if art:
self._legend_entries[i].append(art)

self.legend(i, show="auto")

Expand Down Expand Up @@ -747,14 +742,26 @@ def legend(self, subplot="all", show=True, **kwargs):
show = True # always show legend if single subplot is specified

for s in subplot:
# use topmost axes for legend
ax = self.axflat_twin[s][-1] if len(self.axflat_twin[s]) > 0 else self.axflat[s]
handles = self._legend_entries[s]
# aggregate handles and use topmost axes for legend
handles = []
for ax in [self.axflat[s], *self.axflat_twin[s]]:
handles.extend(ax.get_legend_handles_labels()[0])
if not show or (show == "auto" and len(handles) <= 1):
if ax.get_legend():
ax.get_legend().remove()
else:
ax.legend(handles=handles, **kwargs)
# join handles
handle_map = {
h: [h] for h in handles if not hasattr(h, "_join_legend_entry_with")
}
for h in handles:
if main_handle := getattr(h, "_join_legend_entry_with", None):
handle_map[main_handle].append(h)
handles = [tuple(hs) for hs in handle_map.values()]
labels = [h.get_label() for h in handle_map]

# show legend
ax.legend(handles=handles, labels=labels, **kwargs)

def autoscale(self, subplot="all", reset=False, freeze=True):
"""Autoscale the axes of a subplot
Expand Down
7 changes: 2 additions & 5 deletions xplt/timestructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,7 @@ def create_artists(i, j, k, ax, p):
self._errkw = kwargs.copy()
self._errkw.update(zorder=1.8, alpha=0.3, ls="-", lw=0)
errorbar = ax.fill_between([], [], [], **self._errkw)
errorbar._join_legend_entry_with = plot
else:
errorbar = None
if poisson:
Expand All @@ -1073,11 +1074,7 @@ def create_artists(i, j, k, ax, p):

# legend with combined patch
if std:
# merge plot and errorbar patches
for i, h in enumerate(self._legend_entries):
labels = [h[0].get_label()] + [_.get_label() for _ in h[2:]]
self._legend_entries[i] = [tuple(h[0:2])] + h[2:]
self.legend(i, show="auto", labels=labels)
self.legend()

# set data
if particles is not None:
Expand Down

0 comments on commit 040ee59

Please sign in to comment.