Skip to content

Commit

Permalink
Add PyRIID wrapper for ARAD models, still bug with ARADv1
Browse files Browse the repository at this point in the history
  • Loading branch information
alanjvano committed Nov 20, 2023
1 parent 700d9a2 commit 46f70d5
Showing 1 changed file with 107 additions and 11 deletions.
118 changes: 107 additions & 11 deletions riid/models/neural_nets/arad.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
MaxPool1D, Reshape, UpSampling1D)
from keras.models import Model
from keras.regularizers import L1L2, L2, Regularizer, _check_penalty_number
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from scipy.spatial.distance import jensenshannon
from scipy.stats import entropy

from riid.data.sampleset import SampleSet
from riid.losses import mish
from riid.losses import mish, jensen_shannon_distance
from riid.models import TFModelBase


Expand Down Expand Up @@ -75,13 +78,14 @@ def __init__(self, latent_dim: int = 5):
b3 = self._get_branch(encoder_input, b3_config, 0.1, "softplus", "B3", 5)

x = Concatenate(axis=1)([b1, b2, b3])
x = Reshape((15,), name="reshape")(x)
x = Dense(units=latent_dim, kernel_regularizer=KLDRegularizer(sparsity=0.5),
name="D1_latent_space")(x)
encoder_output = BatchNormalization(name="D1_batch_norm")(x)
encoder = Model(encoder_input, encoder_output, name="encoder")

# Decoder
decoder_input = Input(shape=(latent_dim, 1), name="decoder_input")
decoder_input = Input(shape=(latent_dim,), name="decoder_input")
x = Dense(units=40, name="D2")(decoder_input)
x = Dropout(rate=0.1, name="D2_dropout")(x)
decoder_output = Dense(units=128, name="D3")(x)
Expand Down Expand Up @@ -199,7 +203,7 @@ def __init__(self, latent_dim: int = 5):
upsample_name = f"US{i}"
x = UpSampling1D(size=max_pool_size, name=upsample_name)(x)
x = BatchNormalization(name=f"{upsample_name}_batch_norm")(x)
decoder_output = Conv1DTranspose(
x = Conv1DTranspose(
kernel_size=7,
strides=1,
filters=1,
Expand All @@ -208,6 +212,7 @@ def __init__(self, latent_dim: int = 5):
kernel_initializer=GlorotNormal,
name=f"tconv{i}"
)(x)
decoder_output = Reshape((128,), name="reshape_final")(x)
decoder = Model(decoder_input, decoder_output, name="decoder")

# Autoencoder
Expand Down Expand Up @@ -239,12 +244,103 @@ def __init__(self, model: Model = ARADv2TF()):

# TODO: enable saving as ONNX

def fit(self, ss: SampleSet):
pass # TODO: fit
def fit(self, ss: SampleSet, epochs: int = 300, es_verbose: int = 0,
verbose: bool = False):
"""Fit a model to the given `SampleSet`."""
if ss.n_samples <= 0:
raise ValueError("No spectr[a|um] provided!")

norm_ss = ss[:]
norm_ss.downsample_spectra(target_bins=128)
norm_ss.normalize()
spectra = norm_ss.get_samples().astype(float)

if isinstance(self.model, ARADv1TF):
optimizer = tf.keras.optimizers.Nadam(
learning_rate=1e-4
)
# loss_func = tf.keras.losses.LogCosh()
loss_func = jensen_shannon_distance
es_patience = 5
es_min_delta = 1e-7
lr_sched_patience = 3
lr_sched_min_delta = 1e-8
batch_size = 64

elif isinstance(self.model, ARADv2TF):
optimizer = tf.keras.optimizers.Adam(
learning_rate=0.01,
epsilon=0.05
)
loss_func = jensen_shannon_distance
es_patience = 6
es_min_delta = 1e-4
lr_sched_patience = 3
lr_sched_min_delta = 1e-4
batch_size = 32

self.model.compile(
loss=loss_func,
optimizer=optimizer
)
callbacks = [
EarlyStopping(
monitor="val_loss",
patience=es_patience,
verbose=es_verbose,
restore_best_weights=True,
mode="min",
min_delta=es_min_delta
),
ReduceLROnPlateau(
monitor="val_loss",
factor=0.1,
patience=lr_sched_patience,
min_delta=lr_sched_min_delta
)
]

history = self.model.fit(
spectra,
spectra,
epochs=epochs,
verbose=verbose,
validation_split=0.2,
callbacks=callbacks,
shuffle=True,
batch_size=batch_size
)
self.history = history.history

return history

def predict(self, ss: SampleSet, ood_threshold: float = 0.5,
verbose=False):
"""Generate reconstructions for given `SampleSet`.
Args:
ss: `SampleSet` of `n` spectra where `n` >= 1
ood_threshold: reconstruction error threshold for OOD spectra
Returns:
reconstructed_spectra: output of ARAD model
"""
norm_ss = ss[:]
norm_ss.downsample_spectra(target_bins=128)
norm_ss.normalize()
spectra = norm_ss.get_samples().astype(float)

reconstructed_spectra = self.model.predict(spectra, verbose=verbose)

if isinstance(self.model, ARADv1TF):
reconstruction_metric = entropy

elif isinstance(self.model, ARADv2TF):
reconstruction_metric = jensenshannon

reconstruction_errors = reconstruction_metric(spectra, reconstructed_spectra, axis=1)
ood_decisions = reconstruction_errors > ood_threshold
ss.info["recon_error"] = reconstruction_errors
ss.info["ood"] = ood_decisions

def predict(self, ss: SampleSet, ood_threshold: float = 0.5):
pass
# TODO: predict
# TODO: save results in as:
# SampleSet.info.ood
# SampleSet.info.recon_error
return reconstructed_spectra

0 comments on commit 46f70d5

Please sign in to comment.