Skip to content

Commit

Permalink
Update ARAD example; update and parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
alanjvano committed Nov 30, 2023
1 parent de9cb90 commit b650d05
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
47 changes: 45 additions & 2 deletions examples/modeling/arad.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
import tensorflow as tf
from riid.models.neural_nets.arad import ARAD, ARADv1TF, ARADv2TF

import numpy as np

from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer


if len(sys.argv) == 2:
import matplotlib
matplotlib.use("Agg")
Expand All @@ -33,8 +40,44 @@ def show_summaries(model):
pass


arad_v1 = ARAD(model=ARADv1TF())
arad_v1 = ARAD(model=ARADv1TF(latent_dim=5))
show_summaries(arad_v1.model)

arad_v2 = ARAD(model=ARADv2TF())
arad_v2 = ARAD(model=ARADv2TF(latent_dim=8))
show_summaries(arad_v2.model)

# Generate some training data
fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3).generate(1)

static_synth = StaticSynthesizer(
samples_per_seed=250,
snr_function="log10",
return_fg=False,
return_gross=True,
)
_, train_ss = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss)
train_ss.normalize()

print("training ARADv1...")
arad_v1.fit(train_ss, epochs=50, verbose=True)
print("training ARADv2...")
arad_v2.fit(train_ss, epochs=50, verbose=True)

# Generate some test data
static_synth.samples_per_seed = 50
_, test_ss = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss)
test_ss.normalize()

# Predict
arad_v1_reconstructions = arad_v1.predict(test_ss, verbose=True)
recon_errors = test_ss.info["recon_error"].values
ood_decisions = test_ss.info["ood"].values
print((f"ARADv1: mean reconstruction error = {np.mean(recon_errors):.3f} (KLD)\n"
f" OOD rate = {np.mean(ood_decisions):.2f}"))

arad_v2_reconstructions = arad_v2.predict(test_ss, verbose=True)
recon_errors = test_ss.info["recon_error"].values
ood_decisions = test_ss.info["ood"].values
print((f"ARADv2: mean reconstruction error = {np.mean(recon_errors):.3f} (JSD)\n"
f" OOD rate = {np.mean(ood_decisions):.2f}"))
8 changes: 6 additions & 2 deletions riid/models/neural_nets/arad.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def __init__(self, latent_dim: int = 5):
decoder_input = Input(shape=(latent_dim,), name="decoder_input")
x = Dense(units=40, name="D2")(decoder_input)
x = Dropout(rate=0.1, name="D2_dropout")(x)
decoder_output = Dense(units=128, name="D3")(x)
decoder_output = Dense(
units=128,
activation=sigmoid, # unclear from paper, seems to be necessary
name="D3"
)(x)
decoder = Model(decoder_input, decoder_output, name="decoder")

# Autoencoder
Expand Down Expand Up @@ -129,7 +133,7 @@ class ARADv2TF(Model):
- Ghawaly Jr, James M., et al. "Characterization of the Autoencoder Radiation Anomaly Detection
(ARAD) model." Engineering Applications of Artificial Intelligence 111 (2022): 104761.
"""
def __init__(self, latent_dim: int = 5):
def __init__(self, latent_dim: int = 8):
"""
Args:
latent_dim: dimension of internal latent represention.
Expand Down

0 comments on commit b650d05

Please sign in to comment.