Skip to content

Commit

Permalink
Add functionality for multi-level autoencoders
Browse files Browse the repository at this point in the history
  • Loading branch information
crstngc committed Jan 26, 2025
1 parent d93cbc8 commit 5aa1a27
Show file tree
Hide file tree
Showing 4 changed files with 339 additions and 18 deletions.
18 changes: 9 additions & 9 deletions scico/flax/autoencoders/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,17 @@ class ConvPoolBlock(nn.Module):
kernel_size: A shape tuple defining the size of the convolution
filters.
strides: A shape tuple defining the size of strides in convolution.
activation: Flax function defining the activation operation to apply.
pooling: Flax function defining the pooling operation to apply.
activation_fn: Flax function defining the activation operation to apply.
pooling_fn: Flax function defining the pooling operation to apply.
window_shape: A shape tuple defining the window to reduce over in
the pooling operation.
"""

num_filters: int
kernel_size: Tuple[int, int] = (3, 3)
strides: Tuple[int, int] = (1, 1)
activation: Callable = nn.leaky_relu
pooling: Callable = nn.max_pool
activation_fn: Callable = nn.leaky_relu
pooling_fn: Callable = nn.max_pool
window_shape: Tuple[int, int] = (2, 2)

@nn.compact
Expand All @@ -164,8 +164,8 @@ def __call__(self, x: ArrayLike) -> ArrayLike:
use_bias=False,
padding="CIRCULAR",
)(x)
x = self.activation(x)
x = self.pooling(x, self.window_shape, strides=self.window_shape, padding="SAME")
x = self.activation_fn(x)
x = self.pooling_fn(x, self.window_shape, strides=self.window_shape, padding="SAME")

return x

Expand All @@ -180,14 +180,14 @@ class ConvUpsampleBlock(nn.Module):
kernel_size: A shape tuple defining the size of the convolution
filters.
strides: A shape tuple defining the size of strides in convolution.
activation: Flax function defining the activation operation to apply.
activation_fn: Flax function defining the activation operation to apply.
upsampling_scale: Integer scaling factor.
"""

num_filters: int
kernel_size: Tuple[int, int] = (3, 3)
strides: Tuple[int, int] = (1, 1)
activation: Callable = nn.leaky_relu
activation_fn: Callable = nn.leaky_relu
upsampling_scale: int = 2

@nn.compact
Expand All @@ -199,7 +199,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike:
use_bias=False,
padding="CIRCULAR",
)(x)
x = self.activation(x)
x = self.activation_fn(x)
x = upscale_nn(x, self.upsampling_scale)

return x
2 changes: 1 addition & 1 deletion scico/flax/autoencoders/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def stats_obj() -> Tuple[IterationStats, Callable]:
:class:`~.diagnostics.IterationStats` object takes care of both
printing stats to command line and storing them for further analysis.
"""
# epoch, time learning rate loss and snr (train and
# epoch, time learning rate loss and kl-loss (train and
# eval) fields
itstat_fields = {
"Epoch": "%d",
Expand Down
35 changes: 27 additions & 8 deletions scico/flax/autoencoders/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,29 @@ class VAEMetricsDict(TypedDict, total=False):
learning_rate: float


def kl_loss_fn(mean, logvar):
"""Compute KL divergence loss from given mean and log variance.
Args:
mean: Mean in latent space. For multi-level VAE this is a list
of means for each latent level.
logvar: Log variances in latent space. For multi-level VAE
this is a list of log variances for each latent level.
"""
if isinstance(mean, list): # For multi-level VAE
kl_loss = 0.0
for j, m in enumerate(mean):
reduce_dims = list(range(1, len(m.shape)))
kl_loss = kl_loss + jnp.mean(
-0.5 * jnp.sum(1 + logvar[j] - m**2 - jnp.exp(logvar[j]), axis=reduce_dims)
)
else: # For regular VAE
reduce_dims = list(range(1, len(mean.shape)))
kl_loss = jnp.mean(-0.5 * jnp.sum(1 + logvar - mean**2 - jnp.exp(logvar), axis=reduce_dims))

return kl_loss


def train_step_vae(
state: TrainState,
batch: ArrayLike,
Expand Down Expand Up @@ -94,8 +117,7 @@ def loss_fn(params: PyTree, x: ArrayLike, key: ArrayLike):
mse_loss = criterion(output, x).sum(axis=reduce_dims).mean()
# KL loss term to keep encoder output close to standard
# normal distribution.
reduce_dims = list(range(1, len(mean.shape)))
kl_loss = jnp.mean(-0.5 * jnp.sum(1 + logvar - mean**2 - jnp.exp(logvar), axis=reduce_dims))
kl_loss = kl_loss_fn(mean, logvar)
loss = mse_loss + kl_weight * kl_loss
losses = (loss, mse_loss, kl_loss)
return loss, losses
Expand Down Expand Up @@ -178,8 +200,7 @@ def loss_fn(params: PyTree, x: ArrayLike, c: ArrayLike, key: ArrayLike):
mse_loss = criterion(output, x).sum(axis=reduce_dims).mean()
# KL loss term to keep encoder output close to standard
# normal distribution.
reduce_dims = list(range(1, len(mean.shape)))
kl_loss = jnp.mean(-0.5 * jnp.sum(1 + logvar - mean**2 - jnp.exp(logvar), axis=reduce_dims))
kl_loss = kl_loss_fn(mean, logvar)
loss = mse_loss + kl_weight * kl_loss
losses = (loss, mse_loss, kl_loss)
return loss, losses
Expand Down Expand Up @@ -241,8 +262,7 @@ def eval_step_vae(
mse_loss = criterion(output, batch["image"]).sum(axis=reduce_dims).mean()
# KL loss term to keep encoder output close to standard
# normal distribution.
reduce_dims = list(range(1, len(mean.shape)))
kl_loss = jnp.mean(-0.5 * jnp.sum(1 + logvar - mean**2 - jnp.exp(logvar), axis=reduce_dims))
kl_loss = kl_loss_fn(mean, logvar)
loss = mse_loss + kl_weight * kl_loss
metrics: VAEMetricsDict = {"loss": loss, "mse": mse_loss, "kl": kl_loss}
return metrics
Expand Down Expand Up @@ -286,8 +306,7 @@ def eval_step_vae_class_conditional(
mse_loss = criterion(output, batch["image"]).sum(axis=reduce_dims).mean()
# KL loss term to keep encoder output close to standard
# normal distribution.
reduce_dims = list(range(1, len(mean.shape)))
kl_loss = jnp.mean(-0.5 * jnp.sum(1 + logvar - mean**2 - jnp.exp(logvar), axis=reduce_dims))
kl_loss = kl_loss_fn(mean, logvar)
loss = mse_loss + kl_weight * kl_loss
metrics: VAEMetricsDict = {"loss": loss, "mse": mse_loss, "kl": kl_loss}
return metrics
Expand Down
Loading

0 comments on commit 5aa1a27

Please sign in to comment.