From 392340095855f3d1b28db4ebe06743683943df96 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Thu, 31 Oct 2024 19:36:45 -0700 Subject: [PATCH] Fix 5D shape validation issues with concat layer --- keras/src/layers/merging/concatenate.py | 10 ++++---- keras/src/layers/merging/merging_test.py | 30 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/keras/src/layers/merging/concatenate.py b/keras/src/layers/merging/concatenate.py index 7e240786ac3..f9d4d39ff3c 100644 --- a/keras/src/layers/merging/concatenate.py +++ b/keras/src/layers/merging/concatenate.py @@ -1,3 +1,5 @@ +import copy + from keras.src import ops from keras.src.api_export import keras_export from keras.src.layers.merging.base_merge import Merge @@ -50,15 +52,15 @@ def build(self, input_shape): return reduced_inputs_shapes = [list(shape) for shape in input_shape] + reduced_inputs_shapes_copy = copy.copy(reduced_inputs_shapes) shape_set = set() - - for i in range(len(reduced_inputs_shapes)): + for i in range(len(reduced_inputs_shapes_copy)): # Convert self.axis to positive axis for each input # in case self.axis is a negative number - concat_axis = self.axis % len(reduced_inputs_shapes[i]) + concat_axis = self.axis % len(reduced_inputs_shapes_copy[i]) # Skip batch axis. for axis, axis_value in enumerate( - reduced_inputs_shapes[i][1:], start=1 + reduced_inputs_shapes_copy, start=1 ): # Remove squeezable axes (axes with value of 1) # if not in the axis that will be used for concatenation diff --git a/keras/src/layers/merging/merging_test.py b/keras/src/layers/merging/merging_test.py index a3e2c5ffc07..0008ffd7af8 100644 --- a/keras/src/layers/merging/merging_test.py +++ b/keras/src/layers/merging/merging_test.py @@ -5,6 +5,7 @@ from keras.src import backend from keras.src import layers from keras.src import models +from keras.src import ops from keras.src import testing @@ -339,6 +340,35 @@ def test_concatenate_with_mask(self): ) self.assertAllClose(output._keras_mask, [[1, 1, 1, 1]]) + def test_concatenate_errors(self): + # This should work + x1 = np.ones((1, 1, 1, 1, 5)) + x2 = np.ones((1, 1, 1, 1, 4)) + out = layers.Concatenate(axis=-1)([x1, x2]) + self.assertEqual(ops.shape(out), (1, 1, 1, 1, 9)) + + # This won't + x1 = np.ones((1, 2, 1, 1, 5)) + x2 = np.ones((1, 1, 1, 1, 4)) + with self.assertRaisesRegex( + ValueError, + ( + "requires inputs with matching shapes " + "except for the concatenation axis" + ), + ): + out = layers.Concatenate(axis=-1)([x1, x2]) + x1 = np.ones((1, 2, 1, 2, 1)) + x2 = np.ones((1, 1, 1, 3, 1)) + with self.assertRaisesRegex( + ValueError, + ( + "requires inputs with matching shapes " + "except for the concatenation axis" + ), + ): + out = layers.Concatenate(axis=1)([x1, x2]) + @parameterized.named_parameters(TEST_PARAMETERS) @pytest.mark.skipif( not backend.SUPPORTS_SPARSE_TENSORS,