Skip to content

Commit

Permalink
plot a second signal as vertex size along color
Browse files Browse the repository at this point in the history
  • Loading branch information
mdeff committed Nov 8, 2018
1 parent dc6cc66 commit 50274a5
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 72 deletions.
6 changes: 3 additions & 3 deletions doc/tutorials/intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ which assign a set of values (a vector in :math:`\mathbb{R}^d`) at every node
>>>
>>> fig, axes = plt.subplots(1, 2, figsize=(10, 3))
>>> for i, ax in enumerate(axes):
... _ = G.plot_signal(G.U[:, i+1], vertex_size=30, ax=ax)
... _ = G.plot_signal(G.U[:, i+1], size=30, ax=ax)
... _ = ax.set_title('Eigenvector {}'.format(i+2))
... ax.set_axis_off()
>>> fig.tight_layout()
Expand Down Expand Up @@ -227,9 +227,9 @@ low-pass filter.
>>> s2 = g.filter(s)
>>>
>>> fig, axes = plt.subplots(1, 2, figsize=(10, 3))
>>> _ = G.plot_signal(s, vertex_size=30, title='noisy', ax=axes[0])
>>> _ = G.plot_signal(s, size=30, title='noisy', ax=axes[0])
>>> axes[0].set_axis_off()
>>> _ = G.plot_signal(s2, vertex_size=30, title='cleaned', ax=axes[1])
>>> _ = G.plot_signal(s2, size=30, title='cleaned', ax=axes[1])
>>> axes[1].set_axis_off()
>>> fig.tight_layout()

Expand Down
13 changes: 6 additions & 7 deletions pygsp/graphs/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,15 +690,14 @@ def plot(self, edges=None, index=False, backend=None, vertex_size=None,
return _plot_graph(self, edges=edges, index=index, backend=backend,
vertex_size=vertex_size, title=title, ax=ax)

def plot_signal(self, signal=None, edges=None, vertex_size=None, highlight=[],
index=False, colorbar=True, limits=None, backend=None,
title=None, ax=None):
def plot_signal(self, color=None, size=None, highlight=[], edges=None,
index=False, colorbar=True, limits=None, ax=None,
title=None, backend=None):
r"""Docstring overloaded at import time."""
from pygsp.plotting import _plot_signal
return _plot_signal(self, signal=signal, edges=edges,
vertex_size=vertex_size, highlight=highlight,
index=index, colorbar=colorbar, limits=limits,
backend=backend, title=title, ax=ax)
return _plot_signal(self, color=color, size=size, highlight=highlight,
edges=edges, index=index, colorbar=colorbar,
limits=limits, ax=ax, title=title, backend=backend)

def plot_spectrogram(self, node_idx=None):
r"""Docstring overloaded at import time."""
Expand Down
130 changes: 68 additions & 62 deletions pygsp/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@

def _import_plt():
try:
import matplotlib as mpl
import matplotlib.pyplot as plt
# Not used directly, but needed for 3D projection.
from mpl_toolkits.mplot3d import Axes3D # noqa
except Exception:
raise ImportError('Cannot import matplotlib. Choose another backend '
'or try to install it with '
'pip (or conda) install matplotlib.')
return plt
return mpl, plt


def _import_qtg():
Expand All @@ -75,7 +76,7 @@ def inner(obj, **kwargs):

# Create a figure and an axis if none were passed.
if kwargs['ax'] is None:
plt = _import_plt()
_, plt = _import_plt()
fig = plt.figure()
global _plt_figures
_plt_figures.append(fig)
Expand Down Expand Up @@ -120,7 +121,7 @@ def close_all():

global _plt_figures
for fig in _plt_figures:
plt = _import_plt()
_, plt = _import_plt()
plt.close(fig)
_plt_figures = []

Expand All @@ -130,13 +131,13 @@ def show(*args, **kwargs):
By default, showing plots does not block the prompt.
"""
plt = _import_plt()
_, plt = _import_plt()
plt.show(*args, **kwargs)


def close(*args, **kwargs):
r"""Close created figures, alias to plt.close()."""
plt = _import_plt()
_, plt = _import_plt()
plt.close(*args, **kwargs)


Expand Down Expand Up @@ -202,8 +203,8 @@ def _plot_graph(G, edges, index, backend, vertex_size, title, ax):
if backend == 'pyqtgraph':
_qtg_plot_graph(G, edges=edges, vertex_size=vertex_size, title=title)
elif backend == 'matplotlib':
return _plt_plot_signal(G, signal=None, edges=edges,
vertex_size=vertex_size, highlight=[],
return _plt_plot_signal(G, color=None, edges=edges,
size=vertex_size, highlight=[],
index=index, colorbar=False, limits=[0, 0],
title=title, ax=ax)
else:
Expand Down Expand Up @@ -353,27 +354,30 @@ def _plt_plot_filter(filters, n, eigenvalues, sum, ax, **kwargs):
ax.set_ylabel(r'$\hat{g}(\lambda)$: filter response')


def _plot_signal(G, signal, edges, vertex_size, highlight, index, colorbar,
limits, backend, title, ax):
r"""Plot a signal on the graph.
def _plot_signal(G, color, size, highlight, edges, index, colorbar,
limits, ax, title, backend):
r"""Plot a graph with signals as color or vertex size.
Parameters
----------
signal : array of int
Signal to plot. Signal length should be equal to the number of nodes.
edges : bool
True to draw edges, false to only draw vertices.
Default True if less than 10,000 edges to draw.
Note that drawing many edges can be slow.
cp : list of int
NOT IMPLEMENTED. Camera position when plotting a 3D graph.
vertex_size : float
Size of circle representing each node.
Defaults to G.plotting['vertex_size'].
color : array-like or matplotlib color
Signal to plot as vertex color.
Signal length should be equal to the number of nodes.
If None, all vertices will have the same color.
Alternatively, a color (any format accepted by matplotlib) can be passed.
size : array-like or int
Signal to plot as vertex size (matplotlib only).
Signal length should be equal to the number of nodes.
If None, all vertices will have the size graph.plotting['vertex_size'].
Alternatively, a size can be passed as an integer.
highlight : iterable
List of indices of vertices to be highlighted.
Useful to e.g. show where a filter was localized.
Useful for example to show where a filter was localized.
Only available with the matplotlib backend.
edges : bool
Whether to draw edges in addition to vertices.
Default to True if less than 10,000 edges to draw.
Note that drawing many edges can be slow.
index : bool
Whether to print the node index (in the adjacency / Laplacian matrix
and signal vectors) on top of each node.
Expand All @@ -383,21 +387,16 @@ def _plot_signal(G, signal, edges, vertex_size, highlight, index, colorbar,
Whether to plot a colorbar indicating the signal's amplitude.
Only available with the matplotlib backend.
limits : [vmin, vmax]
Maps colors from vmin to vmax.
Map colors from vmin to vmax.
Defaults to signal minimum and maximum value.
Only available with the matplotlib backend.
bar : boolean
NOT IMPLEMENTED. Signal values are displayed using colors when False,
and bars when True (default False).
bar_width : int
NOT IMPLEMENTED. Width of the bar (default 1).
backend: {'matplotlib', 'pyqtgraph'}
Defines the drawing backend to use. Defaults to :data:`BACKEND`.
title : str
Title of the figure.
ax : :class:`matplotlib.axes.Axes`
Axes where to draw the graph. Optional, created if not passed.
Only available with the matplotlib backend.
title : str
Title of the figure.
backend: {'matplotlib', 'pyqtgraph'}
Defines the drawing backend to use. Defaults to :data:`BACKEND`.
Returns
-------
Expand All @@ -409,9 +408,9 @@ def _plot_signal(G, signal, edges, vertex_size, highlight, index, colorbar,
Examples
--------
>>> import matplotlib
>>> G = graphs.Grid2d(8)
>>> signal = np.sin((np.arange(8**2) * 2*np.pi/8**2))
>>> fig, ax = G.plot_signal(signal)
>>> graph = graphs.Sensor(seed=42)
>>> signal = np.random.RandomState(42).normal(size=graph.n_nodes)
>>> fig, ax = graph.plot_signal(color=signal, size=graph.dw)
"""
if not hasattr(G, 'coords'):
Expand All @@ -421,52 +420,59 @@ def _plot_signal(G, signal, edges, vertex_size, highlight, index, colorbar,
if G.coords.ndim != 1 and check_2d_3d:
raise AttributeError('Coordinates should be in 1D, 2D or 3D space.')

if signal is None:
mpl, _ = _import_plt()
if color is None or mpl.colors.is_color_like(color):
limits = [0, 0]
colorbar = False
else:
signal = signal.squeeze()
if G.coords.ndim == 2 and signal.ndim != 1:
color = np.asarray(color).squeeze()
if color.ndim == 0 or color.shape[0] != G.N:
raise ValueError('Signal should have length G.N = {}.'.format(G.N))
if G.coords.ndim != 1 and color.ndim != 1:
raise ValueError('Can plot only one signal (not {}) with {}D '
'coordinates.'.format(signal.shape[1],
'coordinates.'.format(color.shape[1],
G.coords.shape[1]))
if signal.shape[0] != G.N:
raise ValueError('Signal length is {}, should be '
'G.N = {}.'.format(signal.shape[0], G.N))
if np.any(np.iscomplex(signal)):
raise ValueError("Can't display complex signals.")

if size is None:
size = G.plotting['vertex_size']
elif not np.isscalar(size):
size = np.asarray(size).squeeze()
if size.shape[0] != G.N:
raise ValueError('Signal should have length G.N = {}.'.format(G.N))
if size.ndim != 1:
raise ValueError('Can plot only one signal (not {}) as '
'size.'.format(size.shape[1]))
size -= size.min()
size /= size.max() + 1e-10
size = 500 * size**2 + 50

if edges is None:
edges = G.Ne < 10e3

if vertex_size is None:
vertex_size = G.plotting['vertex_size']
if limits is None:
limits = [1.05*color.min(), 1.05*color.max()]

if title is None:
title = G.__repr__(limit=4)

if limits is None:
limits = [1.05*signal.min(), 1.05*signal.max()]

if backend is None:
backend = BACKEND

G = _handle_directed(G)

if backend == 'pyqtgraph':
_qtg_plot_signal(G, signal=signal, edges=edges,
vertex_size=vertex_size, limits=limits, title=title)
_qtg_plot_signal(G, signal=color, vertex_size=size, edges=edges,
limits=limits, title=title)
elif backend == 'matplotlib':
return _plt_plot_signal(G, signal=signal, edges=edges,
vertex_size=vertex_size, limits=limits,
title=title, highlight=highlight,
index=index, colorbar=colorbar, ax=ax)
return _plt_plot_signal(G, color=color, size=size, highlight=highlight,
edges=edges, index=index, colorbar=colorbar,
limits=limits, ax=ax, title=title)
else:
raise ValueError('Unknown backend {}.'.format(backend))


@_plt_handle_figure
def _plt_plot_signal(G, signal, edges, vertex_size, highlight, index, colorbar,
def _plt_plot_signal(G, color, size, highlight, edges, index, colorbar,
limits, ax):

if edges:
Expand Down Expand Up @@ -503,27 +509,27 @@ def _plt_plot_signal(G, signal, edges, vertex_size, highlight, index, colorbar,
coords_hl = G.coords[highlight]

if G.coords.ndim == 1:
ax.plot(G.coords, signal, alpha=0.5)
ax.plot(G.coords, color, alpha=0.5)
ax.set_ylim(limits)
for coord_hl in coords_hl:
ax.axvline(x=coord_hl, color='C1', linewidth=2)

elif G.coords.shape[1] == 2:
sc = ax.scatter(G.coords[:, 0], G.coords[:, 1],
c=signal, s=vertex_size,
c=color, s=size,
marker='o', linewidths=0, alpha=0.5, zorder=2,
vmin=limits[0], vmax=limits[1])
ax.scatter(coords_hl[:, 0], coords_hl[:, 1],
s=2*vertex_size, zorder=3,
s=2*size, zorder=3,
marker='o', c='None', edgecolors='C1', linewidths=2)

elif G.coords.shape[1] == 3:
sc = ax.scatter(G.coords[:, 0], G.coords[:, 1], G.coords[:, 2],
c=signal, s=vertex_size,
c=color, s=size,
marker='o', linewidths=0, alpha=0.5, zorder=2,
vmin=limits[0], vmax=limits[1])
ax.scatter(coords_hl[:, 0], coords_hl[:, 1], coords_hl[:, 2],
s=2*vertex_size, zorder=3,
s=2*size, zorder=3,
marker='o', c='None', edgecolors='C1', linewidths=2)
try:
ax.view_init(elev=G.plotting['elevation'],
Expand All @@ -533,7 +539,7 @@ def _plt_plot_signal(G, signal, edges, vertex_size, highlight, index, colorbar,
pass

if G.coords.ndim != 1 and colorbar:
plt = _import_plt()
_, plt = _import_plt()
plt.colorbar(sc, ax=ax)

if index:
Expand Down
15 changes: 15 additions & 0 deletions pygsp/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,20 @@ def test(G):
G = graphs.Torus(Nv=5)
test(G)

def test_signals(self):
"""Test the different kind of parameters that can be passed."""
G = graphs.Sensor()
G.plot_signal()
G.plot_signal(G.dw)
G.plot_signal(G.dw, list(G.dw))
G.plot_signal(list(G.dw), G.dw)
G.plot_signal(G.dw[:, np.newaxis], G.dw[np.newaxis, :])
G.plot_signal('r', 100)
G.plot_signal((0.5, 0.5, 0.5, 0.5))
self.assertRaises(ValueError, G.plot_signal, 10)
self.assertRaises(ValueError, G.plot_signal, (0.5, 0.5))
self.assertRaises(ValueError, G.plot_signal, size=[2, 3, 4, 5])
self.assertRaises(ValueError, G.plot_signal, size=[G.dw, G.dw])


suite = unittest.TestLoader().loadTestsFromTestCase(TestCase)

0 comments on commit 50274a5

Please sign in to comment.