Skip to content

Commit

Permalink
Fix 5D shape validation issues with concat layer
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Nov 1, 2024
1 parent cfa32a3 commit 3923400
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
10 changes: 6 additions & 4 deletions keras/src/layers/merging/concatenate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.layers.merging.base_merge import Merge
Expand Down Expand Up @@ -50,15 +52,15 @@ def build(self, input_shape):
return

reduced_inputs_shapes = [list(shape) for shape in input_shape]
reduced_inputs_shapes_copy = copy.copy(reduced_inputs_shapes)
shape_set = set()

for i in range(len(reduced_inputs_shapes)):
for i in range(len(reduced_inputs_shapes_copy)):
# Convert self.axis to positive axis for each input
# in case self.axis is a negative number
concat_axis = self.axis % len(reduced_inputs_shapes[i])
concat_axis = self.axis % len(reduced_inputs_shapes_copy[i])
# Skip batch axis.
for axis, axis_value in enumerate(
reduced_inputs_shapes[i][1:], start=1
reduced_inputs_shapes_copy, start=1
):
# Remove squeezable axes (axes with value of 1)
# if not in the axis that will be used for concatenation
Expand Down
30 changes: 30 additions & 0 deletions keras/src/layers/merging/merging_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras.src import backend
from keras.src import layers
from keras.src import models
from keras.src import ops
from keras.src import testing


Expand Down Expand Up @@ -339,6 +340,35 @@ def test_concatenate_with_mask(self):
)
self.assertAllClose(output._keras_mask, [[1, 1, 1, 1]])

def test_concatenate_errors(self):
# This should work
x1 = np.ones((1, 1, 1, 1, 5))
x2 = np.ones((1, 1, 1, 1, 4))
out = layers.Concatenate(axis=-1)([x1, x2])
self.assertEqual(ops.shape(out), (1, 1, 1, 1, 9))

# This won't
x1 = np.ones((1, 2, 1, 1, 5))
x2 = np.ones((1, 1, 1, 1, 4))
with self.assertRaisesRegex(
ValueError,
(
"requires inputs with matching shapes "
"except for the concatenation axis"
),
):
out = layers.Concatenate(axis=-1)([x1, x2])
x1 = np.ones((1, 2, 1, 2, 1))
x2 = np.ones((1, 1, 1, 3, 1))
with self.assertRaisesRegex(
ValueError,
(
"requires inputs with matching shapes "
"except for the concatenation axis"
),
):
out = layers.Concatenate(axis=1)([x1, x2])

@parameterized.named_parameters(TEST_PARAMETERS)
@pytest.mark.skipif(
not backend.SUPPORTS_SPARSE_TENSORS,
Expand Down

0 comments on commit 3923400

Please sign in to comment.