forked from epfl-lts2/pygsp
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
379 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
r""" | ||
Filtering a graph signal | ||
======================== | ||
A graph signal is filtered by transforming it to the spectral domain (via the | ||
Fourier transform), performing a point-wise multiplication (motivated by the | ||
convolution theorem), and transforming it back to the vertex domain (via the | ||
inverse graph Fourier transform). | ||
.. note:: | ||
In practice, filtering is implemented in the vertex domain to avoid the | ||
computationally expensive graph Fourier transform. To do so, filters are | ||
implemented as polynomials of the eigenvalues / Laplacian. Hence, filtering | ||
a signal reduces to its multiplications with sparse matrices (the graph | ||
Laplacian). | ||
""" | ||
|
||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
import pygsp as pg | ||
|
||
#plt.rc('font', family='Latin Modern Roman') | ||
plt.rc('text', usetex=True) | ||
plt.rc('text.latex', preamble=r'\usepackage{lmodern}') | ||
|
||
G = pg.graphs.Sensor(seed=42) | ||
G.compute_fourier_basis() | ||
|
||
#g = pg.filters.Rectangular(G, band_max=0.2) | ||
g = pg.filters.Expwin(G, band_max=0.5) | ||
|
||
fig, axes = plt.subplots(1, 3, figsize=(12, 4)) | ||
fig.subplots_adjust(hspace=0.5) | ||
|
||
x = np.random.RandomState(1).normal(size=G.N) | ||
#x = np.random.RandomState(42).uniform(-1, 1, size=G.N) | ||
x = 3 * x / np.linalg.norm(x) | ||
y = g.filter(x) | ||
x_hat = G.gft(x).squeeze() | ||
y_hat = G.gft(y).squeeze() | ||
|
||
limits = [x.min(), x.max()] | ||
|
||
G.plot(x, limits=limits, ax=axes[0], title='input signal $x$ in the vertex domain') | ||
axes[0].text(0, -0.1, '$x^T L x = {:.2f}$'.format(G.dirichlet_energy(x))) | ||
axes[0].set_axis_off() | ||
|
||
g.plot(ax=axes[1], alpha=1) | ||
line_filt = axes[1].lines[-2] | ||
line_in, = axes[1].plot(G.e, np.abs(x_hat), '.-') | ||
line_out, = axes[1].plot(G.e, np.abs(y_hat), '.-') | ||
#axes[1].set_xticks(range(0, 16, 4)) | ||
axes[1].set_xlabel(r'graph frequency $\lambda$') | ||
axes[1].set_ylabel(r'frequency content $\hat{x}(\lambda)$') | ||
axes[1].set_title(r'signals in the spectral domain') | ||
axes[1].legend(['input signal $\hat{x}$']) | ||
labels = [ | ||
r'input signal $\hat{x}$', | ||
'kernel $g$', | ||
r'filtered signal $\hat{y}$', | ||
] | ||
axes[1].legend([line_in, line_filt, line_out], labels, loc='upper right') | ||
|
||
G.plot(y, limits=limits, ax=axes[2], title='filtered signal $y$ in the vertex domain') | ||
axes[2].text(0, -0.1, '$y^T L y = {:.2f}$'.format(G.dirichlet_energy(y))) | ||
axes[2].set_axis_off() | ||
|
||
fig.tight_layout() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
r""" | ||
Fourier basis of graphs | ||
======================= | ||
The eigenvectors of the graph Laplacian form the Fourier basis. | ||
The eigenvalues are a measure of variation of their corresponding eigenvector. | ||
The lower the eigenvalue, the smoother the eigenvector. They are hence a | ||
measure of "frequency". | ||
In classical signal processing, Fourier modes are completely delocalized, like | ||
on the grid graph. For general graphs however, Fourier modes might be | ||
localized. See :attr:`pygsp.graphs.Graph.coherence`. | ||
""" | ||
|
||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
import pygsp as pg | ||
|
||
#plt.rc('font', family='Latin Modern Roman') | ||
plt.rc('text', usetex=True) | ||
plt.rc('text.latex', preamble=r'\usepackage{lmodern}') | ||
|
||
n_eigenvectors = 7 | ||
|
||
fig, axes = plt.subplots(2, 7, figsize=(15, 4)) | ||
|
||
def plot_eigenvectors(G, axes): | ||
G.compute_fourier_basis(n_eigenvectors) | ||
limits = [f(G.U) for f in (np.min, np.max)] | ||
for i, ax in enumerate(axes): | ||
G.plot(G.U[:, i], limits=limits, colorbar=False, vertex_size=50, ax=ax) | ||
energy = abs(G.dirichlet_energy(G.U[:, i])) | ||
ax.set_title(r'$u_{0}^\top L u_{0} = {1:.2f}$'.format(i+1, energy)) | ||
ax.set_axis_off() | ||
|
||
G = pg.graphs.Grid2d(10, 10) | ||
plot_eigenvectors(G, axes[0]) | ||
fig.subplots_adjust(hspace=0.5, right=0.8) | ||
cax = fig.add_axes([0.82, 0.60, 0.01, 0.26]) | ||
fig.colorbar(axes[0, -1].collections[1], cax=cax, ticks=[-0.2, 0, 0.2]) | ||
|
||
G = pg.graphs.Sensor(seed=42) | ||
plot_eigenvectors(G, axes[1]) | ||
fig.subplots_adjust(hspace=0.5, right=0.8) | ||
cax = fig.add_axes([0.82, 0.16, 0.01, 0.26]) | ||
fig.colorbar(axes[1, -1].collections[1], cax=cax, ticks=[-0.4, 0, 0.4]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
r""" | ||
Fourier transform | ||
================= | ||
The graph Fourier transform (:meth:`pygsp.graphs.Graph.gft`) transforms a | ||
signal from the vertex domain to the spectral domain. The smoother the signal | ||
(see :meth:`pygsp.graphs.Graph.dirichlet_energy`), the lower in the frequencies | ||
its energy is concentrated. | ||
""" | ||
|
||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
import pygsp as pg | ||
|
||
#plt.rc('font', family='Latin Modern Roman') | ||
plt.rc('text', usetex=True) | ||
plt.rc('text.latex', preamble=r'\usepackage{lmodern}') | ||
|
||
G = pg.graphs.Sensor(seed=42) | ||
G.compute_fourier_basis() | ||
|
||
scales = [10, 3, 0] | ||
limit = 0.32 | ||
|
||
fig, axes = plt.subplots(2, len(scales), figsize=(12, 4)) | ||
fig.subplots_adjust(hspace=0.5) | ||
|
||
x0 = np.random.RandomState(1).normal(size=G.N) | ||
for i, scale in enumerate(scales): | ||
g = pg.filters.Heat(G, scale) | ||
x = g.filter(x0).squeeze() | ||
x /= np.linalg.norm(x) | ||
x_hat = G.gft(x).squeeze() | ||
|
||
assert np.all((-limit < x) & (x < limit)) | ||
G.plot(x, limits=[-limit, limit], ax=axes[0, i]) | ||
axes[0, i].set_axis_off() | ||
axes[0, i].set_title('$x^T L x = {:.2f}$'.format(G.dirichlet_energy(x))) | ||
|
||
axes[1, i].plot(G.e, np.abs(x_hat), '.-') | ||
axes[1, i].set_xticks(range(0, 16, 4)) | ||
axes[1, i].set_xlabel(r'graph frequency $\lambda$') | ||
axes[1, i].set_ylim(-0.05, 0.95) | ||
|
||
axes[1, 0].set_ylabel(r'frequency content $\hat{x}(\lambda)$') | ||
|
||
# axes[0, 0].set_title(r'$x$: signal in the vertex domain') | ||
# axes[1, 0].set_title(r'$\hat{x}$: signal in the spectral domain') | ||
|
||
fig.tight_layout() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
r""" | ||
Heat diffusion on graphs | ||
======================== | ||
Solve the heat equation by filtering the initial conditions with the heat | ||
kernel. | ||
""" | ||
|
||
from os import path | ||
|
||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
import pygsp as pg | ||
|
||
#plt.rc('font', family='Latin Modern Roman') | ||
plt.rc('text', usetex=True) | ||
plt.rc('text.latex', preamble=r'\usepackage{lmodern}') | ||
|
||
n_side = 13 | ||
G = pg.graphs.Grid2d(n_side) | ||
G.compute_fourier_basis() | ||
|
||
sources = [ | ||
(n_side//4 * n_side) + (n_side//4), | ||
(n_side*3//4 * n_side) + (n_side*3//4), | ||
] | ||
x = np.zeros(G.n_vertices) | ||
x[sources] = 5 | ||
|
||
times = [0, 5, 10, 20] | ||
|
||
fig, axes = plt.subplots(2, len(times), figsize=(12, 5)) | ||
for i, t in enumerate(times): | ||
g = pg.filters.Heat(G, scale=t) | ||
title = fr'$\hat{{f}}({t}) = g_{{1,{t}}} \odot \hat{{f}}(0)$' | ||
g.plot(alpha=1, ax=axes[0, i], title=title) | ||
axes[0, i].set_xlabel(r'$\lambda$') | ||
# axes[0, i].set_ylabel(r'$g(\lambda)$') | ||
if i > 0: | ||
axes[0, i].set_ylabel('') | ||
y = g.filter(x) | ||
line, = axes[0, i].plot(G.e, G.gft(y)) | ||
labels = [fr'$\hat{{f}}({t})$', fr'$g_{{1,{t}}}$'] | ||
axes[0, i].legend([line, axes[0, i].lines[-3]], labels, loc='lower right') | ||
G.plot(y, edges=False, highlight=sources, ax=axes[1, i], title=fr'$f({t})$') | ||
axes[1, i].set_aspect('equal', 'box') | ||
axes[1, i].set_axis_off() | ||
|
||
fig.tight_layout() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
r""" | ||
Kernel localization | ||
=================== | ||
In classical signal processing, a filter can be translated in the vertex | ||
domain. We cannot do that on graphs. Instead, we can | ||
:meth:`~pygsp.filters.Filter.localize` a filter kernel. Note how on classic | ||
structures (like the ring), the localized kernel is the same everywhere, while | ||
it changes when localized on irregular graphs. | ||
""" | ||
|
||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
import pygsp as pg | ||
|
||
#plt.rc('font', family='Latin Modern Roman') | ||
plt.rc('text', usetex=True) | ||
plt.rc('text.latex', preamble=r'\usepackage{lmodern}') | ||
|
||
fig, axes = plt.subplots(2, 4, figsize=(10, 4)) | ||
|
||
graphs = [ | ||
pg.graphs.Ring(40), | ||
pg.graphs.Sensor(64, seed=42), | ||
] | ||
|
||
locations = [0, 10, 20] | ||
|
||
for graph, axs in zip(graphs, axes): | ||
graph.compute_fourier_basis() | ||
g = pg.filters.Heat(graph) | ||
g.plot(ax=axs[0], title='heat kernel') | ||
axs[0].set_xlabel(r'eigenvalues $\lambda$') | ||
axs[0].set_ylabel(r'$g(\lambda) = \exp \left( \frac{{-{}\lambda}}{{\lambda_{{max}}}} \right)$'.format(g.scale[0])) | ||
maximum = 0 | ||
for loc in locations: | ||
x = g.localize(loc) | ||
maximum = np.maximum(maximum, x.max()) | ||
for loc, ax in zip(locations, axs[1:]): | ||
graph.plot(g.localize(loc), limits=[0, maximum], highlight=loc, ax=ax, | ||
title=r'$g(L) \delta_{{{}}}$'.format(loc)) | ||
ax.set_axis_off() | ||
|
||
fig.tight_layout() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
r""" | ||
Random walks | ||
============ | ||
Probability of a random walker to be on any given vertex after a given number | ||
of steps starting from a given distribution. | ||
""" | ||
|
||
# sphinx_gallery_thumbnail_number = 2 | ||
|
||
import numpy as np | ||
from scipy import sparse | ||
from matplotlib import pyplot as plt | ||
import pygsp as pg | ||
|
||
#plt.rc('font', family='Latin Modern Roman') | ||
plt.rc('text', usetex=True) | ||
plt.rc('text.latex', preamble=r'\usepackage{lmodern}') | ||
|
||
N = 7 | ||
steps = [0, 1, 2, 3] | ||
|
||
graph = pg.graphs.Grid2d(N) | ||
delta = np.zeros(graph.N) | ||
delta[N//2*N + N//2] = 1 | ||
|
||
probability = sparse.diags(graph.dw**(-1)) @ graph.W | ||
|
||
fig, axes = plt.subplots(1, len(steps), figsize=(12, 3)) | ||
for step, ax in zip(steps, axes): | ||
state = delta @ probability**step | ||
graph.plot(state, ax=ax, title=r'$\delta P^{}$'.format(step)) | ||
ax.set_axis_off() | ||
|
||
fig.tight_layout() | ||
|
||
############################################################################### | ||
# Stationary distribution. | ||
|
||
graphs = [ | ||
pg.graphs.Ring(10), | ||
pg.graphs.Grid2d(5), | ||
pg.graphs.Comet(8, 4), | ||
pg.graphs.BarabasiAlbert(20, seed=42), | ||
] | ||
|
||
fig, axes = plt.subplots(1, len(graphs), figsize=(12, 3)) | ||
|
||
for graph, ax in zip(graphs, axes): | ||
|
||
if not hasattr(graph, 'coords'): | ||
graph.set_coordinates(seed=10) | ||
|
||
P = sparse.diags(graph.dw**(-1)) @ graph.W | ||
|
||
# e, u = np.linalg.eig(P.T.toarray()) | ||
# np.testing.assert_allclose(np.linalg.inv(u.T) @ np.diag(e) @ u.T, | ||
# P.toarray(), atol=1e-10) | ||
# np.testing.assert_allclose(np.abs(e[0]), 1) | ||
# stationary = np.abs(u.T[0]) | ||
|
||
e, u = sparse.linalg.eigs(P.T, k=1, which='LR') | ||
np.testing.assert_allclose(e, 1) | ||
stationary = np.abs(u).squeeze() | ||
assert np.all(stationary < 0.71) | ||
|
||
colorbar = False if type(graph) is pg.graphs.Ring else True | ||
graph.plot(stationary, colorbar=colorbar, ax=ax, title='$xP = x$') | ||
ax.set_axis_off() | ||
|
||
fig.tight_layout() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
r""" | ||
Wave propagation on graphs | ||
========================== | ||
Solve the wave equation by filtering the initial conditions with the wave | ||
kernel. | ||
""" | ||
|
||
from os import path | ||
|
||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
import pygsp as pg | ||
|
||
#plt.rc('font', family='Latin Modern Roman') | ||
plt.rc('text', usetex=True) | ||
plt.rc('text.latex', preamble=r'\usepackage{lmodern}') | ||
|
||
n_side = 13 | ||
G = pg.graphs.Grid2d(n_side) | ||
G.compute_fourier_basis() | ||
|
||
sources = [ | ||
(n_side//4 * n_side) + (n_side//4), | ||
(n_side*3//4 * n_side) + (n_side*3//4), | ||
] | ||
x = np.zeros(G.n_vertices) | ||
x[sources] = 5 | ||
|
||
times = [0, 5, 10, 20] | ||
|
||
fig, axes = plt.subplots(2, len(times), figsize=(12, 5)) | ||
for i, t in enumerate(times): | ||
g = pg.filters.Wave(G, time=t, speed=1) | ||
title = fr'$\hat{{f}}({t}) = g_{{1,{t}}} \odot \hat{{f}}(0)$' | ||
g.plot(alpha=1, ax=axes[0, i], title=title) | ||
axes[0, i].set_xlabel(r'$\lambda$') | ||
# axes[0, i].set_ylabel(r'$g(\lambda)$') | ||
if i > 0: | ||
axes[0, i].set_ylabel('') | ||
y = g.filter(x) | ||
line, = axes[0, i].plot(G.e, G.gft(y)) | ||
labels = [fr'$\hat{{f}}({t})$', fr'$g_{{1,{t}}}$'] | ||
axes[0, i].legend([line, axes[0, i].lines[-3]], labels, loc='lower right') | ||
G.plot(y, edges=False, highlight=sources, ax=axes[1, i], title=fr'$f({t})$') | ||
axes[1, i].set_aspect('equal', 'box') | ||
axes[1, i].set_axis_off() | ||
|
||
fig.tight_layout() |