Skip to content

Commit

Permalink
Add MultiDropout
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Jan 11, 2024
1 parent 36fdf33 commit c23a59f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 14 deletions.
11 changes: 3 additions & 8 deletions n3fit/src/n3fit/backends/keras_backend/base_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,13 @@
"""

from tensorflow import expand_dims, math, nn
from tensorflow.keras.layers import ( # pylint: disable=unused-import
Dropout,
Input,
Lambda,
concatenate,
)
from tensorflow.keras.layers import Dense as KerasDense
from tensorflow.keras.layers import Input, Lambda, concatenate # pylint: disable=unused-import
from tensorflow.keras.layers import LSTM, Concatenate # pylint: disable=unused-import
from tensorflow.keras.regularizers import l1_l2

from n3fit.backends import MetaLayer
from n3fit.backends.keras_backend.multi_dense import MultiDense
from n3fit.backends.keras_backend.multi_dense import MultiDense, MultiDropout


# Custom activation functions
Expand Down Expand Up @@ -163,7 +158,7 @@ def apply_dense(xinput):
LSTM_modified,
{"kernel_initializer": "glorot_normal", "units": 5, "activation": "sigmoid"},
),
"dropout": (Dropout, {"rate": 0.0}),
"dropout": (MultiDropout, {"rate": 0.0}),
"concatenate": (Concatenate, {}),
}

Expand Down
21 changes: 20 additions & 1 deletion n3fit/src/n3fit/backends/keras_backend/multi_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import tensorflow as tf
from tensorflow.keras.initializers import Initializer
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dense, Dropout


class MultiDense(Dense):
Expand Down Expand Up @@ -177,3 +177,22 @@ def __call__(self, shape, dtype=None, **kwargs):
per_replica_weights.append(single_initializer(shape, dtype, **kwargs))

return tf.stack(per_replica_weights, axis=0)


class MultiDropout(Dropout):
"""
Dropout that broadcasts to the replica axis, to make sure that the dropout rate is constant
per replica.
Input shape: (batch_size, replicas, gridsize, features)
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.replica_axis = 1

def _get_noise_shape(self, inputs):
input_shape = list(inputs.shape)
noise_shape = input_shape[: self.replica_axis] + [1] + input_shape[self.replica_axis + 1 :]
return noise_shape
21 changes: 16 additions & 5 deletions n3fit/src/n3fit/tests/test_multidense.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tensorflow.keras.initializers import GlorotUniform
from tensorflow.keras.layers import Dense

from n3fit.backends.keras_backend.multi_dense import MultiDense
from n3fit.backends.keras_backend.multi_dense import MultiDense, MultiDropout
from n3fit.model_gen import generate_nn


Expand Down Expand Up @@ -68,9 +68,20 @@ def test_initializers():
np.testing.assert_allclose(multi_dense_weights, stacked_weights)


def main():
test_initializers()
def test_dropout():
replicas = 100
x_size = 10
features = 1
input_shape = (1, replicas, x_size, features)
test_input = tf.ones(shape=input_shape)

dropout_layer = MultiDropout(rate=0.5, seed=44)

if __name__ == '__main__':
main()
test_output = dropout_layer(test_input, training=True)

# Check that for every replica the same x values are dropped
zero_indices_first_replica = np.where(test_output.numpy()[0, 0, :, 0] == 0)
nonzero_indices_first_replica = np.where(test_output.numpy()[0, 0, :, 0] != 0)

assert np.all(test_output.numpy()[:, :, zero_indices_first_replica, :] == 0)
assert np.all(test_output.numpy()[:, :, nonzero_indices_first_replica, :] != 0)

0 comments on commit c23a59f

Please sign in to comment.