diff --git a/examples/modeling/arad.py b/examples/modeling/arad.py index efecb2e..36dc06f 100644 --- a/examples/modeling/arad.py +++ b/examples/modeling/arad.py @@ -9,7 +9,7 @@ from riid.data.synthetic import get_dummy_seeds from riid.data.synthetic.seed import SeedMixer from riid.data.synthetic.static import StaticSynthesizer -from riid.models.neural_nets.arad import ARAD, ARADv1TF, ARADv2TF +from riid.models.neural_nets.arad import ARAD # Config rng = np.random.default_rng(42) @@ -39,13 +39,13 @@ # Train the models print("Training ARADv1...") -arad_v1 = ARAD(model=ARADv1TF()) +arad_v1 = ARAD(arad_version="v1") arad_v1.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE) arad_v1.predict(gross_train_ss) v1_ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE) print("Training ARADv2...") -arad_v2 = ARAD(model=ARADv2TF()) +arad_v2 = ARAD(arad_version="v2") arad_v2.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE) arad_v2.predict(gross_train_ss) v2_ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE) diff --git a/riid/models/neural_nets/arad.py b/riid/models/neural_nets/arad.py index 417b8ea..7a447a7 100644 --- a/riid/models/neural_nets/arad.py +++ b/riid/models/neural_nets/arad.py @@ -238,14 +238,18 @@ def call(self, x): class ARAD(PyRIIDModel): """PyRIID-compatible ARAD model to work with SampleSets. """ - def __init__(self, model: Model = ARADv2TF()): + def __init__(self, arad_version: str = "v1", latent_dim=None): """ Args: - model: instantiated model of the desired version of ARAD to use. + arad_version: version of ARAD to use, can be "v1" or "v2". + latent_dim: dimension of internal latent represention, by default will + select version used referenced the paper for the corresponding ARAD version """ super().__init__() - self.model = model + self.arad_version = arad_version.lower() + self.latent_dim = latent_dim + self.model = None def _check_spectra(self, ss): """Checks if SampleSet spectra are compatible with ARAD models.""" @@ -280,9 +284,10 @@ def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2, x = ss.get_samples().astype(float) - is_v1 = isinstance(self.model, ARADv1TF) - is_v2 = isinstance(self.model, ARADv2TF) + is_v1 = self.arad_version == "v1" + is_v2 = self.arad_version == "v2" if is_v1: + model_class = ARADv1TF optimizer = tf.keras.optimizers.Nadam( learning_rate=1e-4 ) @@ -294,6 +299,7 @@ def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2, batch_size = 64 y = None elif is_v2: + model_class = ARADv2TF optimizer = tf.keras.optimizers.Adam( learning_rate=0.01, epsilon=0.05 @@ -306,7 +312,12 @@ def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2, batch_size = 32 y = x else: - raise ValueError("Invalid model provided, must be ARADv1TF or ARADv2TF.") + raise ValueError("Invalid ARAD version provided, must be v1 or v2.") + + if not self.model: + self.model = model_class( + latent_dim=self.latent_dim + ) if self.latent_dim else model_class() self.model.compile( loss=loss_func, @@ -361,8 +372,8 @@ def predict(self, ss: SampleSet, verbose=False): reconstructed_spectra = self.get_predictions(x, verbose=verbose) - is_v1 = isinstance(self.model, ARADv1TF) - is_v2 = isinstance(self.model, ARADv2TF) + is_v1 = self.arad_version == "v1" + is_v2 = self.arad_version == "v2" if is_v1: # Entropy is equivalent to KL Divergence with how it is used here reconstruction_metric = entropy