diff --git a/fit/datamodules/tomo_rec/TRecDataModule.py b/fit/datamodules/tomo_rec/TRecDataModule.py index 721fdd6..0112e8e 100644 --- a/fit/datamodules/tomo_rec/TRecDataModule.py +++ b/fit/datamodules/tomo_rec/TRecDataModule.py @@ -105,9 +105,10 @@ def prepare_data(self, *args, **kwargs): assert mnist_train.shape[1] == self.gt_shape assert mnist_train.shape[2] == self.gt_shape - mnist_train = np.clip(mnist_train, 50, 255) - mnist_val = np.clip(mnist_val, 50, 255) - mnist_test = np.clip(mnist_test, 50, 255) + circle = self.__get_circle__() + mnist_train = circle * np.clip(mnist_train, 50, 255) + mnist_val = circle * np.clip(mnist_val, 50, 255) + mnist_test = circle * np.clip(mnist_test, 50, 255) self.mean = mnist_train.mean() self.std = mnist_train.std() @@ -116,7 +117,6 @@ def prepare_data(self, *args, **kwargs): mnist_val = normalize(mnist_val, self.mean, self.std) mnist_test = normalize(mnist_test, self.mean, self.std) - circle = self.__get_circle__() mnist_train *= circle mnist_val *= circle mnist_test *= circle