Skip to content

Commit

Permalink
Fix ARAD predictions from onnx model.
Browse files Browse the repository at this point in the history
  • Loading branch information
alanjvano committed Jan 15, 2024
1 parent 99df3cf commit 03c06ba
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
6 changes: 3 additions & 3 deletions examples/modeling/arad.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid.models.neural_nets.arad import ARAD, ARADv1TF, ARADv2TF
from riid.models.neural_nets.arad import ARAD

# Config
rng = np.random.default_rng(42)
Expand Down Expand Up @@ -39,13 +39,13 @@

# Train the models
print("Training ARADv1...")
arad_v1 = ARAD(model=ARADv1TF())
arad_v1 = ARAD(arad_version="v1")
arad_v1.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE)
arad_v1.predict(gross_train_ss)
v1_ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE)

print("Training ARADv2...")
arad_v2 = ARAD(model=ARADv2TF())
arad_v2 = ARAD(arad_version="v2")
arad_v2.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE)
arad_v2.predict(gross_train_ss)
v2_ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE)
Expand Down
27 changes: 19 additions & 8 deletions riid/models/neural_nets/arad.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,18 @@ def call(self, x):
class ARAD(PyRIIDModel):
"""PyRIID-compatible ARAD model to work with SampleSets.
"""
def __init__(self, model: Model = ARADv2TF()):
def __init__(self, arad_version: str = "v1", latent_dim=None):
"""
Args:
model: instantiated model of the desired version of ARAD to use.
arad_version: version of ARAD to use, can be "v1" or "v2".
latent_dim: dimension of internal latent represention, by default will
select version used referenced the paper for the corresponding ARAD version
"""
super().__init__()

self.model = model
self.arad_version = arad_version.lower()
self.latent_dim = latent_dim
self.model = None

def _check_spectra(self, ss):
"""Checks if SampleSet spectra are compatible with ARAD models."""
Expand Down Expand Up @@ -280,9 +284,10 @@ def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2,

x = ss.get_samples().astype(float)

is_v1 = isinstance(self.model, ARADv1TF)
is_v2 = isinstance(self.model, ARADv2TF)
is_v1 = self.arad_version == "v1"
is_v2 = self.arad_version == "v2"
if is_v1:
model_class = ARADv1TF
optimizer = tf.keras.optimizers.Nadam(
learning_rate=1e-4
)
Expand All @@ -294,6 +299,7 @@ def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2,
batch_size = 64
y = None
elif is_v2:
model_class = ARADv2TF
optimizer = tf.keras.optimizers.Adam(
learning_rate=0.01,
epsilon=0.05
Expand All @@ -306,7 +312,12 @@ def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2,
batch_size = 32
y = x
else:
raise ValueError("Invalid model provided, must be ARADv1TF or ARADv2TF.")
raise ValueError("Invalid ARAD version provided, must be v1 or v2.")

if not self.model:
self.model = model_class(
latent_dim=self.latent_dim
) if self.latent_dim else model_class()

self.model.compile(
loss=loss_func,
Expand Down Expand Up @@ -361,8 +372,8 @@ def predict(self, ss: SampleSet, verbose=False):

reconstructed_spectra = self.get_predictions(x, verbose=verbose)

is_v1 = isinstance(self.model, ARADv1TF)
is_v2 = isinstance(self.model, ARADv2TF)
is_v1 = self.arad_version == "v1"
is_v2 = self.arad_version == "v2"
if is_v1:
# Entropy is equivalent to KL Divergence with how it is used here
reconstruction_metric = entropy
Expand Down

0 comments on commit 03c06ba

Please sign in to comment.