diff --git a/docs/autogen.py b/docs/autogen.py
index b123d5d8..67424f95 100644
--- a/docs/autogen.py
+++ b/docs/autogen.py
@@ -48,7 +48,9 @@
"functions": [],
"methods": [],
"classes": [
+ layers.SRCPool,
layers.DiffPool,
+ layers.LaPool,
layers.MinCutPool,
layers.SAGPool,
layers.TopKPool,
diff --git a/docs/templates/creating-layer.md b/docs/templates/creating-layer.md
index 5ad15a66..33e0f9e6 100644
--- a/docs/templates/creating-layer.md
+++ b/docs/templates/creating-layer.md
@@ -53,16 +53,17 @@ def call(self, inputs):
```
Then, we implement the `message` function.
-The `get_i` and `get_j` built-in methods can be used to automatically access either side of the edges \(i \leftarrow j\). For instance, we can use `get_j` to access the node features `x[j]` of all neighbors `j`.
+The `get_sources` and `get_targets` built-in methods can be used to automatically retrieve the node attributes of nodes that are sending (sources) or receiving (targets) a message.
+For instance, we can use `get_targets` to access the node features `x[j]` of all neighbors `j`.
-If you need direct access to the edge indices, you can use the `index_i` and `index_j` attributes.
+If you need direct access to the edge indices, you can use the `index_sources` and `index_targets` attributes.
-In this case, we only need to get the neighbors' features and return them:
+In this case, we only need to get the neighbors' features and return them:
```py
def message(self, x):
# Get the node features of all neighbors
- return self.get_j(x)
+ return self.get_sources(x)
```
Then, we define an aggregation function for the messages. We can use a simple average of the nodes:
@@ -70,11 +71,12 @@ Then, we define an aggregation function for the messages. We can use a simple av
```py
from spektral.layers.ops import scatter_mean
+
def aggregate(self, messages):
- return scatter_mean(messages, self.index_i, self.n_nodes)
+ return scatter_mean(messages, self.index_targets, self.n_nodes)
```
-**Note**: `n_nodes` is computed dynamically at the start of propagation, exactly like `index_i`.
+**Note**: `n_nodes` is computed dynamically at the start of propagation, exactly like `index_targets`.
Since there are a few common aggregation functions that are often used in the literature, you can also skip the implementation of this method and simply pass a special keyword to the `__init__()` method of the superclass:
diff --git a/examples/node_prediction/citation_gcn.py b/examples/node_prediction/citation_gcn.py
index e5016bf0..71354904 100644
--- a/examples/node_prediction/citation_gcn.py
+++ b/examples/node_prediction/citation_gcn.py
@@ -40,7 +40,7 @@ def mask_to_weights(mask):
for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)
)
-model = GCN(n_labels=dataset.n_labels, n_input_channels=dataset.n_node_features)
+model = GCN(n_labels=dataset.n_labels)
model.compile(
optimizer=Adam(learning_rate),
loss=CategoricalCrossentropy(reduction="sum"),
diff --git a/examples/node_prediction/citation_gcn_custom.py b/examples/node_prediction/citation_gcn_custom.py
index 7563a398..6b5b4a4d 100644
--- a/examples/node_prediction/citation_gcn_custom.py
+++ b/examples/node_prediction/citation_gcn_custom.py
@@ -22,7 +22,7 @@
x, a, y = graph.x, graph.a, graph.y
mask_tr, mask_va, mask_te = dataset.mask_tr, dataset.mask_va, dataset.mask_te
-model = GCN(n_labels=dataset.n_labels, n_input_channels=dataset.n_node_features)
+model = GCN(n_labels=dataset.n_labels)
optimizer = Adam(lr=1e-2)
loss_fn = CategoricalCrossentropy()
diff --git a/examples/other/explain_node_predictions.py b/examples/other/explain_node_predictions.py
index 63794573..2b1ce157 100644
--- a/examples/other/explain_node_predictions.py
+++ b/examples/other/explain_node_predictions.py
@@ -37,7 +37,7 @@ def mask_to_weights(mask):
for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)
)
-model = GCN(n_labels=dataset.n_labels, n_input_channels=dataset.n_node_features)
+model = GCN(n_labels=dataset.n_labels)
model.compile(
optimizer=Adam(learning_rate),
loss=CategoricalCrossentropy(reduction="sum"),
diff --git a/examples/other/node_clustering_mincut.py b/examples/other/node_clustering_mincut.py
index 16a016e3..c47fca82 100644
--- a/examples/other/node_clustering_mincut.py
+++ b/examples/other/node_clustering_mincut.py
@@ -57,7 +57,7 @@ def train_step(inputs):
a_in = Input(shape=(None,), name="A_in", sparse=True)
x_1 = GCSConv(16, activation="elu")([x_in, a_in])
-x_1, a_1, s_1 = MinCutPool(n_clusters, return_mask=True)([x_1, a_in])
+x_1, a_1, s_1 = MinCutPool(n_clusters, return_selection=True)([x_1, a_in])
model = Model([x_in, a_in], [x_1, s_1])
diff --git a/setup.py b/setup.py
index 2b90290f..51bbb143 100644
--- a/setup.py
+++ b/setup.py
@@ -5,18 +5,18 @@
setup(
name="spektral",
- version="1.0.8",
+ version="1.0.9",
packages=find_packages(),
install_requires=[
"joblib",
"lxml",
"networkx",
- "numpy<1.20",
+ "numpy",
"pandas",
"requests",
"scikit-learn",
"scipy",
- "tensorflow>=2.1.0",
+ "tensorflow>=2.2.0",
"tqdm",
],
url="https://github.com/danielegrattarola/spektral",
@@ -27,8 +27,8 @@
long_description=long_description,
long_description_content_type="text/markdown",
classifiers=[
- "Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
],
)
diff --git a/spektral/data/loaders.py b/spektral/data/loaders.py
index c6b70d54..41301ae7 100644
--- a/spektral/data/loaders.py
+++ b/spektral/data/loaders.py
@@ -3,6 +3,7 @@
from spektral.data.utils import (
batch_generator,
+ collate_labels_batch,
collate_labels_disjoint,
get_spec,
prepend_none,
@@ -78,10 +79,10 @@ def train_step(inputs, target):
**Arguments**
- `dataset`: a `spektral.data.Dataset` object;
- - `batch_size`: size of the mini-batches;
- - `epochs`: number of epochs to iterate over the dataset. By default (`None`)
+ - `batch_size`: int, size of the mini-batches;
+ - `epochs`: int, number of epochs to iterate over the dataset. By default (`None`)
iterates indefinitely;
- - `shuffle`: whether to shuffle the dataset at the start of each epoch.
+ - `shuffle`: bool, whether to shuffle the dataset at the start of each epoch.
"""
def __init__(self, dataset, batch_size=1, epochs=None, shuffle=True):
@@ -178,11 +179,10 @@ class SingleLoader(Loader):
**Arguments**
- `dataset`: a `spektral.data.Dataset` object with only one graph;
- - `epochs`: number of epochs to iterate over the dataset. By default (`None`)
+ - `epochs`: int, number of epochs to iterate over the dataset. By default (`None`)
iterates indefinitely;
- - `shuffle`: whether to shuffle the data at the start of each epoch;
- - `sample_weights`: if given, these will be appended to the output
- automatically.
+ - `shuffle`: bool, whether to shuffle the data at the start of each epoch;
+ - `sample_weights`: Numpy array, will be appended to the output automatically.
**Output**
@@ -197,9 +197,8 @@ class SingleLoader(Loader):
- `e`: same as `dataset[0].e`;
`labels` is the same as `dataset[0].y`.
- `sample_weights` is the same object passed to the constructor.
-
+ `sample_weights` is the same array passed when creating the loader.
"""
def __init__(self, dataset, epochs=None, sample_weights=None):
@@ -262,6 +261,8 @@ class DisjointLoader(Loader):
**Arguments**
- `dataset`: a graph Dataset;
+ - `node_level`: bool, if `True` stack the labels vertically for node-level
+ prediction;
- `batch_size`: size of the mini-batches;
- `epochs`: number of epochs to iterate over the dataset. By default (`None`)
iterates indefinitely;
@@ -321,7 +322,7 @@ def tf_signature(self):
Adjacency matrix has shape [n_nodes, n_nodes]
Node features have shape [n_nodes, n_node_features]
Edge features have shape [n_edges, n_edge_features]
- Targets have shape [..., n_labels]
+ Targets have shape [*, n_labels]
"""
signature = self.dataset.signature
if "y" in signature:
@@ -347,33 +348,40 @@ class BatchLoader(Loader):
If `n_max` is the number of nodes of the biggest graph in the batch, then
the padding consist of adding zeros to the node features, adjacency matrix,
and edge attributes of each graph so that they have shapes
- `(n_max, n_node_features)`, `(n_max, n_max)`, and
- `(n_max, n_max, n_edge_features)` respectively.
+ `[n_max, n_node_features]`, `[n_max, n_max]`, and
+ `[n_max, n_max, n_edge_features]` respectively.
The zero-padding is done batch-wise, which saves up memory at the cost of
more computation. If latency is an issue but memory isn't, or if the
dataset has graphs with a similar number of nodes, you can use
- the `PackedBatchLoader` that first zero-pads all the dataset and then
+ the `PackedBatchLoader` that zero-pads all the dataset once and then
iterates over it.
Note that the adjacency matrix and edge attributes are returned as dense
- arrays (mostly due to the lack of support for sparse tensor operations for
- rank >2).
+ arrays.
- Only graph-level labels are supported with this loader (i.e., labels are not
- zero-padded because they are assumed to have no "node" dimensions).
+ if `mask=True`, node attributes will be extended with a binary mask that indicates
+ valid nodes (the last feature of each node will be 1 if the node was originally in
+ the graph and 0 if it is a fake node added by zero-padding).
+
+ Use this flag in conjunction with layers.base.GraphMasking to start the propagation
+ of masks in a model (necessary for node-level prediction and models that use a
+ dense pooling layer like DiffPool or MinCutPool).
+
+ If `node_level=False`, the labels are interpreted as graph-level labels and
+ are returned as an array of shape `[batch, n_labels]`.
+ If `node_level=True`, then the labels are padded along the node dimension and are
+ returned as an array of shape `[batch, n_max, n_labels]`.
**Arguments**
- `dataset`: a graph Dataset;
- - `mask`: if True, node attributes will be extended with a binary mask that
- indicates valid nodes (the last feature of each node will be 1 if the node is valid
- and 0 otherwise). Use this flag in conjunction with layers.base.GraphMasking to
- start the propagation of masks in a model.
- - `batch_size`: size of the mini-batches;
- - `epochs`: number of epochs to iterate over the dataset. By default (`None`)
+ - `mask`: bool, whether to add a mask to the node features;
+ - `batch_size`: int, size of the mini-batches;
+ - `epochs`: int, number of epochs to iterate over the dataset. By default (`None`)
iterates indefinitely;
- - `shuffle`: whether to shuffle the data at the start of each epoch.
+ - `shuffle`: bool, whether to shuffle the data at the start of each epoch;
+ - `node_level`: bool, if `True` pad the labels along the node dimension;
**Output**
@@ -385,11 +393,22 @@ class BatchLoader(Loader):
- `a`: adjacency matrices of shape `[batch, n_max, n_max]`;
- `e`: edge attributes of shape `[batch, n_max, n_max, n_edge_features]`.
- `labels` have shape `[batch, n_labels]`.
+ `labels` have shape `[batch, n_labels]` if `node_level=False` or
+ `[batch, n_max, n_labels]` otherwise.
"""
- def __init__(self, dataset, mask=False, batch_size=1, epochs=None, shuffle=True):
+ def __init__(
+ self,
+ dataset,
+ mask=False,
+ batch_size=1,
+ epochs=None,
+ shuffle=True,
+ node_level=False,
+ ):
self.mask = mask
+ self.node_level = node_level
+ self.signature = dataset.signature
super().__init__(dataset, batch_size=batch_size, epochs=epochs, shuffle=shuffle)
def collate(self, batch):
@@ -397,7 +416,7 @@ def collate(self, batch):
y = packed.pop("y_list", None)
if y is not None:
- y = np.array(y)
+ y = collate_labels_batch(y, node_level=self.node_level)
output = to_batch(**packed, mask=self.mask)
output = sp_matrices_to_sp_tensors(output)
@@ -415,12 +434,13 @@ def tf_signature(self):
Adjacency matrix has shape [batch, n_nodes, n_nodes]
Node features have shape [batch, n_nodes, n_node_features]
Edge features have shape [batch, n_nodes, n_nodes, n_edge_features]
- Targets have shape [batch, ..., n_labels]
+ Labels have shape [batch, n_labels]
"""
- signature = self.dataset.signature
+ signature = self.signature
for k in signature:
signature[k]["shape"] = prepend_none(signature[k]["shape"])
- if "x" in signature:
+ if "x" in signature and self.mask:
+ # In case we have a mask, the mask is concatenated to the features
signature["x"]["shape"] = signature["x"]["shape"][:-1] + (
signature["x"]["shape"][-1] + 1,
)
@@ -430,6 +450,9 @@ def tf_signature(self):
if "e" in signature:
# Edge attributes have an extra None dimension in batch mode
signature["e"]["shape"] = prepend_none(signature["e"]["shape"])
+ if "y" in signature and self.node_level:
+ # Node labels have an extra None dimension
+ signature["y"]["shape"] = prepend_none(signature["y"]["shape"])
return to_tf_signature(signature)
@@ -454,10 +477,12 @@ class PackedBatchLoader(BatchLoader):
**Arguments**
- `dataset`: a graph Dataset;
- - `batch_size`: size of the mini-batches;
- - `epochs`: number of epochs to iterate over the dataset. By default (`None`)
+ - `mask`: bool, whether to add a mask to the node features;
+ - `batch_size`: int, size of the mini-batches;
+ - `epochs`: int, number of epochs to iterate over the dataset. By default (`None`)
iterates indefinitely;
- - `shuffle`: whether to shuffle the data at the start of each epoch.
+ - `shuffle`: bool, whether to shuffle the data at the start of each epoch;
+ - `node_level`: bool, if `True` pad the labels along the node dimension;
**Output**
@@ -469,12 +494,26 @@ class PackedBatchLoader(BatchLoader):
- `a`: adjacency matrices of shape `[batch, n_max, n_max]`;
- `e`: edge attributes of shape `[batch, n_max, n_max, n_edge_features]`.
- `labels` have shape `[batch, ..., n_labels]`.
+ `labels` have shape `[batch, n_labels]` if `node_level=False` or
+ `[batch, n_max, n_labels]` otherwise.
"""
- def __init__(self, dataset, mask=False, batch_size=1, epochs=None, shuffle=True):
+ def __init__(
+ self,
+ dataset,
+ mask=False,
+ batch_size=1,
+ epochs=None,
+ shuffle=True,
+ node_level=False,
+ ):
super().__init__(
- dataset, mask=mask, batch_size=batch_size, epochs=epochs, shuffle=shuffle
+ dataset,
+ mask=mask,
+ batch_size=batch_size,
+ epochs=epochs,
+ shuffle=shuffle,
+ node_level=node_level,
)
# Drop the Dataset container and work on packed tensors directly
@@ -482,9 +521,8 @@ def __init__(self, dataset, mask=False, batch_size=1, epochs=None, shuffle=True)
y = packed.pop("y_list", None)
if y is not None:
- y = np.array(y)
+ y = collate_labels_batch(y, node_level=self.node_level)
- self.signature = dataset.signature
self.dataset = to_batch(**packed, mask=mask)
if y is not None:
self.dataset += (y,)
@@ -501,29 +539,6 @@ def collate(self, batch):
else:
return batch[:-1], batch[-1]
- def tf_signature(self):
- """
- Adjacency matrix has shape [batch, n_nodes, n_nodes]
- Node features have shape [batch, n_nodes, n_node_features]
- Edge features have shape [batch, n_nodes, n_nodes, n_edge_features]
- Targets have shape [batch, ..., n_labels]
- """
- signature = self.signature
- for k in signature:
- signature[k]["shape"] = prepend_none(signature[k]["shape"])
- if "x" in signature:
- signature["x"]["shape"] = signature["x"]["shape"][:-1] + (
- signature["x"]["shape"][-1] + 1,
- )
- if "a" in signature:
- # Adjacency matrix in batch mode is dense
- signature["a"]["spec"] = tf.TensorSpec
- if "e" in signature:
- # Edge attributes have an extra None dimension in batch mode
- signature["e"]["shape"] = prepend_none(signature["e"]["shape"])
-
- return to_tf_signature(signature)
-
@property
def steps_per_epoch(self):
if len(self.dataset) > 0:
@@ -544,10 +559,10 @@ class MixedLoader(Loader):
**Arguments**
- `dataset`: a graph Dataset;
- - `batch_size`: size of the mini-batches;
- - `epochs`: number of epochs to iterate over the dataset. By default (`None`)
+ - `batch_size`: int, size of the mini-batches;
+ - `epochs`: int, number of epochs to iterate over the dataset. By default (`None`)
iterates indefinitely;
- - `shuffle`: whether to shuffle the data at the start of each epoch.
+ - `shuffle`: bool, whether to shuffle the data at the start of each epoch.
**Output**
diff --git a/spektral/data/utils.py b/spektral/data/utils.py
index 29d4c0e7..7e096867 100644
--- a/spektral/data/utils.py
+++ b/spektral/data/utils.py
@@ -287,3 +287,11 @@ def collate_labels_disjoint(y_list, node_level=False):
if len(np.shape(y_list[0])) == 0:
y_list = [np.array([y_]) for y_ in y_list]
return np.array(y_list)
+
+
+def collate_labels_batch(y_list, node_level=False):
+ if node_level:
+ n_max = max([x.shape[0] for x in y_list])
+ return pad_jagged_array(y_list, (n_max, -1))
+ else:
+ return np.array(y_list)
diff --git a/spektral/layers/convolutional/agnn_conv.py b/spektral/layers/convolutional/agnn_conv.py
index 1790bd5f..18823cb6 100644
--- a/spektral/layers/convolutional/agnn_conv.py
+++ b/spektral/layers/convolutional/agnn_conv.py
@@ -68,14 +68,14 @@ def call(self, inputs, **kwargs):
return output
def message(self, x, x_norm=None):
- x_j = self.get_j(x)
- x_norm_i = self.get_i(x_norm)
- x_norm_j = self.get_j(x_norm)
+ x_j = self.get_sources(x)
+ x_norm_i = self.get_targets(x_norm)
+ x_norm_j = self.get_sources(x_norm)
alpha = self.beta * tf.reduce_sum(x_norm_i * x_norm_j, axis=-1)
if len(alpha.shape) == 2:
alpha = tf.transpose(alpha) # For mixed mode
- alpha = ops.unsorted_segment_softmax(alpha, self.index_i, self.n_nodes)
+ alpha = ops.unsorted_segment_softmax(alpha, self.index_targets, self.n_nodes)
if len(alpha.shape) == 2:
alpha = tf.transpose(alpha) # For mixed mode
alpha = alpha[..., None]
diff --git a/spektral/layers/convolutional/crystal_conv.py b/spektral/layers/convolutional/crystal_conv.py
index 1dad7d59..d865fc0e 100644
--- a/spektral/layers/convolutional/crystal_conv.py
+++ b/spektral/layers/convolutional/crystal_conv.py
@@ -34,12 +34,10 @@ class CrystalConv(MessagePassing):
**Output**
- - Node features with the same shape of the input, but the last dimension
- changed to `channels`.
+ - Node features with the same shape of the input.
**Arguments**
- - `channels`: integer, number of output channels;
- `activation`: activation function;
- `use_bias`: bool, add a bias vector to the output;
- `kernel_initializer`: initializer for the weights;
@@ -53,7 +51,6 @@ class CrystalConv(MessagePassing):
def __init__(
self,
- channels,
aggregate="sum",
activation=None,
use_bias=True,
@@ -79,7 +76,6 @@ def __init__(
bias_constraint=bias_constraint,
**kwargs
)
- self.channels = channels
def build(self, input_shape):
assert len(input_shape) >= 2
@@ -92,14 +88,15 @@ def build(self, input_shape):
bias_constraint=self.bias_constraint,
dtype=self.dtype,
)
- self.dense_f = Dense(self.channels, activation="sigmoid", **layer_kwargs)
- self.dense_s = Dense(self.channels, activation=self.activation, **layer_kwargs)
+ channels = input_shape[0][-1]
+ self.dense_f = Dense(channels, activation="sigmoid", **layer_kwargs)
+ self.dense_s = Dense(channels, activation=self.activation, **layer_kwargs)
self.built = True
def message(self, x, e=None):
- x_i = self.get_i(x)
- x_j = self.get_j(x)
+ x_i = self.get_targets(x)
+ x_j = self.get_sources(x)
to_concat = [x_i, x_j]
if e is not None:
@@ -111,7 +108,3 @@ def message(self, x, e=None):
def update(self, embeddings, x=None):
return x + embeddings
-
- @property
- def config(self):
- return {"channels": self.channels}
diff --git a/spektral/layers/convolutional/ecc_conv.py b/spektral/layers/convolutional/ecc_conv.py
index be1105ee..c1bed5c3 100644
--- a/spektral/layers/convolutional/ecc_conv.py
+++ b/spektral/layers/convolutional/ecc_conv.py
@@ -169,11 +169,11 @@ def call(self, inputs, mask=None):
if mode == modes.MIXED:
target_shape = (tf.shape(x)[0],) + target_shape
kernel = tf.reshape(kernel_network, target_shape)
- index_i = a.indices[:, 1]
- index_j = a.indices[:, 0]
- messages = tf.gather(x, index_j, axis=-2)
+ index_targets = a.indices[:, 1]
+ index_sources = a.indices[:, 0]
+ messages = tf.gather(x, index_sources, axis=-2)
messages = tf.einsum("...ab,...abc->...ac", messages, kernel)
- output = ops.scatter_sum(messages, index_i, N)
+ output = ops.scatter_sum(messages, index_targets, N)
if self.root:
output += K.dot(x, self.root_kernel)
diff --git a/spektral/layers/convolutional/edge_conv.py b/spektral/layers/convolutional/edge_conv.py
index 3af8673f..84c47d8d 100644
--- a/spektral/layers/convolutional/edge_conv.py
+++ b/spektral/layers/convolutional/edge_conv.py
@@ -115,8 +115,8 @@ def build(self, input_shape):
self.built = True
def message(self, x, **kwargs):
- x_i = self.get_i(x)
- x_j = self.get_j(x)
+ x_i = self.get_targets(x)
+ x_j = self.get_sources(x)
return self.mlp(K.concatenate((x_i, x_j - x_i)))
@property
diff --git a/spektral/layers/convolutional/gat_conv.py b/spektral/layers/convolutional/gat_conv.py
index 1f3d1d8f..ad338d03 100644
--- a/spektral/layers/convolutional/gat_conv.py
+++ b/spektral/layers/convolutional/gat_conv.py
@@ -240,7 +240,8 @@ def _call_dense(self, x, a):
attn_coef = attn_for_self + attn_for_neighs
attn_coef = tf.nn.leaky_relu(attn_coef, alpha=0.2)
- mask = -10e9 * (1.0 - a)
+ mask = tf.where(a == 0.0, -10e9, 0.0)
+ mask = tf.cast(mask, dtype=attn_coef.dtype)
attn_coef += mask[..., None, :]
attn_coef = tf.nn.softmax(attn_coef, axis=-1)
attn_coef_drop = self.dropout(attn_coef)
diff --git a/spektral/layers/convolutional/message_passing.py b/spektral/layers/convolutional/message_passing.py
index 398b97a4..5e0e404d 100644
--- a/spektral/layers/convolutional/message_passing.py
+++ b/spektral/layers/convolutional/message_passing.py
@@ -53,9 +53,11 @@ class MessagePassing(Layer):
Computes messages, equivalent to \(\phi\) in the definition.
Any extra keyword argument of this function will be populated by
`propagate()` if a matching keyword is found.
- Use `self.get_i()` and `self.get_j()` to gather the elements using the
- indices `i` or `j` of the adjacency matrix. Equivalently, you can access
- the indices themselves via the `index_i` and `index_j` attributes.
+ The `get_sources` and `get_targets` built-in methods can be used to automatically
+ retrieve the node attributes of nodes that are sending (sources) or receiving
+ (targets) a message.
+ If you need direct access to the edge indices, you can use the `index_sources` and
+ `index_targets` attributes.
```python
aggregate(messages, **kwargs)
@@ -110,8 +112,8 @@ def build(self, input_shape):
def propagate(self, x, a, e=None, **kwargs):
self.n_nodes = tf.shape(x)[-2]
- self.index_i = a.indices[:, 1]
- self.index_j = a.indices[:, 0]
+ self.index_targets = a.indices[:, 1] # Nodes receiving the message
+ self.index_sources = a.indices[:, 0] # Nodes sending the message (ie neighbors)
# Message
msg_kwargs = self.get_kwargs(x, a, e, self.msg_signature, kwargs)
@@ -128,19 +130,19 @@ def propagate(self, x, a, e=None, **kwargs):
return output
def message(self, x, **kwargs):
- return self.get_j(x)
+ return self.get_sources(x)
def aggregate(self, messages, **kwargs):
- return self.agg(messages, self.index_i, self.n_nodes)
+ return self.agg(messages, self.index_targets, self.n_nodes)
def update(self, embeddings, **kwargs):
return embeddings
- def get_i(self, x):
- return tf.gather(x, self.index_i, axis=-2)
+ def get_targets(self, x):
+ return tf.gather(x, self.index_targets, axis=-2)
- def get_j(self, x):
- return tf.gather(x, self.index_j, axis=-2)
+ def get_sources(self, x):
+ return tf.gather(x, self.index_sources, axis=-2)
def get_kwargs(self, x, a, e, signature, kwargs):
output = {}
diff --git a/spektral/layers/convolutional/tag_conv.py b/spektral/layers/convolutional/tag_conv.py
index 7304c42f..36aa33af 100644
--- a/spektral/layers/convolutional/tag_conv.py
+++ b/spektral/layers/convolutional/tag_conv.py
@@ -107,7 +107,7 @@ def call(self, inputs, **kwargs):
return self.linear(output)
def message(self, x, edge_weight=None):
- x_j = self.get_j(x)
+ x_j = self.get_sources(x)
return edge_weight[:, None] * x_j
@property
diff --git a/spektral/layers/convolutional/xenet_conv.py b/spektral/layers/convolutional/xenet_conv.py
index de911b5a..d3925a2a 100644
--- a/spektral/layers/convolutional/xenet_conv.py
+++ b/spektral/layers/convolutional/xenet_conv.py
@@ -1,7 +1,7 @@
from collections.abc import Iterable
import tensorflow as tf
-from tensorflow.keras.layers import Concatenate, Dense, Multiply, PReLU, ReLU, Reshape
+from tensorflow.keras.layers import Concatenate, Dense, Multiply, PReLU, ReLU
from tensorflow.python.ops import gen_sparse_ops
from spektral.layers.convolutional.conv import Conv
@@ -149,8 +149,8 @@ def call(self, inputs, **kwargs):
return x_out, e_out
def message(self, x, e=None):
- x_i = self.get_i(x) # Features of self
- x_j = self.get_j(x) # Features of neighbours
+ x_i = self.get_targets(x) # Features of self
+ x_j = self.get_sources(x) # Features of neighbours
# Features of outgoing edges are simply the edge features
e_ij = e
@@ -163,7 +163,7 @@ def message(self, x, e=None):
# tf.transpose(E, perm=(1, 0, 2))
# where E has shape (N, N, S).
reorder_idx = gen_sparse_ops.sparse_reorder(
- tf.stack([self.index_i, self.index_j], axis=-1),
+ tf.stack([self.index_targets, self.index_sources], axis=-1),
tf.range(tf.shape(e)[0]),
(self.n_nodes, self.n_nodes),
)[1]
@@ -181,20 +181,16 @@ def message(self, x, e=None):
return stack_ij
def aggregate(self, messages, x=None):
- # Note: messages == stack_ij
if self.attention:
incoming_att = self.incoming_att_sigmoid(messages)
incoming = self.incoming_att_multiply([incoming_att, messages])
- incoming = self.agg(incoming, self.index_i, self.n_nodes)
+ incoming = self.agg(incoming, self.index_targets, self.n_nodes)
outgoing_att = self.outgoing_att_sigmoid(messages)
outgoing = self.outgoing_att_multiply([outgoing_att, messages])
- outgoing = self.agg(outgoing, self.index_j, self.n_nodes)
+ outgoing = self.agg(outgoing, self.index_sources, self.n_nodes)
else:
- # The equivalent numpy notation for these operations is:
- # incoming[i] = np.sum(stack_ij[self.index_i == i])
- # outgoing[j] = np.sum(stack_ij[self.index_j == j])
- incoming = self.agg(messages, self.index_i, self.n_nodes)
- outgoing = self.agg(messages, self.index_j, self.n_nodes)
+ incoming = self.agg(messages, self.index_targets, self.n_nodes)
+ outgoing = self.agg(messages, self.index_sources, self.n_nodes)
return tf.concat([x, incoming, outgoing], axis=-1), messages
diff --git a/spektral/layers/pooling/__init__.py b/spektral/layers/pooling/__init__.py
index eee5c057..af7f1240 100644
--- a/spektral/layers/pooling/__init__.py
+++ b/spektral/layers/pooling/__init__.py
@@ -7,6 +7,8 @@
GlobalSumPool,
SortPool,
)
+from .la_pool import LaPool
from .mincut_pool import MinCutPool
from .sag_pool import SAGPool
+from .src import SRCPool
from .topk_pool import TopKPool
diff --git a/spektral/layers/pooling/diff_pool.py b/spektral/layers/pooling/diff_pool.py
index a546c980..e644a7d5 100644
--- a/spektral/layers/pooling/diff_pool.py
+++ b/spektral/layers/pooling/diff_pool.py
@@ -3,11 +3,10 @@
from tensorflow.keras import backend as K
from spektral.layers import ops
-from spektral.layers.ops import modes
-from spektral.layers.pooling.pool import Pool
+from spektral.layers.pooling.src import SRCPool
-class DiffPool(Pool):
+class DiffPool(SRCPool):
r"""
A DiffPool layer from the paper
@@ -16,53 +15,53 @@ class DiffPool(Pool):
**Mode**: batch.
- This layer computes a soft clustering \(\S\) of the input graphs using a GNN,
- and reduces graphs as follows:
+ This layer learns a soft clustering of the input graph as follows:
$$
\begin{align}
\S &= \textrm{GNN}_{embed}(\A, \X); \\
\Z &= \textrm{GNN}_{pool}(\A, \X); \\
+ \X' &= \S^\top \Z; \\
\A' &= \S^\top \A \S; \\
- \X' &= \S^\top \Z
\end{align}
$$
where:
$$
\textrm{GNN}_{\square}(\A, \X) = \D^{-1/2} \A \D^{-1/2} \X \W_{\square}.
$$
- The number of output channels of \(\textrm{GNN}_{embed}\) is controlled by
- the `channels` parameter.
+ The number of output channels of \(\textrm{GNN}_{embed}\) is controlled by the
+ `channels` parameter.
- Two auxiliary loss terms are also added to the model: the _link prediction
- loss_
+ Two auxiliary loss terms are also added to the model: the link prediction loss
$$
L_{LP} = \big\| \A - \S\S^\top \big\|_F
$$
- and the _entropy loss_
+ and the entropy loss
$$
L_{E} - \frac{1}{N} \sum\limits_{i = 1}^{N} \S \log (\S).
$$
- The layer can be used without a supervised loss, to compute node clustering
- simply by minimizing the two auxiliary losses.
+ The layer can be used without a supervised loss to compute node clustering by
+ minimizing the two auxiliary losses.
**Input**
- - Node features of shape `([batch], n_nodes, n_node_features)`;
- - Adjacency matrix of shape `([batch], n_nodes, n_nodes)`;
+ - Node features of shape `(batch, n_nodes_in, n_node_features)`;
+ - Adjacency matrix of shape `(batch, n_nodes_in, n_nodes_in)`;
**Output**
- - Reduced node features of shape `([batch], K, channels)`;
- - Reduced adjacency matrix of shape `([batch], K, K)`;
- - If `return_mask=True`, the soft clustering matrix of shape `([batch], n_nodes, K)`.
+ - Reduced node features of shape `(batch, n_nodes_out, channels)`;
+ - Reduced adjacency matrix of shape `(batch, n_nodes_out, n_nodes_out)`;
+ - If `return_selection=True`, the selection matrix of shape
+ `(batch, n_nodes_in, n_nodes_out)`.
**Arguments**
- `k`: number of output nodes;
- - `channels`: number of output channels (if None, the number of output
- channels is assumed to be the same as the input);
- - `return_mask`: boolean, whether to return the cluster assignment matrix;
+ - `channels`: number of output channels (if `None`, the number of output channels is
+ the same as the input);
+ - `return_selection`: boolean, whether to return the selection matrix;
+ - `activation`: activation to apply after reduction;
- `kernel_initializer`: initializer for the weights;
- `kernel_regularizer`: regularization applied to the weights;
- `kernel_constraint`: constraint applied to the weights;
@@ -72,16 +71,15 @@ def __init__(
self,
k,
channels=None,
- return_mask=False,
+ return_selection=False,
activation=None,
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
kernel_constraint=None,
**kwargs
):
-
super().__init__(
- activation=activation,
+ return_selection=return_selection,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
kernel_constraint=kernel_constraint,
@@ -89,16 +87,14 @@ def __init__(
)
self.k = k
self.channels = channels
- self.return_mask = return_mask
+ self.activation = activations.get(activation)
def build(self, input_shape):
- F = input_shape[0][-1]
-
+ in_channels = input_shape[0][-1]
if self.channels is None:
- self.channels = F
-
+ self.channels = in_channels
self.kernel_emb = self.add_weight(
- shape=(F, self.channels),
+ shape=(in_channels, self.channels),
name="kernel_emb",
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
@@ -106,82 +102,80 @@ def build(self, input_shape):
)
self.kernel_pool = self.add_weight(
- shape=(F, self.k),
+ shape=(in_channels, self.k),
name="kernel_pool",
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
)
-
super().build(input_shape)
def call(self, inputs, mask=None):
- X, A = inputs
-
- N = K.shape(A)[-1]
- # Check if the layer is operating in mixed or batch mode
- mode = ops.autodetect_mode(X, A)
- self.reduce_loss = mode in (modes.MIXED, modes.BATCH)
+ x, a, i = self.get_inputs(inputs)
- # Get normalized adjacency
- if K.is_sparse(A):
- I_ = tf.sparse.eye(N, dtype=A.dtype)
- A_ = tf.sparse.add(A, I_)
+ # Graph filter for GNNs
+ if K.is_sparse(a):
+ i_n = tf.sparse.eye(self.n_nodes, dtype=a.dtype)
+ a_ = tf.sparse.add(a, i_n)
else:
- I_ = tf.eye(N, dtype=A.dtype)
- A_ = A + I_
- fltr = ops.normalize_A(A_)
-
- # Node embeddings
- Z = K.dot(X, self.kernel_emb)
- Z = ops.modal_dot(fltr, Z)
- if self.activation is not None:
- Z = self.activation(Z)
-
- # Compute cluster assignment matrix
- S = K.dot(X, self.kernel_pool)
- S = ops.modal_dot(fltr, S)
- S = activations.softmax(S, axis=-1) # softmax applied row-wise
+ i_n = tf.eye(self.n_nodes, dtype=a.dtype)
+ a_ = a + i_n
+ fltr = ops.normalize_A(a_)
+
+ output = self.pool(x, a, i, fltr=fltr, mask=mask)
+ return output
+
+ def select(self, x, a, i, fltr=None, mask=None):
+ s = ops.modal_dot(fltr, K.dot(x, self.kernel_pool))
+ s = activations.softmax(s, axis=-1)
if mask is not None:
- S *= mask[0]
-
- # Link prediction loss
- S_gram = ops.modal_dot(S, S, transpose_b=True)
- if mode == modes.MIXED:
- A = tf.sparse.to_dense(A)[None, ...]
- if K.is_sparse(A):
- LP_loss = tf.sparse.add(A, -S_gram) # A/tf.norm(A) - S_gram/tf.norm(S_gram)
- else:
- LP_loss = A - S_gram
- LP_loss = tf.norm(LP_loss, axis=(-1, -2))
- if self.reduce_loss:
- LP_loss = K.mean(LP_loss)
- self.add_loss(LP_loss)
+ s *= mask[0]
- # Entropy loss
- entr = tf.negative(
- tf.reduce_sum(tf.multiply(S, K.log(S + K.epsilon())), axis=-1)
- )
- entr_loss = K.mean(entr, axis=-1)
- if self.reduce_loss:
+ # Auxiliary losses
+ lp_loss = self.link_prediction_loss(a, s)
+ entr_loss = self.entropy_loss(s)
+ if K.ndim(x) == 3:
+ lp_loss = K.mean(lp_loss)
entr_loss = K.mean(entr_loss)
+ self.add_loss(lp_loss)
self.add_loss(entr_loss)
- # Pooling
- X_pooled = ops.modal_dot(S, Z, transpose_a=True)
- A_pooled = ops.matmul_at_b_a(S, A)
+ return s
- output = [X_pooled, A_pooled]
+ def reduce(self, x, s, fltr=None):
+ z = ops.modal_dot(fltr, K.dot(x, self.kernel_emb))
+ z = self.activation(z)
- if self.return_mask:
- output.append(S)
+ return ops.modal_dot(s, z, transpose_a=True)
- return output
+ def connect(self, a, s, **kwargs):
+ return ops.matmul_at_b_a(s, a)
+
+ def reduce_index(self, i, s, **kwargs):
+ i_mean = tf.math.segment_mean(i, i)
+ i_pool = ops.repeat(i_mean, tf.ones_like(i_mean) * self.k)
+
+ return i_pool
+
+ @staticmethod
+ def link_prediction_loss(a, s):
+ s_gram = ops.modal_dot(s, s, transpose_b=True)
+ if K.is_sparse(a):
+ lp_loss = tf.sparse.add(a, -s_gram)
+ else:
+ lp_loss = a - s_gram
+ lp_loss = tf.norm(lp_loss, axis=(-1, -2))
+ return lp_loss
+
+ @staticmethod
+ def entropy_loss(s):
+ entr = tf.negative(
+ tf.reduce_sum(tf.multiply(s, K.log(s + K.epsilon())), axis=-1)
+ )
+ entr_loss = K.mean(entr, axis=-1)
+ return entr_loss
- @property
- def config(self):
- return {
- "k": self.k,
- "channels": self.channels,
- "return_mask": self.return_mask,
- }
+ def get_config(self):
+ config = {"k": self.k, "channels": self.channels}
+ base_config = super().get_config()
+ return {**base_config, **config}
diff --git a/spektral/layers/pooling/global_pool.py b/spektral/layers/pooling/global_pool.py
index fef28448..f12ed430 100644
--- a/spektral/layers/pooling/global_pool.py
+++ b/spektral/layers/pooling/global_pool.py
@@ -327,12 +327,14 @@ def call(self, inputs):
X = inputs
attn_coeff = K.dot(X, self.attn_kernel)
attn_coeff = K.squeeze(attn_coeff, -1)
- attn_coeff = K.softmax(attn_coeff)
if self.data_mode == "single":
+ attn_coeff = K.softmax(attn_coeff)
output = K.dot(attn_coeff[None, ...], X)
elif self.data_mode == "batch":
+ attn_coeff = K.softmax(attn_coeff)
output = K.batch_dot(attn_coeff, X)
else:
+ attn_coeff = ops.unsorted_segment_softmax(attn_coeff, I, K.shape(X)[0])
output = attn_coeff[:, None] * X
output = tf.math.segment_sum(output, I)
diff --git a/spektral/layers/pooling/la_pool.py b/spektral/layers/pooling/la_pool.py
new file mode 100644
index 00000000..253c1759
--- /dev/null
+++ b/spektral/layers/pooling/la_pool.py
@@ -0,0 +1,187 @@
+import tensorflow as tf
+from scipy import sparse
+from tensorflow.keras import backend as K
+
+from spektral.layers import ops
+from spektral.layers.pooling.src import SRCPool
+
+
+class LaPool(SRCPool):
+ r"""
+ A Laplacian pooling (LaPool) layer from the paper
+
+ > [Towards Interpretable Sparse Graph Representation Learning with Laplacian Pooling](https://arxiv.org/abs/1905.11577)
+ > Emmanuel Noutahi et al.
+
+ **Mode**: disjoint.
+
+ This layer computes a soft clustering of the graph by first identifying a set of
+ leaders, and then assigning every remaining node to the cluster of the closest
+ leader:
+ $$
+ \V = \norm{\L\X}_d; \\
+ \i = \{ i \mid \V_i > \V_j, \forall j \in \cN(i) \} \\
+ \S^\top = \textrm{SparseMax}\left( \beta \frac{\X\X_{\i}^\top}{\norm{\X}\norm{\X_{\i}}} \right)
+ $$
+ \(\beta\) is a regularization vecotr that is applied element-wise to the selection
+ matrix.
+ If `shortest_path_reg=True`, it is equal to the inverse of the shortest path between
+ each node and its corresponding leader (this can be expensive since it runs on CPU).
+ Otherwise it is equal to 1.
+
+ The reduction and connection are computed as \(\X' = \S\X\) and
+ \(\A' = \S^\top\A\S\), respectively.
+
+ Note that the number of nodes in the output graph depends on the input node features.
+
+ **Input**
+
+ - Node features of shape `(n_nodes_in, n_node_features)`;
+ - Adjacency matrix of shape `(n_nodes_in, n_nodes_in)`;
+
+ **Output**
+
+ - Reduced node features of shape `(n_nodes_out, channels)`;
+ - Reduced adjacency matrix of shape `(n_nodes_out, n_nodes_out)`;
+ - If `return_selection=True`, the selection matrix of shape
+ `(n_nodes_in, n_nodes_out)`.
+
+ **Arguments**
+
+ - `shortest_path_reg`: boolean, apply the shortest path regularization described in
+ the papaer (can be expensive);
+ - `return_selection`: boolean, whether to return the selection matrix;
+ """
+
+ def __init__(self, shortest_path_reg=True, return_selection=False, **kwargs):
+ super().__init__(return_selection=return_selection, **kwargs)
+
+ self.shortest_path_reg = shortest_path_reg
+
+ def call(self, inputs, **kwargs):
+ x, a, i = self.get_inputs(inputs)
+
+ # Select leaders
+ lap = laplacian(a)
+ v = ops.modal_dot(lap, x)
+ v = tf.norm(v, axis=-1, keepdims=1)
+
+ row = a.indices[:, 0]
+ col = a.indices[:, 1]
+ leader_check = tf.cast(tf.gather(v, row) >= tf.gather(v, col), tf.int32)
+ leader_mask = ops.scatter_prod(leader_check[:, 0], row, self.n_nodes)
+ leader_mask = tf.cast(leader_mask, tf.bool)
+
+ return self.pool(x, a, i, leader_mask=leader_mask)
+
+ def select(self, x, a, i, leader_mask=None):
+ # Cosine similarity
+ if i is None:
+ i = tf.zeros(self.n_nodes, dtype=tf.int32)
+ cosine_similarity = sparse_cosine_similarity(x, self.n_nodes, leader_mask, i)
+
+ # Shortest path regularization
+ if self.shortest_path_reg:
+
+ def shortest_path(a_):
+ return sparse.csgraph.shortest_path(a_, directed=False)
+
+ np_fn_input = tf.sparse.to_dense(a) if K.is_sparse(a) else a
+ beta = 1 / tf.numpy_function(shortest_path, [np_fn_input], tf.float64)
+ beta = tf.where(tf.math.is_inf(beta), tf.zeros_like(beta), beta)
+ beta = tf.boolean_mask(beta, leader_mask, axis=1)
+ beta = tf.cast(
+ tf.ensure_shape(beta, cosine_similarity.shape), cosine_similarity.dtype
+ )
+ else:
+ beta = 1.0
+
+ s = tf.sparse.softmax(cosine_similarity)
+ s = beta * tf.sparse.to_dense(s)
+
+ # Leaders end up entirely in their own cluster
+ kronecker_delta = tf.boolean_mask(
+ tf.eye(self.n_nodes, dtype=s.dtype), leader_mask, axis=1
+ )
+
+ # Create clustering
+ s = tf.where(leader_mask[:, None], kronecker_delta, s)
+
+ return s
+
+ def reduce(self, x, s, **kwargs):
+ return ops.modal_dot(s, x, transpose_a=True)
+
+ def connect(self, a, s, **kwargs):
+ return ops.matmul_at_b_a(s, a)
+
+ def reduce_index(self, i, s, leader_mask=None):
+ i_pool = tf.boolean_mask(i, leader_mask)
+
+ return i_pool
+
+ def get_config(self):
+ config = {"shortest_path_reg": self.shortest_path_reg}
+ base_config = super().get_config()
+ return {**base_config, **config}
+
+
+def laplacian(a):
+ d = ops.degree_matrix(a, return_sparse_batch=True)
+ if K.is_sparse(a):
+ a = a.__mul__(-1)
+ else:
+ a = -a
+
+ return tf.sparse.add(d, a)
+
+
+def reduce_sum(x, **kwargs):
+ if K.is_sparse(x):
+ return tf.sparse.reduce_sum(x, **kwargs)
+ else:
+ return tf.reduce_sum(x, **kwargs)
+
+
+def sparse_cosine_similarity(x, n_nodes, mask, i):
+ mask = tf.cast(mask, tf.int32)
+ leader_idx = tf.where(mask)
+
+ # Number of nodes in each graph
+ ns = tf.math.segment_sum(tf.ones_like(i), i)
+ ks = tf.math.segment_sum(mask, i)
+
+ # s will be a block-diagonal matrix where entry i,j is the cosine
+ # similarity between node i and leader j.
+ # The code below creates the indices of the sparse block-diagonal matrix
+ # Row indices of the block-diagonal S
+ starts = tf.cumsum(ns) - ns
+ starts = tf.repeat(starts, ks)
+ stops = tf.cumsum(ns)
+ stops = tf.repeat(stops, ks)
+ index_n = tf.ragged.range(starts, stops).flat_values
+
+ # Column indices of the block-diagonal S
+ index_k = tf.repeat(leader_idx, tf.repeat(ns, ks))
+ index_k_for_s = tf.repeat(tf.range(tf.reduce_sum(ks)), tf.repeat(ns, ks))
+
+ # Make index int64
+ index_n = tf.cast(index_n, tf.int64)
+ index_k = tf.cast(index_k, tf.int64)
+ index_k_for_s = tf.cast(index_k_for_s, tf.int64)
+
+ # Compute similarity between nodes and leaders
+ x_n = tf.gather(x, index_n)
+ x_n_norm = tf.norm(x_n, axis=-1)
+ x_k = tf.gather(x, index_k)
+ x_k_norm = tf.norm(x_k, axis=-1)
+ values = tf.reduce_sum(x_n * x_k, -1) / (x_n_norm * x_k_norm)
+
+ # Create a sparse tensor for S
+ indices = tf.stack((index_n, index_k_for_s), 1)
+ s = tf.SparseTensor(
+ values=values, indices=indices, dense_shape=(n_nodes, tf.reduce_sum(ks))
+ )
+ s = tf.sparse.reorder(s)
+
+ return s
diff --git a/spektral/layers/pooling/mincut_pool.py b/spektral/layers/pooling/mincut_pool.py
index 4d6b8e8c..8b296b48 100644
--- a/spektral/layers/pooling/mincut_pool.py
+++ b/spektral/layers/pooling/mincut_pool.py
@@ -4,10 +4,10 @@
from tensorflow.keras.layers import Dense
from spektral.layers import ops
-from spektral.layers.pooling.pool import Pool
+from spektral.layers.pooling.src import SRCPool
-class MinCutPool(Pool):
+class MinCutPool(SRCPool):
r"""
A MinCut pooling layer from the paper
@@ -16,22 +16,21 @@ class MinCutPool(Pool):
**Mode**: batch.
- This layer computes a soft clustering \(\S\) of the input graphs using a MLP,
- and reduces graphs as follows:
+ This layer learns a soft clustering of the input graph as follows:
$$
\begin{align}
\S &= \textrm{MLP}(\X); \\
- \A' &= \S^\top \A \S; \\
- \X' &= \S^\top \X
+ \X' &= \S^\top \X \\
+ \A' &= \S^\top \A \S; \\
\end{align}
$$
- where MLP is a multi-layer perceptron with softmax output.
+ where \(\textrm{MLP}\) is a multi-layer perceptron with softmax output.
- Two auxiliary loss terms are also added to the model: the _minCUT loss_
+ Two auxiliary loss terms are also added to the model: the minimum cut loss
$$
L_c = - \frac{ \mathrm{Tr}(\S^\top \A \S) }{ \mathrm{Tr}(\S^\top \D \S) }
$$
- and the _orthogonality loss_
+ and the orthogonality loss
$$
L_o = \left\|
\frac{\S^\top \S}{\| \S^\top \S \|_F}
@@ -39,28 +38,30 @@ class MinCutPool(Pool):
\right\|_F.
$$
- The layer can be used without a supervised loss, to compute node clustering
- simply by minimizing the two auxiliary losses.
+ The layer can be used without a supervised loss to compute node clustering by
+ minimizing the two auxiliary losses.
**Input**
- - Node features of shape `([batch], n_nodes, n_node_features)`;
- - Symmetrically normalized adjacency matrix of shape `([batch], n_nodes, n_nodes)`;
+ - Node features of shape `(batch, n_nodes_in, n_node_features)`;
+ - Symmetrically normalized adjacency matrix of shape
+ `(batch, n_nodes_in, n_nodes_in)`;
**Output**
- - Reduced node features of shape `([batch], K, n_node_features)`;
- - Reduced adjacency matrix of shape `([batch], K, K)`;
- - If `return_mask=True`, the soft clustering matrix of shape `([batch], n_nodes, K)`.
+ - Reduced node features of shape `(batch, n_nodes_out, n_node_features)`;
+ - Reduced adjacency matrix of shape `(batch, n_nodes_out, n_nodes_out)`;
+ - If `return_selection=True`, the selection matrix of shape
+ `(batch, n_nodes_in, n_nodes_out)`.
**Arguments**
- `k`: number of output nodes;
- - `mlp_hidden`: list of integers, number of hidden units for each hidden
- layer in the MLP used to compute cluster assignments (if None, the MLP has
- only the output layer);
+ - `mlp_hidden`: list of integers, number of hidden units for each hidden layer in
+ the MLP used to compute cluster assignments (if `None`, the MLP has only one output
+ layer);
- `mlp_activation`: activation for the MLP layers;
- - `return_mask`: boolean, whether to return the cluster assignment matrix;
+ - `return_selection`: boolean, whether to return the selection matrix;
- `use_bias`: use bias in the MLP;
- `kernel_initializer`: initializer for the weights of the MLP;
- `bias_initializer`: initializer for the bias of the MLP;
@@ -75,9 +76,7 @@ def __init__(
k,
mlp_hidden=None,
mlp_activation="relu",
- return_mask=False,
- activation=None,
- use_bias=True,
+ return_selection=False,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
@@ -86,10 +85,8 @@ def __init__(
bias_constraint=None,
**kwargs
):
-
super().__init__(
- activation=activation,
- use_bias=use_bias,
+ return_selection=return_selection,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
@@ -98,10 +95,10 @@ def __init__(
bias_constraint=bias_constraint,
**kwargs
)
+
self.k = k
- self.mlp_hidden = mlp_hidden if mlp_hidden else []
+ self.mlp_hidden = mlp_hidden if mlp_hidden is not None else []
self.mlp_activation = mlp_activation
- self.return_mask = return_mask
def build(self, input_shape):
layer_kwargs = dict(
@@ -112,65 +109,81 @@ def build(self, input_shape):
kernel_constraint=self.kernel_constraint,
bias_constraint=self.bias_constraint,
)
- mlp_layers = []
- for _, channels in enumerate(self.mlp_hidden):
- mlp_layers.append(Dense(channels, self.mlp_activation, **layer_kwargs))
- mlp_layers.append(Dense(self.k, "softmax", **layer_kwargs))
- self.mlp = Sequential(mlp_layers)
+ self.mlp = Sequential(
+ [
+ Dense(channels, self.mlp_activation, **layer_kwargs)
+ for channels in self.mlp_hidden
+ ]
+ + [Dense(self.k, "softmax", **layer_kwargs)]
+ )
super().build(input_shape)
def call(self, inputs, mask=None):
- X, A = inputs
+ x, a, i = self.get_inputs(inputs)
+ return self.pool(x, a, i, mask=mask)
- # Check if the layer is operating in batch mode (X and A have rank 3)
- batch_mode = K.ndim(X) == 3
-
- # Compute cluster assignment matrix
- S = self.mlp(X)
+ def select(self, x, a, i, mask=None):
+ s = self.mlp(x)
if mask is not None:
- S *= mask[0]
+ s *= mask[0]
- # MinCut regularization
- A_pooled = ops.matmul_at_b_a(S, A)
- num = tf.linalg.trace(A_pooled)
- D = ops.degree_matrix(A)
- den = tf.linalg.trace(ops.matmul_at_b_a(S, D)) + K.epsilon()
- cut_loss = -(num / den)
- if batch_mode:
+ # Orthogonality loss
+ ortho_loss = self.orthogonality_loss(s)
+ if K.ndim(a) == 3:
+ ortho_loss = K.mean(ortho_loss)
+ self.add_loss(ortho_loss)
+
+ return s
+
+ def reduce(self, x, s, **kwargs):
+ return ops.modal_dot(s, x, transpose_a=True)
+
+ def connect(self, a, s, **kwargs):
+ a_pool = ops.matmul_at_b_a(s, a)
+
+ # MinCut loss
+ cut_loss = self.mincut_loss(a, s, a_pool)
+ if K.ndim(a) == 3:
cut_loss = K.mean(cut_loss)
self.add_loss(cut_loss)
- # Orthogonality regularization
- SS = ops.modal_dot(S, S, transpose_a=True)
- I_S = tf.eye(self.k, dtype=SS.dtype)
- ortho_loss = tf.norm(
- SS / tf.norm(SS, axis=(-1, -2), keepdims=True) - I_S / tf.norm(I_S),
- axis=(-1, -2),
+ # Post-processing of A
+ a_pool = tf.linalg.set_diag(
+ a_pool, tf.zeros(K.shape(a_pool)[:-1], dtype=a_pool.dtype)
)
- if batch_mode:
- ortho_loss = K.mean(ortho_loss)
- self.add_loss(ortho_loss)
+ a_pool = ops.normalize_A(a_pool)
- # Pooling
- X_pooled = ops.modal_dot(S, X, transpose_a=True)
- A_pooled = tf.linalg.set_diag(
- A_pooled, tf.zeros(K.shape(A_pooled)[:-1], dtype=A_pooled.dtype)
- ) # Remove diagonal
- A_pooled = ops.normalize_A(A_pooled)
+ return a_pool
- output = [X_pooled, A_pooled]
+ def reduce_index(self, i, s, **kwargs):
+ i_mean = tf.math.segment_mean(i, i)
+ i_pool = ops.repeat(i_mean, tf.ones_like(i_mean) * self.k)
- if self.return_mask:
- output.append(S)
+ return i_pool
- return output
+ def orthogonality_loss(self, s):
+ ss = ops.modal_dot(s, s, transpose_a=True)
+ i_s = tf.eye(self.k, dtype=ss.dtype)
+ ortho_loss = tf.norm(
+ ss / tf.norm(ss, axis=(-1, -2), keepdims=True) - i_s / tf.norm(i_s),
+ axis=(-1, -2),
+ )
+ return ortho_loss
+
+ @staticmethod
+ def mincut_loss(a, s, a_pool):
+ num = tf.linalg.trace(a_pool)
+ d = ops.degree_matrix(a)
+ den = tf.linalg.trace(ops.matmul_at_b_a(s, d))
+ cut_loss = -(num / den)
+ return cut_loss
- @property
- def config(self):
- return {
+ def get_config(self):
+ config = {
"k": self.k,
"mlp_hidden": self.mlp_hidden,
"mlp_activation": self.mlp_activation,
- "return_mask": self.return_mask,
}
+ base_config = super().get_config()
+ return {**base_config, **config}
diff --git a/spektral/layers/pooling/pool.py b/spektral/layers/pooling/pool.py
deleted file mode 100644
index a37afb91..00000000
--- a/spektral/layers/pooling/pool.py
+++ /dev/null
@@ -1,56 +0,0 @@
-from tensorflow.keras.layers import Layer
-
-from spektral.utils.keras import (
- deserialize_kwarg,
- is_keras_kwarg,
- is_layer_kwarg,
- serialize_kwarg,
-)
-
-
-class Pool(Layer):
- r"""
- A general class for pooling layers.
-
- You can extend this class to create custom implementations of pooling layers.
-
- Any extension of this class must implement the `call(self, inputs)` and
- `config(self)` methods.
-
- **Arguments**:
-
- - ``**kwargs`: additional keyword arguments specific to Keras' Layers, like
- regularizers, initializers, constraints, etc.
- """
-
- def __init__(self, **kwargs):
- super().__init__(**{k: v for k, v in kwargs.items() if is_keras_kwarg(k)})
- self.supports_masking = True
- self.kwargs_keys = []
- for key in kwargs:
- if is_layer_kwarg(key):
- attr = kwargs[key]
- attr = deserialize_kwarg(key, attr)
- self.kwargs_keys.append(key)
- setattr(self, key, attr)
-
- def build(self, input_shape):
- self.built = True
-
- def call(self, inputs):
- raise NotImplementedError
-
- def get_config(self):
- base_config = super().get_config()
- keras_config = {}
- for key in self.kwargs_keys:
- keras_config[key] = serialize_kwarg(key, getattr(self, key))
- return {**base_config, **keras_config, **self.config}
-
- def compute_mask(self, inputs, mask=None):
- if mask is not None:
- return None
-
- @property
- def config(self):
- return {}
diff --git a/spektral/layers/pooling/sag_pool.py b/spektral/layers/pooling/sag_pool.py
index b6620e8e..57dc928a 100644
--- a/spektral/layers/pooling/sag_pool.py
+++ b/spektral/layers/pooling/sag_pool.py
@@ -1,6 +1,8 @@
+import tensorflow as tf
from tensorflow.keras import backend as K
-from spektral.layers.pooling.topk_pool import TopKPool, ops
+from spektral.layers import ops
+from spektral.layers.pooling.topk_pool import TopKPool
class SAGPool(TopKPool):
@@ -12,40 +14,44 @@ class SAGPool(TopKPool):
**Mode**: single, disjoint.
- This layer computes the following operations:
+ This layer computes:
$$
\y = \textrm{GNN}(\A, \X); \;\;\;\;
\i = \textrm{rank}(\y, K); \;\;\;\;
\X' = (\X \odot \textrm{tanh}(\y))_\i; \;\;\;\;
\A' = \A_{\i, \i}
$$
- where \(\textrm{rank}(\y, K)\) returns the indices of the top K values of
- \(\y\) and
+ where \(\textrm{rank}(\y, K)\) returns the indices of the top K values of \(\y\) and
$$
\textrm{GNN}(\A, \X) = \A \X \W.
$$
- \(K\) is defined for each graph as a fraction of the number of nodes,
- controlled by the `ratio` argument.
+ \(K\) is defined for each graph as a fraction of the number of nodes, controlled by
+ the `ratio` argument.
+
+ The gating operation \(\textrm{tanh}(\y)\) (Cangea et al.) can be replaced with a
+ sigmoid (Gao & Ji).
**Input**
- - Node features of shape `(n_nodes, n_node_features)`;
- - Binary adjacency matrix of shape `(n_nodes, n_nodes)`;
+ - Node features of shape `(n_nodes_in, n_node_features)`;
+ - Adjacency matrix of shape `(n_nodes_in, n_nodes_in)`;
- Graph IDs of shape `(n_nodes, )` (only in disjoint mode);
**Output**
- - Reduced node features of shape `(ratio * n_nodes, n_node_features)`;
- - Reduced adjacency matrix of shape `(ratio * n_nodes, ratio * n_nodes)`;
- - Reduced graph IDs of shape `(ratio * n_nodes, )` (only in disjoint mode);
- - If `return_mask=True`, the binary pooling mask of shape `(ratio * n_nodes, )`.
+ - Reduced node features of shape `(ratio * n_nodes_in, n_node_features)`;
+ - Reduced adjacency matrix of shape `(ratio * n_nodes_in, ratio * n_nodes_in)`;
+ - Reduced graph IDs of shape `(ratio * n_nodes_in, )` (only in disjoint mode);
+ - If `return_selection=True`, the selection mask of shape `(ratio * n_nodes_in, )`.
+ - If `return_score=True`, the scoring vector of shape `(n_nodes_in, )`
**Arguments**
- `ratio`: float between 0 and 1, ratio of nodes to keep in each graph;
- - `return_mask`: boolean, whether to return the binary mask used for pooling;
- - `sigmoid_gating`: boolean, use a sigmoid gating activation instead of a
+ - `return_selection`: boolean, whether to return the selection mask;
+ - `return_score`: boolean, whether to return the node scoring vector;
+ - `sigmoid_gating`: boolean, use a sigmoid activation for gating instead of a
tanh;
- `kernel_initializer`: initializer for the weights;
- `kernel_regularizer`: regularization applied to the weights;
@@ -55,7 +61,8 @@ class SAGPool(TopKPool):
def __init__(
self,
ratio,
- return_mask=False,
+ return_selection=False,
+ return_score=False,
sigmoid_gating=False,
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
@@ -64,7 +71,8 @@ def __init__(
):
super().__init__(
ratio,
- return_mask=return_mask,
+ return_selection=return_selection,
+ return_score=return_score,
sigmoid_gating=sigmoid_gating,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
@@ -72,7 +80,21 @@ def __init__(
**kwargs
)
- def compute_scores(self, X, A, I):
- scores = K.dot(X, self.kernel)
- scores = ops.modal_dot(A, scores)
- return scores
+ def call(self, inputs):
+ x, a, i = self.get_inputs(inputs)
+
+ # Graph filter for GNN
+ if K.is_sparse(a):
+ i_n = tf.sparse.eye(self.n_nodes, dtype=a.dtype)
+ a_ = tf.sparse.add(a, i_n)
+ else:
+ i_n = tf.eye(self.n_nodes, dtype=a.dtype)
+ a_ = a + i_n
+ fltr = ops.normalize_A(a_)
+
+ y = ops.modal_dot(fltr, K.dot(x, self.kernel))
+ output = self.pool(x, a, i, y=y)
+ if self.return_score:
+ output.append(y)
+
+ return output
diff --git a/spektral/layers/pooling/src.py b/spektral/layers/pooling/src.py
new file mode 100644
index 00000000..1664d6d0
--- /dev/null
+++ b/spektral/layers/pooling/src.py
@@ -0,0 +1,315 @@
+import inspect
+
+import tensorflow as tf
+from tensorflow.keras import backend as K
+from tensorflow.keras.layers import Layer
+
+from spektral.utils.keras import (
+ deserialize_kwarg,
+ is_keras_kwarg,
+ is_layer_kwarg,
+ serialize_kwarg,
+)
+
+
+class SRCPool(Layer):
+ r"""
+ A general class for graph pooling layers based on the "Select, Reduce,
+ Connect" framework presented in:
+
+ > [Understanding Pooling in Graph Neural Networks.](https://arxiv.org/abs/2110.05292)
+ > Daniele Grattarola et al.
+
+ This layer computes:
+ $$
+ \begin{align}
+ & \mathcal{S} = \left\{\mathcal{S}_k\right\}_{k=1:K} = \textsc{Sel}(\mathcal{G}) \\
+ & \mathcal{X}'=\left\{\textsc{Red}( \mathcal{G}, \mathcal{S}_k )\right\}_{k=1:K} \\
+ & \mathcal{E}'=\left\{\textsc{Con}( \mathcal{G}, \mathcal{S}_k, \mathcal{S}_l )\right\}_{k,L=1:K} \\
+ \end{align}
+ $$
+ Where \(\textsc{Sel}\) is a node equivariant selection function that computes
+ the supernode assignments \(\mathcal{S}_k\), \(\textsc{Red}\) is a
+ permutation-invariant function to reduce the supernodes into the new node
+ attributes, and \(\textsc{Con}\) is a permutation-invariant connection
+ function that computes the link between the pooled nodes.
+
+ By extending this class, it is possible to create any pooling layer in the
+ SRC formalism.
+
+ **Input**
+
+ - `x`: Tensor of shape `([batch], N, F)` representing node features;
+ - `a`: Tensor or SparseTensor of shape `([batch], N, N)` representing the
+ adjacency matrix;
+ - `i`: (optional) Tensor of integers with shape `(N, )` representing the
+ batch index;
+
+ **Output**
+
+ - `x_pool`: Tensor of shape `([batch], K, F)`, representing the node
+ features of the output. `K` is the number of output nodes and depends on the
+ specific pooling strategy;
+ - `a_pool`: Tensor or SparseTensor of shape `([batch], K, K)` representing
+ the adjacency matrix of the output;
+ - `i_pool`: (only if i was given as input) Tensor of integers with shape
+ `(K, )` representing the batch index of the output;
+ - `s`: (if `return_selection=True`) Tensor or SparseTensor representing the
+ supernode assignments;
+
+ **API**
+
+ - `pool(x, a, i, **kwargs)`: pools the graph and returns the reduced node
+ features and adjacency matrix. If the batch index `i` is not `None`, a
+ reduced version of `i` will be returned as well.
+ Any given `kwargs` will be passed as keyword arguments to `select()`,
+ `reduce()` and `connect()` if any matching key is found.
+ The mandatory arguments of `pool()` **must** be computed in `call()` by
+ calling `self.get_inputs(inputs)`.
+ - `select(x, a, i, **kwargs)`: computes supernode assignments mapping the
+ nodes of the input graph to the nodes of the output.
+ - `reduce(x, s, **kwargs)`: reduces the supernodes to form the nodes of the
+ pooled graph.
+ - `connect(a, s, **kwargs)`: connects the reduced supernodes.
+ - `reduce_index(i, s, **kwargs)`: helper function to reduce the batch index
+ (only called if `i` is given as input).
+
+ When overriding any function of the API, it is possible to access the
+ true number of nodes of the input (`n_nodes`) as a Tensor in the instance variable
+ `self.n_nodes` (this is populated by `self.get_inputs()` at the beginning of
+ `call()`).
+
+ **Arguments**:
+
+ - `return_selection`: if `True`, the Tensor used to represent supernode assignments
+ will be returned with `x_pool`, `a_pool`, and `i_pool`;
+ """
+
+ def __init__(self, return_selection=False, **kwargs):
+ # kwargs for the Layer class are handled automatically
+ super().__init__(**{k: v for k, v in kwargs.items() if is_keras_kwarg(k)})
+ self.supports_masking = True
+ self.return_selection = return_selection
+
+ # *_regularizer, *_constraint, *_initializer, activation, and use_bias
+ # are dealt with automatically if passed to the constructor
+ self.kwargs_keys = []
+ for key in kwargs:
+ if is_layer_kwarg(key):
+ attr = kwargs[key]
+ attr = deserialize_kwarg(key, attr)
+ self.kwargs_keys.append(key)
+ setattr(self, key, attr)
+
+ # Signature of the SRC functions
+ self.sel_signature = inspect.signature(self.select).parameters
+ self.red_signature = inspect.signature(self.reduce).parameters
+ self.con_signature = inspect.signature(self.connect).parameters
+ self.i_red_signature = inspect.signature(self.reduce_index).parameters
+
+ self._n_nodes = None
+
+ def build(self, input_shape):
+ super().build(input_shape)
+
+ def call(self, inputs, **kwargs):
+ # Always start the call() method with get_inputs(inputs) to set self.n_nodes
+ x, a, i = self.get_inputs(inputs)
+
+ return self.pool(x, a, i)
+
+ def pool(self, x, a, i, **kwargs):
+ """
+ This is the core method of the SRC class, which runs a full pass of
+ selection, reduction and connection.
+ It is usually not necessary to modify this function. Any previous/shared
+ operations should be done in `call()` and their results can be passed to
+ the three SRC functions via keyword arguments (any kwargs given to this
+ function will be matched to the signature of `select()`, `reduce()` and
+ `connect()` and propagated as input to the three functions).
+ Any pooling logic should go in the SRC functions themselves.
+ :param x: Tensor of shape `([batch], N, F)`;
+ :param a: Tensor or SparseTensor of shape `([batch], N, N)`;
+ :param i: only in single/disjoint mode, Tensor of integers with shape
+ `(N, )`; otherwise, `None`;
+ :param kwargs: additional keyword arguments for `select()`, `reduce()`
+ and `connect()`. Any matching kwargs will be passed to each of the three
+ functions.
+ :return:
+ - `x_pool`: Tensor of shape `([batch], K, F)`, where `K` is the
+ number of output nodes and depends on the pooling strategy;
+ - `a_pool`: Tensor or SparseTensor of shape `([batch], K, K)`;
+ - `i_pool`: (only if `i` is not `None`) Tensor of integers with shape
+ `(K, )`;
+ """
+ # Select
+ sel_kwargs = self._get_kwargs(x, a, i, self.sel_signature, kwargs)
+ s = self.select(x, a, i, **sel_kwargs)
+
+ # Reduce
+ red_kwargs = self._get_kwargs(x, a, i, self.red_signature, kwargs)
+ x_pool = self.reduce(x, s, **red_kwargs)
+
+ # Index reduce
+ i_red_kwargs = self._get_kwargs(x, a, i, self.i_red_signature, kwargs)
+ i_pool = self.reduce_index(i, s, **i_red_kwargs) if i is not None else None
+
+ # Connect
+ con_kwargs = self._get_kwargs(x, a, i, self.con_signature, kwargs)
+ a_pool = self.connect(a, s, **con_kwargs)
+
+ return self.get_outputs(x_pool, a_pool, i_pool, s)
+
+ def select(self, x, a, i, **kwargs):
+ """
+ Selection function. Given the graph, computes the supernode assignments
+ that will eventually be mapped to the `K` nodes of the pooled graph.
+ Supernode assignments are usually represented as a dense matrix of shape
+ `(N, K)` or sparse indices of shape `(K, )`.
+ :param x: Tensor of shape `([batch], N, F)`;
+ :param a: Tensor or SparseTensor (depending on the implementation of the
+ SRC functions) of shape `([batch], N, N)`;
+ :param i: Tensor of integers with shape `(N, )` or `None`;
+ :param kwargs: additional keyword arguments.
+ :return: Tensor representing supernode assignments.
+ """
+ return tf.range(tf.shape(i))
+
+ def reduce(self, x, s, **kwargs):
+ """
+ Reduction function. Given a selection, reduces the supernodes to form
+ the nodes of the new graph.
+ :param x: Tensor of shape `([batch], N, F)`;
+ :param s: Tensor representing supernode assignments, as computed by
+ `select()`;
+ :param kwargs: additional keyword arguments; when overriding this
+ function, any keyword argument defined explicitly as `key=None` will be
+ automatically filled in when calling `pool(key=value)`.
+ :return: Tensor of shape `([batch], K, F)` representing the node attributes of
+ the pooled graph.
+ """
+ return tf.gather(x, s)
+
+ def connect(self, a, s, **kwargs):
+ """
+ Connection function. Given a selection, connects the nodes of the pooled
+ graphs.
+ :param a: Tensor or SparseTensor of shape `([batch], N, N)`;
+ :param s: Tensor representing supernode assignments, as computed by
+ `select()`;
+ :param kwargs: additional keyword arguments; when overriding this
+ function, any keyword argument defined explicitly as `key=None` will be
+ automatically filled in when calling `pool(key=value)`.
+ :return: Tensor or SparseTensor of shape `([batch], K, K)` representing
+ the adjacency matrix of the pooled graph.
+ """
+ return sparse_connect(a, s, self.n_nodes)
+
+ def reduce_index(self, i, s, **kwargs):
+ """
+ Helper function to reduce the batch index `i`. Given a selection,
+ returns a new batch index for the pooled graph. This is only called by
+ `pool()` when `i` is given as input to the layer.
+ :param i: Tensor of integers with shape `(N, )`;
+ :param s: Tensor representing supernode assignments, as computed by
+ `select()`.
+ :param kwargs: additional keyword arguments; when overriding this
+ function, any keyword argument defined explicitly as `key=None` will be
+ automatically filled in when calling `pool(key=value)`.
+ :return: Tensor of integers of shape `(K, )`.
+ """
+ return tf.gather(i, s)
+
+ @staticmethod
+ def _get_kwargs(x, a, i, signature, kwargs):
+ output = {}
+ for k in signature.keys():
+ if signature[k].default is inspect.Parameter.empty or k == "kwargs":
+ pass
+ elif k == "x":
+ output[k] = x
+ elif k == "a":
+ output[k] = a
+ elif k == "i":
+ output[k] = i
+ elif k in kwargs:
+ output[k] = kwargs[k]
+ else:
+ raise ValueError("Missing key {} for signature {}".format(k, signature))
+
+ return output
+
+ def get_inputs(self, inputs):
+ if len(inputs) == 3:
+ x, a, i = inputs
+ if K.ndim(i) == 2:
+ i = i[:, 0]
+ assert K.ndim(i) == 1, "i must have rank 1"
+ elif len(inputs) == 2:
+ x, a = inputs
+ i = None
+ else:
+ raise ValueError(
+ "Expected 2 or 3 inputs tensors (x, a, i), got {}.".format(len(inputs))
+ )
+
+ self.n_nodes = tf.shape(x)[-2]
+
+ return x, a, i
+
+ def get_outputs(self, x_pool, a_pool, i_pool, s):
+ output = [x_pool, a_pool]
+ if i_pool is not None:
+ output.append(i_pool)
+ if self.return_selection:
+ output.append(s)
+
+ return output
+
+ def get_config(self):
+ config = {
+ "return_selection": self.return_selection,
+ }
+ for key in self.kwargs_keys:
+ config[key] = serialize_kwarg(key, getattr(self, key))
+ base_config = super().get_config()
+ return {**base_config, **config}
+
+ def compute_mask(self, inputs, mask=None):
+ # After pooling all nodes are always valid
+ return None
+
+ @property
+ def n_nodes(self):
+ if self._n_nodes is None:
+ raise ValueError(
+ "self.n_nodes has not been defined. Have you called "
+ "self.get_inputs(inputs) at the beginning of call()?"
+ )
+ return self._n_nodes
+
+ @n_nodes.setter
+ def n_nodes(self, value):
+ self._n_nodes = value
+
+ @n_nodes.deleter
+ def n_nodes(self):
+ self._n_nodes = None
+
+
+def sparse_connect(A, S, N):
+ N_sel = tf.cast(tf.shape(S), tf.int64)[0]
+ m = tf.scatter_nd(S[:, None], tf.range(N_sel) + 1, (N,)) - 1
+
+ row, col = A.indices[:, 0], A.indices[:, 1]
+ r_mask = tf.gather(m, row)
+ c_mask = tf.gather(m, col)
+ mask_total = (r_mask >= 0) & (c_mask >= 0)
+ r_new = tf.boolean_mask(r_mask, mask_total)
+ c_new = tf.boolean_mask(c_mask, mask_total)
+ v_new = tf.boolean_mask(A.values, mask_total)
+
+ output = tf.SparseTensor(
+ values=v_new, indices=tf.stack((r_new, c_new), 1), dense_shape=(N_sel, N_sel)
+ )
+ return tf.sparse.reorder(output)
diff --git a/spektral/layers/pooling/topk_pool.py b/spektral/layers/pooling/topk_pool.py
index e1af473e..7fa99c36 100644
--- a/spektral/layers/pooling/topk_pool.py
+++ b/spektral/layers/pooling/topk_pool.py
@@ -2,10 +2,10 @@
from tensorflow.keras import backend as K
from spektral.layers import ops
-from spektral.layers.pooling.pool import Pool
+from spektral.layers.pooling.src import SRCPool
-class TopKPool(Pool):
+class TopKPool(SRCPool):
r"""
A gPool/Top-K layer from the papers
@@ -19,7 +19,7 @@ class TopKPool(Pool):
**Mode**: single, disjoint.
- This layer computes the following operations:
+ This layer computes:
$$
\y = \frac{\X\p}{\|\p\|}; \;\;\;\;
\i = \textrm{rank}(\y, K); \;\;\;\;
@@ -32,27 +32,29 @@ class TopKPool(Pool):
\(K\) is defined for each graph as a fraction of the number of nodes,
controlled by the `ratio` argument.
- Note that the the gating operation \(\textrm{tanh}(\y)\) (Cangea et al.)
- can be replaced with a sigmoid (Gao & Ji).
+ The gating operation \(\textrm{tanh}(\y)\) (Cangea et al.) can be replaced with a
+ sigmoid (Gao & Ji).
**Input**
- - Node features of shape `(n_nodes, n_node_features)`;
- - Binary adjacency matrix of shape `(n_nodes, n_nodes)`;
+ - Node features of shape `(n_nodes_in, n_node_features)`;
+ - Adjacency matrix of shape `(n_nodes_in, n_nodes_in)`;
- Graph IDs of shape `(n_nodes, )` (only in disjoint mode);
**Output**
- - Reduced node features of shape `(ratio * n_nodes, n_node_features)`;
- - Reduced adjacency matrix of shape `(ratio * n_nodes, ratio * n_nodes)`;
- - Reduced graph IDs of shape `(ratio * n_nodes, )` (only in disjoint mode);
- - If `return_mask=True`, the binary pooling mask of shape `(ratio * n_nodes, )`.
+ - Reduced node features of shape `(ratio * n_nodes_in, n_node_features)`;
+ - Reduced adjacency matrix of shape `(ratio * n_nodes_in, ratio * n_nodes_in)`;
+ - Reduced graph IDs of shape `(ratio * n_nodes_in, )` (only in disjoint mode);
+ - If `return_selection=True`, the selection mask of shape `(ratio * n_nodes_in, )`.
+ - If `return_score=True`, the scoring vector of shape `(n_nodes_in, )`
**Arguments**
- `ratio`: float between 0 and 1, ratio of nodes to keep in each graph;
- - `return_mask`: boolean, whether to return the binary mask used for pooling;
- - `sigmoid_gating`: boolean, use a sigmoid gating activation instead of a
+ - `return_selection`: boolean, whether to return the selection mask;
+ - `return_score`: boolean, whether to return the node scoring vector;
+ - `sigmoid_gating`: boolean, use a sigmoid activation for gating instead of a
tanh;
- `kernel_initializer`: initializer for the weights;
- `kernel_regularizer`: regularization applied to the weights;
@@ -62,7 +64,8 @@ class TopKPool(Pool):
def __init__(
self,
ratio,
- return_mask=False,
+ return_selection=False,
+ return_score=False,
sigmoid_gating=False,
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
@@ -70,21 +73,21 @@ def __init__(
**kwargs
):
super().__init__(
+ return_selection=return_selection,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
kernel_constraint=kernel_constraint,
**kwargs
)
self.ratio = ratio
- self.return_mask = return_mask
+ self.return_score = return_score
self.sigmoid_gating = sigmoid_gating
self.gating_op = K.sigmoid if self.sigmoid_gating else K.tanh
def build(self, input_shape):
- self.F = input_shape[0][-1]
- self.N = input_shape[0][0]
+ self.n_nodes = input_shape[0][0]
self.kernel = self.add_weight(
- shape=(self.F, 1),
+ shape=(input_shape[0][-1], 1),
name="kernel",
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
@@ -92,60 +95,88 @@ def build(self, input_shape):
)
super().build(input_shape)
- def call(self, inputs):
- if len(inputs) == 3:
- X, A, I = inputs
- self.data_mode = "disjoint"
- else:
- X, A = inputs
- I = tf.zeros(tf.shape(X)[:1])
- self.data_mode = "single"
- if K.ndim(I) == 2:
- I = I[:, 0]
- I = tf.cast(I, tf.int32)
-
- A_is_sparse = K.is_sparse(A)
-
- # Get mask
- y = self.compute_scores(X, A, I)
- N = K.shape(X)[-2]
- indices = ops.segment_top_k(y[:, 0], I, self.ratio)
- indices = tf.sort(indices) # required for ordered SparseTensors
- mask = ops.indices_to_mask(indices, N)
-
- # Multiply X and y to make layer differentiable
- features = X * self.gating_op(y)
-
- axis = 0 if len(A.shape) == 2 else 1 # Cannot use negative axis
- # Reduce X
- X_pooled = tf.gather(features, indices, axis=axis)
-
- # Reduce A
- if A_is_sparse:
- A_pooled, _ = ops.gather_sparse_square(A, indices, mask=mask)
- else:
- A_pooled = tf.gather(A, indices, axis=axis)
- A_pooled = tf.gather(A_pooled, indices, axis=axis + 1)
-
- output = [X_pooled, A_pooled]
-
- # Reduce I
- if self.data_mode == "disjoint":
- I_pooled = tf.gather(I, indices)
- output.append(I_pooled)
-
- if self.return_mask:
- output.append(mask)
+ def call(self, inputs, **kwargs):
+ x, a, i = self.get_inputs(inputs)
+ y = K.dot(x, K.l2_normalize(self.kernel))
+ output = self.pool(x, a, i, y=y)
+ if self.return_score:
+ output.append(y)
return output
- def compute_scores(self, X, A, I):
- return K.dot(X, K.l2_normalize(self.kernel))
+ def select(self, x, a, i, y=None):
+ if i is None:
+ i = tf.zeros(self.n_nodes)
+ s = segment_top_k(y[:, 0], i, self.ratio)
- @property
- def config(self):
- return {
+ return tf.sort(s)
+
+ def reduce(self, x, s, y=None):
+ x_pool = tf.gather(x * self.gating_op(y), s)
+
+ return x_pool
+
+ def get_outputs(self, x_pool, a_pool, i_pool, s):
+ output = [x_pool, a_pool]
+ if i_pool is not None:
+ output.append(i_pool)
+ if self.return_selection:
+ # Convert sparse indices to boolean mask
+ s = tf.scatter_nd(s[:, None], tf.ones_like(s), (self.n_nodes,))
+ output.append(s)
+
+ return output
+
+ def get_config(self):
+ config = {
"ratio": self.ratio,
- "return_mask": self.return_mask,
- "sigmoid_gating": self.sigmoid_gating,
}
+ base_config = super().get_config()
+ return {**base_config, **config}
+
+
+def segment_top_k(x, i, ratio):
+ """
+ Returns indices to get the top K values in x segment-wise, according to
+ the segments defined in I. K is not fixed, but it is defined as a ratio of
+ the number of elements in each segment.
+ :param x: a rank 1 Tensor;
+ :param i: a rank 1 Tensor with segment IDs for x;
+ :param ratio: float, ratio of elements to keep for each segment;
+ :return: a rank 1 Tensor containing the indices to get the top K values of
+ each segment in x.
+ """
+ i = tf.cast(i, tf.int32)
+ n = tf.shape(i)[0]
+ n_nodes = tf.math.segment_sum(tf.ones_like(i), i)
+ batch_size = tf.shape(n_nodes)[0]
+ n_nodes_max = tf.reduce_max(n_nodes)
+ cumulative_n_nodes = tf.concat(
+ (tf.zeros(1, dtype=n_nodes.dtype), tf.cumsum(n_nodes)[:-1]), 0
+ )
+ index = tf.range(n)
+ index = (index - tf.gather(cumulative_n_nodes, i)) + (i * n_nodes_max)
+
+ dense_x = tf.zeros(batch_size * n_nodes_max, dtype=x.dtype) - 1e20
+ dense_x = tf.tensor_scatter_nd_update(dense_x, index[:, None], x)
+ dense_x = tf.reshape(dense_x, (batch_size, n_nodes_max))
+
+ perm = tf.argsort(dense_x, direction="DESCENDING")
+ perm = perm + cumulative_n_nodes[:, None]
+ perm = tf.reshape(perm, (-1,))
+
+ k = tf.cast(tf.math.ceil(ratio * tf.cast(n_nodes, tf.float32)), i.dtype)
+
+ # This costs more memory
+ # to_rep = tf.tile(tf.constant([1., 0.]), (batch_size,))
+ # rep_times = tf.reshape(tf.concat((k[:, None], (n_nodes_max - k)[:, None]), -1), (-1,))
+ # mask = ops.repeat(to_rep, rep_times)
+ # perm = tf.boolean_mask(perm, mask)
+
+ # This is slower
+ r_range = tf.ragged.range(k).flat_values
+ r_delta = ops.repeat(tf.range(batch_size) * n_nodes_max, k)
+ mask = r_range + r_delta
+ perm = tf.gather(perm, mask)
+
+ return perm
diff --git a/spektral/models/gcn.py b/spektral/models/gcn.py
index 5c850d35..37b8a67e 100644
--- a/spektral/models/gcn.py
+++ b/spektral/models/gcn.py
@@ -31,7 +31,6 @@ class GCN(tf.keras.Model):
- `use_bias`: whether to add a learnable bias to the two GCNConv layers;
- `dropout_rate`: `rate` used in `Dropout` layers;
- `l2_reg`: l2 regularization strength;
- - `n_input_channels`: number of input channels, required for tf 2.1;
- `**kwargs`: passed to `Model.__init__`.
"""
@@ -44,7 +43,6 @@ def __init__(
use_bias=False,
dropout_rate=0.5,
l2_reg=2.5e-4,
- n_input_channels=None,
**kwargs,
):
super().__init__(**kwargs)
@@ -56,7 +54,6 @@ def __init__(
self.use_bias = use_bias
self.dropout_rate = dropout_rate
self.l2_reg = l2_reg
- self.n_input_channels = n_input_channels
reg = tf.keras.regularizers.l2(l2_reg)
self._d0 = tf.keras.layers.Dropout(dropout_rate)
self._gcn0 = gcn_conv.GCNConv(
@@ -67,13 +64,6 @@ def __init__(
n_labels, activation=output_activation, use_bias=use_bias
)
- if tf.version.VERSION < "2.2":
- if n_input_channels is None:
- raise ValueError("n_input_channels required for tf < 2.2")
- x = tf.keras.Input((n_input_channels,), dtype=tf.float32)
- a = tf.keras.Input((None,), dtype=tf.float32, sparse=True)
- self._set_inputs((x, a))
-
def get_config(self):
return dict(
n_labels=self.n_labels,
@@ -83,7 +73,6 @@ def get_config(self):
use_bias=self.use_bias,
dropout_rate=self.dropout_rate,
l2_reg=self.l2_reg,
- n_input_channels=self.n_input_channels,
)
def call(self, inputs):
@@ -91,10 +80,7 @@ def call(self, inputs):
x, a = inputs
else:
x, a, _ = inputs # So that the model can be used with DisjointLoader
- if self.n_input_channels is None:
- self.n_input_channels = x.shape[-1]
- else:
- assert self.n_input_channels == x.shape[-1]
+
x = self._d0(x)
x = self._gcn0([x, a])
x = self._d1(x)
diff --git a/tests/test_layers/convolutional/test_crystal_conv.py b/tests/test_layers/convolutional/test_crystal_conv.py
index acf0c24b..a61121bc 100644
--- a/tests/test_layers/convolutional/test_crystal_conv.py
+++ b/tests/test_layers/convolutional/test_crystal_conv.py
@@ -1,11 +1,11 @@
-from core import MODES, run_layer
+from core import MODES, F, run_layer
from spektral import layers
config = {
"layer": layers.CrystalConv,
"modes": [MODES["SINGLE"], MODES["MIXED"]],
- "kwargs": {"channels": 7},
+ "kwargs": {"channels": F}, # Set channels same as node features
"dense": False,
"sparse": True,
"edges": True,
diff --git a/tests/test_layers/pooling/core.py b/tests/test_layers/pooling/core.py
index feca8eb0..59927f17 100644
--- a/tests/test_layers/pooling/core.py
+++ b/tests/test_layers/pooling/core.py
@@ -29,8 +29,7 @@ def _check_output_and_model_output_shapes(true_shape, model_shape):
def _check_number_of_nodes(N_pool_expected, N_pool_true):
- if N_pool_expected is not None:
- assert N_pool_expected == N_pool_true or N_pool_true is None
+ assert N_pool_expected == N_pool_true or N_pool_true is None
def _test_single_mode(layer, sparse=False, **kwargs):
@@ -55,17 +54,20 @@ def _test_single_mode(layer, sparse=False, **kwargs):
X_pool, A_pool, mask = output
if "ratio" in kwargs.keys():
N_exp = kwargs["ratio"] * N
+ N_pool_expected = int(np.ceil(N_exp))
elif "k" in kwargs.keys():
- N_exp = kwargs["k"]
+ N_pool_expected = int(kwargs["k"])
else:
- raise ValueError("Need k or ratio.")
- N_pool_expected = int(np.ceil(N_exp))
+ N_pool_expected = None
+
N_pool_true = A_pool.shape[-1]
+ assert N_pool_true > 0
- _check_number_of_nodes(N_pool_expected, N_pool_true)
+ if N_pool_expected is not None:
+ _check_number_of_nodes(N_pool_expected, N_pool_true)
- assert X_pool.shape == (N_pool_expected, F)
- assert A_pool.shape == (N_pool_expected, N_pool_expected)
+ assert X_pool.shape == (N_pool_expected, F)
+ assert A_pool.shape == (N_pool_expected, N_pool_expected)
output_shape = [o.shape for o in output]
_check_output_and_model_output_shapes(output_shape, model.output_shape)
@@ -89,17 +91,18 @@ def _test_batch_mode(layer, **kwargs):
X_pool, A_pool, mask = output
if "ratio" in kwargs.keys():
N_exp = kwargs["ratio"] * N
+ N_pool_expected = int(np.ceil(N_exp))
elif "k" in kwargs.keys():
- N_exp = kwargs["k"]
+ N_pool_expected = int(kwargs["k"])
else:
- raise ValueError("Need k or ratio.")
- N_pool_expected = int(np.ceil(N_exp))
+ N_pool_expected = None
N_pool_true = A_pool.shape[-1]
- _check_number_of_nodes(N_pool_expected, N_pool_true)
+ if N_pool_expected is not None:
+ _check_number_of_nodes(N_pool_expected, N_pool_true)
- assert X_pool.shape == (batch_size, N_pool_expected, F)
- assert A_pool.shape == (batch_size, N_pool_expected, N_pool_expected)
+ assert X_pool.shape == (batch_size, N_pool_expected, F)
+ assert A_pool.shape == (batch_size, N_pool_expected, N_pool_expected)
output_shape = [o.shape for o in output]
_check_output_and_model_output_shapes(output_shape, model.output_shape)
@@ -128,20 +131,26 @@ def _test_disjoint_mode(layer, sparse=False, **kwargs):
output = model(input_data)
- X_pool, A_pool, I_pool, mask = output
- N_pool_expected = (
- np.ceil(kwargs["ratio"] * N1)
- + np.ceil(kwargs["ratio"] * N2)
- + np.ceil(kwargs["ratio"] * N3)
- )
- N_pool_expected = int(N_pool_expected)
+ X_pool, A_pool, I_pool, s = output
+
+ if "ratio" in kwargs.keys():
+ N_pool_expected = int(
+ np.ceil(kwargs["ratio"] * N1)
+ + np.ceil(kwargs["ratio"] * N2)
+ + np.ceil(kwargs["ratio"] * N3)
+ )
+ elif "k" in kwargs.keys():
+ N_pool_expected = int(kwargs["k"])
+ else:
+ N_pool_expected = None
N_pool_true = A_pool.shape[0]
- _check_number_of_nodes(N_pool_expected, N_pool_true)
+ if N_pool_expected is not None:
+ _check_number_of_nodes(N_pool_expected, N_pool_true)
- assert X_pool.shape == (N_pool_expected, F)
- assert A_pool.shape == (N_pool_expected, N_pool_expected)
- assert I_pool.shape == (N_pool_expected,)
+ assert X_pool.shape == (N_pool_expected, F)
+ assert A_pool.shape == (N_pool_expected, N_pool_expected)
+ assert I_pool.shape == (N_pool_expected,)
output_shape = [o.shape for o in output]
_check_output_and_model_output_shapes(output_shape, model.output_shape)
diff --git a/tests/test_layers/pooling/test_diff_pool.py b/tests/test_layers/pooling/test_diff_pool.py
index 91076427..a1cc2e35 100644
--- a/tests/test_layers/pooling/test_diff_pool.py
+++ b/tests/test_layers/pooling/test_diff_pool.py
@@ -4,7 +4,7 @@
config = {
"layer": layers.DiffPool,
"modes": [MODES["SINGLE"], MODES["BATCH"]],
- "kwargs": {"k": 5, "return_mask": True},
+ "kwargs": {"k": 5, "return_selection": True},
"dense": True,
"sparse": True,
}
diff --git a/tests/test_layers/pooling/test_la_pool.py b/tests/test_layers/pooling/test_la_pool.py
new file mode 100644
index 00000000..25a3b2f5
--- /dev/null
+++ b/tests/test_layers/pooling/test_la_pool.py
@@ -0,0 +1,14 @@
+from spektral import layers
+from tests.test_layers.pooling.core import MODES, run_layer
+
+config = {
+ "layer": layers.LaPool,
+ "modes": [MODES["SINGLE"], MODES["DISJOINT"]],
+ "kwargs": {"return_selection": True},
+ "dense": False,
+ "sparse": True,
+}
+
+
+def test_layer():
+ run_layer(config)
diff --git a/tests/test_layers/pooling/test_mincut_pool.py b/tests/test_layers/pooling/test_mincut_pool.py
index 9f80af58..379b38f8 100644
--- a/tests/test_layers/pooling/test_mincut_pool.py
+++ b/tests/test_layers/pooling/test_mincut_pool.py
@@ -4,7 +4,7 @@
config = {
"layer": layers.MinCutPool,
"modes": [MODES["SINGLE"], MODES["BATCH"]],
- "kwargs": {"k": 5, "return_mask": True, "mlp_hidden": [32]},
+ "kwargs": {"k": 5, "return_selection": True, "mlp_hidden": [32]},
"dense": True,
"sparse": True,
}
diff --git a/tests/test_layers/pooling/test_sag_pool.py b/tests/test_layers/pooling/test_sag_pool.py
index 21d43e3c..55022847 100644
--- a/tests/test_layers/pooling/test_sag_pool.py
+++ b/tests/test_layers/pooling/test_sag_pool.py
@@ -4,8 +4,8 @@
config = {
"layer": layers.SAGPool,
"modes": [MODES["SINGLE"], MODES["DISJOINT"]],
- "kwargs": {"ratio": 0.5, "return_mask": True},
- "dense": True,
+ "kwargs": {"ratio": 0.5, "return_selection": True},
+ "dense": False,
"sparse": True,
}
diff --git a/tests/test_layers/pooling/test_topk_pool.py b/tests/test_layers/pooling/test_topk_pool.py
index 470cff0f..a9f85397 100644
--- a/tests/test_layers/pooling/test_topk_pool.py
+++ b/tests/test_layers/pooling/test_topk_pool.py
@@ -4,8 +4,8 @@
config = {
"layer": layers.TopKPool,
"modes": [MODES["SINGLE"], MODES["DISJOINT"]],
- "kwargs": {"ratio": 0.5, "return_mask": True},
- "dense": True,
+ "kwargs": {"ratio": 0.5, "return_selection": True},
+ "dense": False,
"sparse": True,
}