diff --git a/microwakeword/mixednet.py b/microwakeword/mixednet.py index 076d467..75cbb9f 100644 --- a/microwakeword/mixednet.py +++ b/microwakeword/mixednet.py @@ -231,7 +231,7 @@ def __call__(self, inputs): return x -class SpatialAttention(tf.keras.layers.Layer): +class SpatialAttention: """Spatial Attention Layer based on CBAM: Convolutional Block Attention Module https://arxiv.org/pdf/1807.06521v2 @@ -239,14 +239,12 @@ class SpatialAttention(tf.keras.layers.Layer): object (_type_): _description_ """ - def __init__(self, kernel_size, ring_buffer_size, **kwargs): - super().__init__(**kwargs) - + def __init__(self, kernel_size, ring_buffer_size): self.kernel_size = kernel_size self.ring_buffer_size = ring_buffer_size - def call(self, inputs): - tranposed = tf.transpose(inputs, perm=[0, 1, 3, 2]) + def __call__(self, inputs): + tranposed = tf.keras.ops.transpose(inputs, axes=[0, 1, 3, 2]) channel_avg = tf.keras.layers.AveragePooling2D( pool_size=(1, tranposed.shape[2]), strides=(1, tranposed.shape[2]) )(tranposed) @@ -254,6 +252,7 @@ def call(self, inputs): pool_size=(1, tranposed.shape[2]), strides=(1, tranposed.shape[2]) )(tranposed) pooled = tf.keras.layers.Concatenate(axis=-1)([channel_avg, channel_max]) + attention = stream.Stream( cell=tf.keras.layers.Conv2D( 1, @@ -275,12 +274,6 @@ def call(self, inputs): return net * attention - def get_config(self): - return { - "kernel_size": self.kernel_size, - "ring_buffer_size": self.ring_buffer_size, - } - def model(flags, shape, batch_size): """MixedNet model. @@ -368,7 +361,10 @@ def model(flags, shape, batch_size): if net.shape[1] > 1: if flags.spatial_attention: - net = SpatialAttention(4, net.shape[1] - 1)(net) + net = SpatialAttention( + kernel_size=4, + ring_buffer_size=net.shape[1] - 1, + )(net) else: net = stream.Stream( cell=tf.keras.layers.Identity(),