From f7b96980404740eeafb5e344999d6b7cfb9a0743 Mon Sep 17 00:00:00 2001 From: Edward Date: Fri, 11 Nov 2022 17:18:37 +0800 Subject: [PATCH 1/2] Typo fixed: _make_wondows -> _make_windows --- model/layers/moat_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model/layers/moat_blocks.py b/model/layers/moat_blocks.py index c0419d2..ea96e4e 100644 --- a/model/layers/moat_blocks.py +++ b/model/layers/moat_blocks.py @@ -460,7 +460,7 @@ def build(self, input_shape: list[int]) -> None: kernel_initializer=self._config.kernel_initializer, bias_initializer=self._config.bias_initializer) - def _make_wondows(self, inputs): + def _make_windows(self, inputs): _, height, width, channels = inputs.get_shape().with_rank(4).as_list() inputs = tf.reshape( inputs, @@ -531,7 +531,7 @@ def call(self, inputs: tf.Tensor, training: bool) -> tf.Tensor: def _func(output): output = self._attention_norm(output) _, height, width, _ = output.get_shape().with_rank(4).as_list() - output = self._make_wondows(output) + output = self._make_windows(output) output = self._attention(output) output = self._remove_windows(output, height, width) return output From 63d2a23d7f1df6447ca601e651eef145b0d57f4d Mon Sep 17 00:00:00 2001 From: Edward Date: Fri, 11 Nov 2022 22:58:53 +0800 Subject: [PATCH 2/2] Fixed global window implementation --- model/layers/moat_blocks.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/model/layers/moat_blocks.py b/model/layers/moat_blocks.py index ea96e4e..4a248be 100644 --- a/model/layers/moat_blocks.py +++ b/model/layers/moat_blocks.py @@ -530,10 +530,18 @@ def call(self, inputs: tf.Tensor, training: bool) -> tf.Tensor: attention_shortcut = output def _func(output): output = self._attention_norm(output) - _, height, width, _ = output.get_shape().with_rank(4).as_list() - output = self._make_windows(output) + _, height, width, channels = output.get_shape().with_rank(4).as_list() + + if self._config.window_size: + output = self._make_windows(output) + output = self._attention(output) - output = self._remove_windows(output, height, width) + + if self._config.window_size: + output = self._remove_windows(output, height, width) + else: + output = tf.reshape(output, [-1, height, width, channels]) + return output func = _func