Skip to content

Commit

Permalink
Improve documentation of VAE classes
Browse files Browse the repository at this point in the history
  • Loading branch information
crstngc committed Jan 31, 2025
1 parent 5f5b385 commit c0c0735
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 9 deletions.
23 changes: 23 additions & 0 deletions scico/flax/autoencoders/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,26 @@ def sample_fn(z: ArrayLike, c: ArrayLike) -> ArrayLike:
return model.apply(params, z, c, method=model.decode_cond)

return sample_fn


def build_sample_conditional_return_latent_fn(model, params):
"""Function to generate samples from model and return intermediate
decodings too.
Args:
model: Variational autoencoder model to generate samples from.
params: Parameters of trained model.
"""

@jax.jit
def sample_fn(z, c):
"""Generate samples from latent representation conditioned
on sample class and return intermediate representations too.
Args:
z: Representation in latent space.
c: Class to generate samples from.
"""
return model.apply(params, z, c, method=model.decode_cond_return_latent)

return sample_fn
62 changes: 56 additions & 6 deletions scico/flax/autoencoders/uvarautoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,20 @@ class MultiLevelDecoder(nn.Module):
upsampling_scale: int = 2

@nn.compact
def __call__(self, xlist) -> ArrayLike:
def __call__(self, xlist, lat=False) -> ArrayLike:
"""Apply multi-level decoder.
Args:
x: The array with multi-level latent representations to be
xlist: The list with multi-level latent representations to be
decoded.
lat: Flag to indicate if the latent representations are to be
returned.
Returns:
The reconstructed signal.
The reconstructed signal and, if requested, the intermediate
representations.
"""
xlup = []
h, w = self.platent_shape
x = None
for j, nfilters in enumerate(self.num_filters):
Expand All @@ -149,9 +153,11 @@ def __call__(self, xlist) -> ArrayLike:
x = xl
else:
x = x + xl
xlup.append(x)
x = ConvUpsampleBlock(
nfilters, self.kernel_size, self.strides, self.activation_fn, self.upsampling_scale
)(x)
xlup.append(x)
h = h * self.upsampling_scale
w = w * self.upsampling_scale

Expand All @@ -164,6 +170,9 @@ def __call__(self, xlist) -> ArrayLike:
padding="CIRCULAR",
)(x)

if lat:
return x, xlup

return x


Expand Down Expand Up @@ -255,12 +264,27 @@ def setup(self):
self.class_proj = [nn.Dense(self.cond_width) for _ in range(len(self.encoder_filters))]

def encode(self, x: ArrayLike) -> Tuple[ArrayLike, ArrayLike]:
"""Encode using multiple latent representations."""
"""Encode using multiple latent representations.
Args:
x: Signals to encode.
Returns:
The mean and log variances of the encoded signals.
"""
mean, logvar = self.encoder(x)
return mean, logvar

def decode_cond(self, xlist: ArrayLike, c: ArrayLike):
"""Class-conditional decoding using multiple latent representations."""
"""Class-conditional decoding using multiple latent representations.
Args:
xlist: List of random generations in latent spaces.
c: Classes to generate samples from.
Returns:
The generated samples.
"""
xl = []
for j, z_ in enumerate(xlist):
x = self.post_latent_proj[j](z_)
Expand All @@ -269,8 +293,34 @@ def decode_cond(self, xlist: ArrayLike, c: ArrayLike):
x = self.decoder(xl)
return x

def decode_cond_return_latent(self, xlist: ArrayLike, c: ArrayLike):
"""Class-conditional decoding using multiple latent representations.
The different latent representations are returned also.
Args:
xlist: List of random generations in latent spaces.
c: Classes to generate samples from.
Returns:
The generated samples as well as the intermediate representations.
"""
xl = []
for j, z_ in enumerate(xlist):
x = self.post_latent_proj[j](z_)
x = x + self.class_proj[j](c)
xl.append(x)
x, xlup = self.decoder(xl, lat=True)
return x, xlup

def decode(self, xlist: ArrayLike):
"""Class-independent decoding using multiple latent representations."""
"""Class-independent decoding using multiple latent representations.
Args:
xlist: List of random generations in latent spaces.
Returns:
The generated samples.
"""
x = self.decoder(xlist)
return x

Expand Down
28 changes: 25 additions & 3 deletions scico/flax/autoencoders/varautoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,42 @@ def setup(self):
self.class_proj = nn.Dense(self.cond_width)

def encode(self, x: ArrayLike) -> Tuple[ArrayLike, ArrayLike]:
"""Variational encoding."""
"""Variational encoding.
Args:
x: Signals to encode.
Returns:
The mean and log variances of the encoded signals.
"""
mean, logvar = self.encoder(x)
return mean, logvar

def decode_cond(self, x: ArrayLike, c: ArrayLike):
"""Class-conditional decoding."""
"""Class-conditional decoding.
Args:
x: Random generation in latent space to decode.
c: Classes to generate samples from.
Returns:
The generated samples.
"""
assert self.cond_width > 0
x = self.post_latent_proj(x)
x = x + self.class_proj(c)
x = self.decoder(x)
return x

def decode(self, x: ArrayLike):
"""Class-independent decoding."""
"""Class-independent decoding.
Args:
x: Random generation in latent space to decode.
Returns:
The generated samples.
"""
x = self.decoder(x)
return x

Expand Down

0 comments on commit c0c0735

Please sign in to comment.