Skip to content

Commit

Permalink
Add more blocks to autoencoders
Browse files Browse the repository at this point in the history
  • Loading branch information
crstngc committed Jan 26, 2025
1 parent d306e1c commit d93cbc8
Showing 1 changed file with 82 additions and 0 deletions.
82 changes: 82 additions & 0 deletions scico/flax/autoencoders/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import flax.linen as nn
from flax.core import Scope # noqa
from flax.linen.module import _Sentinel # noqa
from scico.flax.blocks import upscale_nn

# The imports of Scope and _Sentinel (above) are required to silence
# "cannot resolve forward reference" warnings when building sphinx api
Expand Down Expand Up @@ -121,3 +122,84 @@ def __call__(self, x: ArrayLike) -> ArrayLike:
x = x.reshape((x.shape[0], -1))

return x


class ConvPoolBlock(nn.Module):
"""Define convolution and pooling Flax block.
Args:
num_filters: Number of filters in the convolutional layer of the
block. Corresponds to the number of channels in the output
tensor.
kernel_size: A shape tuple defining the size of the convolution
filters.
strides: A shape tuple defining the size of strides in convolution.
activation: Flax function defining the activation operation to apply.
pooling: Flax function defining the pooling operation to apply.
window_shape: A shape tuple defining the window to reduce over in
the pooling operation.
"""

num_filters: int
kernel_size: Tuple[int, int] = (3, 3)
strides: Tuple[int, int] = (1, 1)
activation: Callable = nn.leaky_relu
pooling: Callable = nn.max_pool
window_shape: Tuple[int, int] = (2, 2)

@nn.compact
def __call__(self, x: ArrayLike) -> ArrayLike:
"""Apply convolution followed by activation and pooling.
Args:
inputs: The array to be transformed.
Returns:
The transformed input.
"""
x = nn.Conv(
self.num_filters,
self.kernel_size,
strides=self.strides,
use_bias=False,
padding="CIRCULAR",
)(x)
x = self.activation(x)
x = self.pooling(x, self.window_shape, strides=self.window_shape, padding="SAME")

return x


class ConvUpsampleBlock(nn.Module):
"""Define convolution and upsample Flax block.
Args:
num_filters: Number of filters in the convolutional layer of the
block. Corresponds to the number of channels in the output
tensor.
kernel_size: A shape tuple defining the size of the convolution
filters.
strides: A shape tuple defining the size of strides in convolution.
activation: Flax function defining the activation operation to apply.
upsampling_scale: Integer scaling factor.
"""

num_filters: int
kernel_size: Tuple[int, int] = (3, 3)
strides: Tuple[int, int] = (1, 1)
activation: Callable = nn.leaky_relu
upsampling_scale: int = 2

@nn.compact
def __call__(self, x: ArrayLike) -> ArrayLike:
x = nn.ConvTranspose(
self.num_filters,
self.kernel_size,
strides=self.strides,
use_bias=False,
padding="CIRCULAR",
)(x)
x = self.activation(x)
x = upscale_nn(x, self.upsampling_scale)

return x

0 comments on commit d93cbc8

Please sign in to comment.