diff --git a/n3fit/src/n3fit/backends/keras_backend/base_layers.py b/n3fit/src/n3fit/backends/keras_backend/base_layers.py index 43f7622e99..3f1726fe16 100644 --- a/n3fit/src/n3fit/backends/keras_backend/base_layers.py +++ b/n3fit/src/n3fit/backends/keras_backend/base_layers.py @@ -18,18 +18,13 @@ """ from tensorflow import expand_dims, math, nn -from tensorflow.keras.layers import ( # pylint: disable=unused-import - Dropout, - Input, - Lambda, - concatenate, -) from tensorflow.keras.layers import Dense as KerasDense +from tensorflow.keras.layers import Input, Lambda, concatenate # pylint: disable=unused-import from tensorflow.keras.layers import LSTM, Concatenate # pylint: disable=unused-import from tensorflow.keras.regularizers import l1_l2 from n3fit.backends import MetaLayer -from n3fit.backends.keras_backend.multi_dense import MultiDense +from n3fit.backends.keras_backend.multi_dense import MultiDense, MultiDropout # Custom activation functions @@ -163,7 +158,7 @@ def apply_dense(xinput): LSTM_modified, {"kernel_initializer": "glorot_normal", "units": 5, "activation": "sigmoid"}, ), - "dropout": (Dropout, {"rate": 0.0}), + "dropout": (MultiDropout, {"rate": 0.0}), "concatenate": (Concatenate, {}), } diff --git a/n3fit/src/n3fit/backends/keras_backend/multi_dense.py b/n3fit/src/n3fit/backends/keras_backend/multi_dense.py index 5045711537..3e7dbdf4ec 100644 --- a/n3fit/src/n3fit/backends/keras_backend/multi_dense.py +++ b/n3fit/src/n3fit/backends/keras_backend/multi_dense.py @@ -2,7 +2,7 @@ import tensorflow as tf from tensorflow.keras.initializers import Initializer -from tensorflow.keras.layers import Dense +from tensorflow.keras.layers import Dense, Dropout class MultiDense(Dense): @@ -177,3 +177,22 @@ def __call__(self, shape, dtype=None, **kwargs): per_replica_weights.append(single_initializer(shape, dtype, **kwargs)) return tf.stack(per_replica_weights, axis=0) + + +class MultiDropout(Dropout): + """ + Dropout that broadcasts to the replica axis, to make sure that the dropout rate is constant + per replica. + + Input shape: (batch_size, replicas, gridsize, features) + + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.replica_axis = 1 + + def _get_noise_shape(self, inputs): + input_shape = list(inputs.shape) + noise_shape = input_shape[: self.replica_axis] + [1] + input_shape[self.replica_axis + 1 :] + return noise_shape diff --git a/n3fit/src/n3fit/tests/test_multidense.py b/n3fit/src/n3fit/tests/test_multidense.py index ed848ffedc..42a77e314e 100644 --- a/n3fit/src/n3fit/tests/test_multidense.py +++ b/n3fit/src/n3fit/tests/test_multidense.py @@ -4,7 +4,7 @@ from tensorflow.keras.initializers import GlorotUniform from tensorflow.keras.layers import Dense -from n3fit.backends.keras_backend.multi_dense import MultiDense +from n3fit.backends.keras_backend.multi_dense import MultiDense, MultiDropout from n3fit.model_gen import generate_nn @@ -68,9 +68,20 @@ def test_initializers(): np.testing.assert_allclose(multi_dense_weights, stacked_weights) -def main(): - test_initializers() +def test_dropout(): + replicas = 100 + x_size = 10 + features = 1 + input_shape = (1, replicas, x_size, features) + test_input = tf.ones(shape=input_shape) + dropout_layer = MultiDropout(rate=0.5, seed=44) -if __name__ == '__main__': - main() + test_output = dropout_layer(test_input, training=True) + + # Check that for every replica the same x values are dropped + zero_indices_first_replica = np.where(test_output.numpy()[0, 0, :, 0] == 0) + nonzero_indices_first_replica = np.where(test_output.numpy()[0, 0, :, 0] != 0) + + assert np.all(test_output.numpy()[:, :, zero_indices_first_replica, :] == 0) + assert np.all(test_output.numpy()[:, :, nonzero_indices_first_replica, :] != 0)