From d93cbc8b2ce628a6ea8674dad991f65b8b830aa3 Mon Sep 17 00:00:00 2001 From: crstngc Date: Sat, 25 Jan 2025 19:08:00 -0700 Subject: [PATCH] Add more blocks to autoencoders --- scico/flax/autoencoders/blocks.py | 82 +++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/scico/flax/autoencoders/blocks.py b/scico/flax/autoencoders/blocks.py index e95cb4d3..173ae835 100644 --- a/scico/flax/autoencoders/blocks.py +++ b/scico/flax/autoencoders/blocks.py @@ -19,6 +19,7 @@ import flax.linen as nn from flax.core import Scope # noqa from flax.linen.module import _Sentinel # noqa +from scico.flax.blocks import upscale_nn # The imports of Scope and _Sentinel (above) are required to silence # "cannot resolve forward reference" warnings when building sphinx api @@ -121,3 +122,84 @@ def __call__(self, x: ArrayLike) -> ArrayLike: x = x.reshape((x.shape[0], -1)) return x + + +class ConvPoolBlock(nn.Module): + """Define convolution and pooling Flax block. + + Args: + num_filters: Number of filters in the convolutional layer of the + block. Corresponds to the number of channels in the output + tensor. + kernel_size: A shape tuple defining the size of the convolution + filters. + strides: A shape tuple defining the size of strides in convolution. + activation: Flax function defining the activation operation to apply. + pooling: Flax function defining the pooling operation to apply. + window_shape: A shape tuple defining the window to reduce over in + the pooling operation. + """ + + num_filters: int + kernel_size: Tuple[int, int] = (3, 3) + strides: Tuple[int, int] = (1, 1) + activation: Callable = nn.leaky_relu + pooling: Callable = nn.max_pool + window_shape: Tuple[int, int] = (2, 2) + + @nn.compact + def __call__(self, x: ArrayLike) -> ArrayLike: + """Apply convolution followed by activation and pooling. + + Args: + inputs: The array to be transformed. + + Returns: + The transformed input. + """ + x = nn.Conv( + self.num_filters, + self.kernel_size, + strides=self.strides, + use_bias=False, + padding="CIRCULAR", + )(x) + x = self.activation(x) + x = self.pooling(x, self.window_shape, strides=self.window_shape, padding="SAME") + + return x + + +class ConvUpsampleBlock(nn.Module): + """Define convolution and upsample Flax block. + + Args: + num_filters: Number of filters in the convolutional layer of the + block. Corresponds to the number of channels in the output + tensor. + kernel_size: A shape tuple defining the size of the convolution + filters. + strides: A shape tuple defining the size of strides in convolution. + activation: Flax function defining the activation operation to apply. + upsampling_scale: Integer scaling factor. + """ + + num_filters: int + kernel_size: Tuple[int, int] = (3, 3) + strides: Tuple[int, int] = (1, 1) + activation: Callable = nn.leaky_relu + upsampling_scale: int = 2 + + @nn.compact + def __call__(self, x: ArrayLike) -> ArrayLike: + x = nn.ConvTranspose( + self.num_filters, + self.kernel_size, + strides=self.strides, + use_bias=False, + padding="CIRCULAR", + )(x) + x = self.activation(x) + x = upscale_nn(x, self.upsampling_scale) + + return x