diff --git a/spektral/layers/ops/sparse.py b/spektral/layers/ops/sparse.py index 7048348c..b10e09e9 100644 --- a/spektral/layers/ops/sparse.py +++ b/spektral/layers/ops/sparse.py @@ -54,11 +54,10 @@ def add_self_loops_indices(indices, n_nodes=None): :return: Tensor of rank 2, the indices to a SparseTensor. """ n_nodes = tf.reduce_max(indices) + 1 if n_nodes is None else n_nodes - row, col = indices[..., 0], indices[..., 1] - mask = tf.ensure_shape(row != col, row.shape) - sl_indices = tf.range(n_nodes, dtype=row.dtype)[:, None] + mask = tf.map_fn(lambda i: i[0] != i[1], indices, tf.bool) + sl_indices = tf.range(n_nodes, dtype=indices.dtype)[:, None] sl_indices = tf.repeat(sl_indices, 2, -1) - indices = tf.concat((indices[mask], sl_indices), 0) + indices = tf.concat((tf.boolean_mask(indices, mask), sl_indices), 0) dummy_values = tf.ones_like(indices[:, 0]) indices, _ = gen_sparse_ops.sparse_reorder( indices, dummy_values, (n_nodes, n_nodes)