From 32a50a9d31c42bd14e13fb445fe8d412aa786c9c Mon Sep 17 00:00:00 2001 From: Alan Van Omen <46762315+alanjvano@users.noreply.github.com> Date: Wed, 10 Jan 2024 09:36:27 -0700 Subject: [PATCH] Add `ARADLatentPredictor`; standardize model format. Standardizing model format involved the following: - Remove ONNX runtime dependency - Change saving as ONNX to a one-way export - Change model saving and loading for all models use the same JSON format - Fix various bugs with Co-authored-by: Tyler Morrow --- examples/modeling/arad.py | 54 +-- examples/modeling/arad_latent_prediction.py | 82 ++++ .../modeling/label_proportion_estimation.py | 9 +- pyproject.toml | 2 +- riid/__init__.py | 2 + riid/data/sampleset.py | 4 +- riid/models/__init__.py | 271 +++++------ riid/models/bayes.py | 59 +-- riid/models/neural_nets/__init__.py | 254 ++++------ riid/models/neural_nets/arad.py | 438 +++++++++++++++--- tests/anomaly_tests.py | 4 +- tests/model_tests.py | 103 +++- 12 files changed, 841 insertions(+), 441 deletions(-) create mode 100644 examples/modeling/arad_latent_prediction.py diff --git a/examples/modeling/arad.py b/examples/modeling/arad.py index efecb2e5..846975fe 100644 --- a/examples/modeling/arad.py +++ b/examples/modeling/arad.py @@ -9,7 +9,7 @@ from riid.data.synthetic import get_dummy_seeds from riid.data.synthetic.seed import SeedMixer from riid.data.synthetic.static import StaticSynthesizer -from riid.models.neural_nets.arad import ARAD, ARADv1TF, ARADv2TF +from riid.models.neural_nets.arad import ARADv1, ARADv2 # Config rng = np.random.default_rng(42) @@ -37,47 +37,33 @@ _, gross_train_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss) gross_train_ss.normalize() -# Train the models -print("Training ARADv1...") -arad_v1 = ARAD(model=ARADv1TF()) -arad_v1.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE) -arad_v1.predict(gross_train_ss) -v1_ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE) - -print("Training ARADv2...") -arad_v2 = ARAD(model=ARADv2TF()) -arad_v2.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE) -arad_v2.predict(gross_train_ss) -v2_ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE) - # Generate test data static_synth.samples_per_seed = TEST_SAMPLES_PER_SEED _, test_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss) test_ss.normalize() -# Predict +# Train the models +results = {} +models = [ARADv1, ARADv2] +for model_class in models: + arad = model_class() + model_name = arad.__class__.__name__ -arad_v1_reconstructions = arad_v1.predict(test_ss, verbose=True) -arad_v1_ood = test_ss.info.recon_error.values > v1_ood_threshold -arad_v1_false_positive_rate = arad_v1_ood.mean() -arad_v1_mean_recon_error = test_ss.info.recon_error.values.mean() + print(f"Training and testing {model_name}...") + arad.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE) + arad.predict(gross_train_ss) + ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE) -arad_v2_reconstructions = arad_v2.predict(test_ss, verbose=True) -arad_v2_ood = test_ss.info.recon_error.values > v2_ood_threshold -arad_v2_false_positive_rate = arad_v2_ood.mean() -arad_v2_mean_recon_error = test_ss.info.recon_error.values.mean() + reconstructions = arad.predict(test_ss, verbose=True) + ood = test_ss.info.recon_error.values > ood_threshold + false_positive_rate = ood.mean() + mean_recon_error = test_ss.info.recon_error.values.mean() -results = { - "ARADv1": { - "ood_threshold": f"KLD={v1_ood_threshold:.4f}", - "mean_recon_error": arad_v1_mean_recon_error, - "false_positive_rate": arad_v1_false_positive_rate, - }, - "ARADv2": { - "ood_threshold": f"JSD={v2_ood_threshold:.4f}", - "mean_recon_error": arad_v2_mean_recon_error, - "false_positive_rate": arad_v2_false_positive_rate, + results[model_name] = { + "ood_threshold": f"{ood_threshold:.4f}", + "mean_recon_error": mean_recon_error, + "false_positive_rate": false_positive_rate, } -} + print(f"Target False Positive Rate: {1-OOD_QUANTILE:.4f}") print(pd.DataFrame.from_dict(results)) diff --git a/examples/modeling/arad_latent_prediction.py b/examples/modeling/arad_latent_prediction.py new file mode 100644 index 00000000..fed2d805 --- /dev/null +++ b/examples/modeling/arad_latent_prediction.py @@ -0,0 +1,82 @@ +# Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +# Under the terms of Contract DE-NA0003525 with NTESS, +# the U.S. Government retains certain rights in this software. +"""This example demonstrates how to train a regressor or classifier branch +from an ARAD latent space. +""" +import numpy as np +from sklearn.metrics import f1_score, mean_squared_error + +from riid.data.synthetic import get_dummy_seeds +from riid.data.synthetic.seed import SeedMixer +from riid.data.synthetic.static import StaticSynthesizer +from riid.models.neural_nets.arad import ARADv2, ARADLatentPredictor + +# Config +rng = np.random.default_rng(42) +VERBOSE = True +# Some of the following parameters are set low because this example runs on GitHub Actions and +# we don't want it taking a bunch of time. +# When running this locally, change the values per their corresponding comment, otherwise +# the results likely will not be meaningful. +EPOCHS = 5 # Change this to 20+ +N_MIXTURES = 50 # Change this to 1000+ +TRAIN_SAMPLES_PER_SEED = 5 # Change this to 20+ +TEST_SAMPLES_PER_SEED = 5 + +# Generate training data +fg_seeds_ss, bg_seeds_ss = get_dummy_seeds(n_channels=128, rng=rng).split_fg_and_bg() +mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=rng).generate(N_MIXTURES) +static_synth = StaticSynthesizer( + samples_per_seed=TRAIN_SAMPLES_PER_SEED, + snr_function_args=(0, 0), + return_fg=False, + return_gross=True, + rng=rng, +) +_, gross_train_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss) +gross_train_ss.normalize() + +# Generate test data +static_synth.samples_per_seed = TEST_SAMPLES_PER_SEED +_, test_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss) +test_ss.normalize() + +# Train ARAD model +print("Training ARAD") +arad_v2 = ARADv2() +arad_v2.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE) + +# Train regressor to predict SNR +print("Training Regressor") +arad_regressor = ARADLatentPredictor() +_ = arad_regressor.fit( + arad_v2.model, + gross_train_ss, + target_info_columns=["live_time"], + epochs=10, + batch_size=5, + verbose=VERBOSE, +) +regression_predictions = arad_regressor.predict(test_ss) +regression_score = mean_squared_error(gross_train_ss.info.live_time, regression_predictions) +print("Regressor MSE: {:.3f}".format(regression_score)) + +# Train classifier to predict isotope +print("Training Classifier") +arad_classifier = ARADLatentPredictor( + loss="categorical_crossentropy", + metrics=("accuracy", "categorical_crossentropy"), + final_activation="softmax" +) +arad_classifier.fit( + arad_v2.model, + gross_train_ss, + target_level="Isotope", + epochs=10, + batch_size=5, + verbose=VERBOSE, +) +arad_classifier.predict(test_ss) +classification_score = f1_score(test_ss.get_labels(), test_ss.get_predictions(), average="micro") +print("Classification F1 Score: {:.3f}".format(classification_score)) diff --git a/examples/modeling/label_proportion_estimation.py b/examples/modeling/label_proportion_estimation.py index 271f4cd3..8f3a4ef5 100644 --- a/examples/modeling/label_proportion_estimation.py +++ b/examples/modeling/label_proportion_estimation.py @@ -45,7 +45,6 @@ optimizer="RMSprop", learning_rate=1e-2, hidden_layer_activation="relu", - l2_alpha=1e-4, dropout=0.05, ) @@ -54,7 +53,7 @@ bg_seeds_ss, bg_ss, batch_size=10, - epochs=10, + epochs=2, validation_split=0.2, verbose=True, bg_cps=300 @@ -74,8 +73,9 @@ ) print(f"Mean Test MAE: {test_meas.mean():.3f}") -# Save model in ONNX format -model_info_path, model_path = model.save("./model.onnx") +# Save model +model_path = "./model.json" +model.save(model_path, overwrite=True) loaded_model = LabelProportionEstimator() loaded_model.load(model_path) @@ -89,5 +89,4 @@ print(f"Mean Test MAE: {test_maes.mean():.3f}") # Clean up model file - remove this if you want to keep the model -os.remove(model_info_path) os.remove(model_path) diff --git a/pyproject.toml b/pyproject.toml index 7ad393b0..c238fbd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,9 +56,9 @@ dependencies = [ "pyyaml ==6.0.*", "seaborn ==0.12.*", "tf2onnx ==1.14.*", + "onnx ==1.14.1", "tqdm ==4.65.*", "numpy ==1.23.*", - "onnxruntime ==1.15.*", "pandas ==2.0.*", "parmap ==1.6.*", "pythonnet ==3.0.*; platform_system == 'Windows'", diff --git a/riid/__init__.py b/riid/__init__.py index 93ae0207..4e471878 100644 --- a/riid/__init__.py +++ b/riid/__init__.py @@ -20,6 +20,8 @@ SAMPLESET_HDF_FILE_EXTENSION = ".h5" SAMPLESET_JSON_FILE_EXTENSION = ".json" PCF_FILE_EXTENSION = ".pcf" +ONNX_MODEL_FILE_EXTENSION = ".onnx" +TFLITE_MODEL_FILE_EXTENSION = ".tflite" RIID = "riid" __version__ = get_distribution(RIID).version diff --git a/riid/data/sampleset.py b/riid/data/sampleset.py index 4931219f..9e520428 100644 --- a/riid/data/sampleset.py +++ b/riid/data/sampleset.py @@ -32,7 +32,7 @@ _pcf_to_dict, _unpack_compressed_text_buffer) -class SpectraState(Enum): +class SpectraState(int, Enum): """States in which SampleSet spectra can exist.""" Unknown = 0 Counts = 1 @@ -40,7 +40,7 @@ class SpectraState(Enum): L2Normalized = 3 -class SpectraType(Enum): +class SpectraType(int, Enum): """Types for SampleSet spectra.""" Unknown = 0 Background = 1 diff --git a/riid/models/__init__.py b/riid/models/__init__.py index 261d0519..1ab110e4 100644 --- a/riid/models/__init__.py +++ b/riid/models/__init__.py @@ -5,22 +5,29 @@ import json import os import uuid -import warnings +from abc import abstractmethod from enum import Enum import numpy as np -import onnxruntime -import pandas as pd import tensorflow as tf import tf2onnx +from keras.models import Model +from keras.utils import get_custom_objects import riid from riid.data.labeling import label_to_index_element from riid.data.sampleset import SampleSet, SpectraState +from riid.losses import mish from riid.metrics import multi_f1, single_f1 +get_custom_objects().update({ + "multi_f1": multi_f1, + "single_f1": single_f1, + "mish": mish, +}) -class ModelInput(Enum): + +class ModelInput(int, Enum): """Enumerates the potential input sources for a model.""" GrossSpectrum = 0 BackgroundSpectrum = 1 @@ -28,14 +35,13 @@ class ModelInput(Enum): class PyRIIDModel: - """Base class for TensorFlow models.""" - - CUSTOM_OBJECTS = {"multi_f1": multi_f1, "single_f1": single_f1} - SUPPORTED_SAVE_EXTS = {"H5": ".h5", "ONNX": ".onnx"} + """Base class for PyRIID models.""" def __init__(self, *args, **kwargs): self._info = {} - self._temp_file_path = "temp_model_file" + riid.SAMPLESET_HDF_FILE_EXTENSION + self._temp_file_path = "temp_model.json" + self._custom_objects = {} + self._initialize_info() @property def seeds(self): @@ -68,6 +74,22 @@ def target_level(self, value): ) raise ValueError(msg) + @property + def model(self) -> Model: + return self._model + + @model.setter + def model(self, value: Model): + self._model = value + + @property + def model_id(self): + return self._info["model_id"] + + @model_id.setter + def model_id(self, value): + self._info["model_id"] = value + @property def model_inputs(self): return self._info["model_inputs"] @@ -82,151 +104,136 @@ def model_outputs(self): @model_outputs.setter def model_outputs(self, value): - n_levels = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level) + 1 - if all([len(v) == n_levels for v in value]): - self._info["model_outputs"] = value - else: - self._info["model_outputs"] = [ - label_to_index_element(v, self.target_level) for v in value - ] + self._info["model_outputs"] = value + + def get_model_outputs_as_label_tuples(self): + return [ + label_to_index_element(v, self.target_level) for v in self.model_outputs + ] + + def _get_model_dict(self) -> dict: + model_json = self.model.to_json() + model_dict = json.loads(model_json) + model_weights = self.model.get_weights() + model_dict = { + "info": self._info, + "model": model_dict, + "weights": model_weights, + } + return model_dict - def to_tflite(self, file_path: str = None, quantize: bool = False): - """Convert the model to a TFLite model and optionally save or quantize it. + def _get_model_str(self) -> str: + model_dict = self._get_model_dict() + model_str = json.dumps(model_dict, indent=4, cls=PyRIIDModelJsonEncoder) + return model_str - Args: - file_path: file path at which to save the model - quantize: whether to apply quantization + def _initialize_info(self): + init_info = { + "model_id": str(uuid.uuid4()), + "model_type": self.__class__.__name__, + "normalization": SpectraState.Unknown, + "pyriid_version": riid.__version__, + } + self._update_info(**init_info) - Returns: - bytes object representing the model in TFLite form - """ - converter = tf.lite.TFLiteConverter.from_keras_model(self.model) - if quantize: - converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] - tflite_model = converter.convert() - if file_path: - open(file_path, "wb").write(tflite_model) - return tflite_model + def _update_info(self, **kwargs): + self._info.update(kwargs) - def save(self, file_path: str): - """Save the model to a file. + def _update_custom_objects(self, key, value): + self._custom_objects.update({key: value}) - Args: - file_path: file path at which to save the model, can be either .h5 or - .onnx format + def load(self, model_path: str): + """Load the model from a path. - Raises: - `ValueError` when the given file path already exists + Args: + model_path: path from which to load the model. """ - if os.path.exists(file_path): - raise ValueError("Path already exists.") + if not os.path.exists(model_path): + raise ValueError("Model file does not exist.") - root, ext = os.path.splitext(file_path) - if ext.lower() not in self.SUPPORTED_SAVE_EXTS.values(): - raise NameError("Model must be an .onnx or .h5 file.") + with open(model_path) as fin: + model = json.load(fin) - warnings.filterwarnings("ignore") + model_str = json.dumps(model["model"]) + self.model = tf.keras.models.model_from_json(model_str, custom_objects=self._custom_objects) + self.model.set_weights([np.array(x) for x in model["weights"]]) + self.info = model["info"] - if ext.lower() == self.SUPPORTED_SAVE_EXTS["H5"]: - self.model.save(file_path, save_format="h5") - pd.DataFrame( - [[v] for v in self.info.values()], - self.info.keys() - ).to_hdf(file_path, "_info") - else: - model_path = root + self.SUPPORTED_SAVE_EXTS["ONNX"] - model_info_path = root + "_info.json" + def save(self, model_path: str, overwrite=False): + """Save the model to a path. - model_info_df = pd.DataFrame( - [[v] for v in self.info.values()], - self.info.keys() - ) - model_info_df[0].to_json(model_info_path, indent=4) + Args: + model_path: path at which to save the model. + overwrite: whether to overwrite an existing file if it already exists. - tf2onnx.convert.from_keras( - self.model, - input_signature=None, - output_path=model_path - ) + Raises: + `ValueError` when the given path already exists + """ + if os.path.exists(model_path) and not overwrite: + raise ValueError("Model file already exists.") - warnings.resetwarnings() + model_str = self._get_model_str() + with open(model_path, "w") as fout: + fout.write(model_str) - def load(self, file_path: str): - """Load the model from a file. + def to_onnx(self, model_path: str = None, **tf2onnx_kwargs: dict): + """Convert the model to an ONNX model. Args: - file_path: file path from which to load the model, must be either an - .h5 or .onnx file + model_path: path at which to save the model + tf2onnx_kwargs: additional kwargs to pass to the conversion """ + if not model_path.endswith(riid.ONNX_MODEL_FILE_EXTENSION): + raise ValueError(f"ONNX file path must end with {riid.ONNX_MODEL_FILE_EXTENSION}") + if os.path.exists(model_path): + raise ValueError("Model file already exists.") - root, ext = os.path.splitext(file_path) - if ext.lower() not in self.SUPPORTED_SAVE_EXTS.values(): - raise NameError("Model must be an .onnx or .h5 file.") + tf2onnx.convert.from_keras( + self.model, + output_path=model_path, + **tf2onnx_kwargs + ) - warnings.filterwarnings("ignore", category=DeprecationWarning) + def to_tflite(self, model_path: str, quantize: bool = False, prune: bool = False): + """Convert the model to a TFLite model and optionally applying quantization or pruning. - if ext.lower() == self.SUPPORTED_SAVE_EXTS["H5"]: - self.model = tf.keras.models.load_model( - file_path, - custom_objects=self.CUSTOM_OBJECTS - ) - self._info = pd.read_hdf(file_path, "_info")[0].to_dict() - else: - model_path = root + self.SUPPORTED_SAVE_EXTS["ONNX"] - model_info_path = root + "_info.json" - with open(model_info_path) as fin: - model_info = json.load(fin) - self._info = model_info - self.onnx_session = onnxruntime.InferenceSession(model_path) - - warnings.resetwarnings() - - def get_predictions(self, x, **kwargs): - if self.model is not None: - outputs = self.model.predict(x, **kwargs) - elif self.onnx_session is not None: - outputs = self.onnx_session.run( - [self.onnx_session.get_outputs()[0].name], - {self.onnx_session.get_inputs()[0].name: x.astype(np.float32)} - )[0] - else: - raise ValueError("No model found with which to obtain predictions.") - return outputs + Args: + model_path: file path at which to save the model + quantize: whether to apply quantization + prune: whether to apply pruning + """ + if not model_path.endswith(riid.TFLITE_MODEL_FILE_EXTENSION): + raise ValueError(f"TFLite file path must end with {riid.TFLITE_MODEL_FILE_EXTENSION}") + if os.path.exists(model_path): + raise ValueError("Model file already exists.") - def serialize(self) -> bytes: - """Convert model to a bytes object. + optimizations = [] + if quantize: + optimizations.append(tf.lite.Optimize.DEFAULT) + if prune: + optimizations.append(tf.lite.Optimize.EXPERIMENTAL_SPARSITY) - Returns: - bytes object representing the model on disk - """ - self.save(self._temp_file_path) - try: - with open(self._temp_file_path, "rb") as f: - data = f.read() - finally: - os.remove(self._temp_file_path) + converter = tf.lite.TFLiteConverter.from_keras_model(self.model) + converter.optimizations = optimizations + tflite_model = converter.convert() - return data + with open(model_path, "wb") as fout: + fout.write(tflite_model) - def deserialize(self, stream: bytes): - """Populate the current model with the given bytes object. + @abstractmethod + def fit(self): + pass - Args: - stream: bytes object containing the model information - """ - try: - with open(self._temp_file_path, "wb") as f: - f.write(stream) - self.load(self._temp_file_path) - finally: - os.remove(self._temp_file_path) - - def initialize_info(self): - """Initialize model information with default values.""" - info = { - "model_id": str(uuid.uuid4()), - "model_type": self.__class__.__name__, - "normalization": SpectraState.Unknown, - "pyriid_version": riid.__version__, - } - self.info.update(info) + @abstractmethod + def predict(self): + pass + + +class PyRIIDModelJsonEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, np.ndarray): + return o.tolist() + elif isinstance(o, np.float32): + return o.astype(float) + else: + return super().default(o) diff --git a/riid/models/bayes.py b/riid/models/bayes.py index 225e2559..520ecc98 100644 --- a/riid/models/bayes.py +++ b/riid/models/bayes.py @@ -29,7 +29,27 @@ class PoissonBayesClassifier(PyRIIDModel): def __init__(self): super().__init__() - def _create_model(self, seeds_ss: SampleSet): + def fit(self, seeds_ss: SampleSet): + """Construct a TF-based implementation of a poisson-bayes classifier in terms + of the given seeds. + + Args: + seeds_ss: `SampleSet` of `n` foreground seed spectra where `n` >= 1. + + Raises: + - `ValueError` when no seeds are provided + - `NegativeSpectrumError` when any seed spectrum has negative counts in any bin + - `ZeroTotalCountsError` when any seed spectrum contains zero total counts + """ + if seeds_ss.n_samples <= 0: + raise ValueError("Argument 'seeds_ss' must contain at least one seed.") + if (seeds_ss.spectra.values < 0).any(): + msg = "Argument 'seeds_ss' can't contain any spectra with negative values." + raise NegativeSpectrumError(msg) + if (seeds_ss.spectra.values.sum(axis=1) <= 0).any(): + msg = "Argument 'seeds_ss' can't contain any spectra with zero total counts." + raise ZeroTotalCountsError(msg) + self._seeds = tf.constant(tf.convert_to_tensor( seeds_ss.spectra.values, dtype=tf.float32 @@ -37,7 +57,7 @@ def _create_model(self, seeds_ss: SampleSet): # Inputs gross_spectrum_input = tf.keras.layers.Input( - shape=self.seeds_ss.n_channels, + shape=seeds_ss.n_channels, name="gross_spectrum" ) gross_live_time_input = tf.keras.layers.Input( @@ -45,7 +65,7 @@ def _create_model(self, seeds_ss: SampleSet): name="gross_live_time" ) bg_spectrum_input = tf.keras.layers.Input( - shape=self.seeds_ss.n_channels, + shape=seeds_ss.n_channels, name="bg_spectrum" ) bg_live_time_input = tf.keras.layers.Input( @@ -96,29 +116,9 @@ def _create_model(self, seeds_ss: SampleSet): self.model = tf.keras.Model(model_inputs, prediction_probas) self.model.compile() - def fit(self, seeds_ss: SampleSet = None): - """Construct a TF-based implementation of a poisson-bayes classifier in terms - of the given seeds. - - Args: - seeds_ss: `SampleSet` of `n` foreground seed spectra where `n` >= 1. - - Raises: - - `ValueError` when no seeds are provided - - `NegativeSpectrumError` when any seed spectrum has negative counts in any bin - - `ZeroTotalCountsError` when any seed spectrum contains zero total counts - """ - if seeds_ss.n_samples <= 0: - raise ValueError("Argument 'seeds_ss' must contain at least one seed.") - if (seeds_ss.spectra.values < 0).any(): - msg = "Argument 'seeds_ss' can't contain any spectra with negative values." - raise NegativeSpectrumError(msg) - if (seeds_ss.spectra.values.sum(axis=1) <= 0).any(): - msg = "Argument 'seeds_ss' can't contain any spectra with zero total counts." - raise ZeroTotalCountsError(msg) - - self.seeds_ss = seeds_ss - self._create_model(self.seeds_ss) + self.target_level = "Seed" + sources_df = seeds_ss.sources.groupby(axis=1, level=self.target_level, sort=False).sum() + self.model_outputs = sources_df.columns.values.tolist() def predict(self, gross_ss: SampleSet, bg_ss: SampleSet, normalize_scores: bool = False, verbose: bool = False): @@ -139,7 +139,7 @@ def predict(self, gross_ss: SampleSet, bg_ss: SampleSet, bg_spectra = tf.convert_to_tensor(bg_ss.spectra.values, dtype=tf.float32) bg_lts = tf.convert_to_tensor(bg_ss.info.live_time.values, dtype=tf.float32) - prediction_probas = self.get_predictions(( + prediction_probas = self.model.predict(( gross_spectra, gross_lts, bg_spectra, bg_lts ), batch_size=512, verbose=verbose) @@ -150,7 +150,10 @@ def predict(self, gross_ss: SampleSet, bg_ss: SampleSet, gross_ss.prediction_probas = pd.DataFrame( prediction_probas, - columns=self.seeds_ss.sources.columns + columns=pd.MultiIndex.from_tuples( + self.get_model_outputs_as_label_tuples(), + names=SampleSet.SOURCES_MULTI_INDEX_NAMES + ) ) diff --git a/riid/models/neural_nets/__init__.py b/riid/models/neural_nets/__init__.py index b020ef14..08c60c60 100644 --- a/riid/models/neural_nets/__init__.py +++ b/riid/models/neural_nets/__init__.py @@ -2,25 +2,20 @@ # Under the terms of Contract DE-NA0003525 with NTESS, # the U.S. Government retains certain rights in this software. """This module contains multi-layer perceptron classifiers and regressors.""" -import json -import os -from typing import Any, List, Tuple +from typing import Any, List import numpy as np -import onnxruntime import pandas as pd import tensorflow as tf -import tf2onnx from keras.callbacks import EarlyStopping -from keras.layers import Activation, Dense, Dropout +from keras.layers import Dense, Dropout from keras.optimizers import Adam from keras.regularizers import L1L2, l1, l2 -from keras.utils import get_custom_objects from scipy.interpolate import UnivariateSpline from riid.data.sampleset import SampleSet from riid.losses import (build_keras_semisupervised_loss_func, - chi_squared_diff, jensen_shannon_divergence, mish, + chi_squared_diff, jensen_shannon_divergence, normal_nll_diff, poisson_nll_diff, reconstruction_error, sse_diff, weighted_sse_diff) from riid.losses.sparsemax import SparsemaxLoss, sparsemax @@ -28,24 +23,6 @@ single_f1) from riid.models import ModelInput, PyRIIDModel -tf2onnx.logging.basicConfig(level=tf2onnx.logging.WARNING) - -get_custom_objects().update({"mish": Activation(mish)}) - - -def _get_reordered_spectra(old_spectra_df: pd.DataFrame, old_sources_df: pd.DataFrame, - new_sources_columns, target_level) -> pd.DataFrame: - collapsed_sources_df = old_sources_df\ - .groupby(axis=1, level=target_level)\ - .sum() - reordered_spectra_df = old_spectra_df.iloc[ - collapsed_sources_df[ - new_sources_columns - ].idxmax() - ].reset_index(drop=True) - - return reordered_spectra_df - class MLPClassifier(PyRIIDModel): """Multi-layer perceptron classifier.""" @@ -221,12 +198,12 @@ def fit(self, ss: SampleSet, bg_ss: SampleSet = None, batch_size=batch_size, ) - # Initialize model information - self.target_level = target_level - self.model_outputs = source_contributions_df.columns.values - self.initialize_info() - # TODO: get rid of the following line in favor of a normalization layer - self._info["normalization"] = ss.spectra_state + # Update model information + self._update_info( + target_level=target_level, + model_outputs=source_contributions_df.columns.values.tolist(), + normalization=ss.spectra_state, + ) return history @@ -246,29 +223,26 @@ def predict(self, ss: SampleSet, bg_ss: SampleSet = None, verbose=False): else: X = x_test - results = self.get_predictions(X, verbose=verbose) + results = self.model.predict(X, verbose=verbose) col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level) col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1] ss.prediction_probas = pd.DataFrame( data=results, columns=pd.MultiIndex.from_tuples( - self.model_outputs, names=col_level_subset + self.get_model_outputs_as_label_tuples(), + names=col_level_subset ) ) - ss.classified_by = self.info["model_id"] + ss.classified_by = self.model_id class MultiEventClassifier(PyRIIDModel): """A classifier for spectra from multiple detectors observing the same event.""" - def __init__(self, hidden_layers: tuple = (512,), activation: str = "relu", loss: str = "categorical_crossentropy", - optimizer: Any = Adam( - learning_rate=0.01, - clipnorm=0.001 - ), + optimizer: Any = Adam(learning_rate=0.01, clipnorm=0.001), metrics: list = ["accuracy", "categorical_crossentropy", multi_f1, single_f1], l2_alpha: float = 1e-4, activity_regularizer: tf.keras.regularizers = l1(0), dropout: float = 0.0, learning_rate: float = 0.01): @@ -415,11 +389,15 @@ def fit(self, list_of_ss: List[SampleSet], target_contributions: pd.DataFrame, ) # Initialize model info, update output/input information - self.target_level = target_level - self.model_outputs = target_contributions.columns.values - self.initialize_info() - self.info["model_inputs"] = tuple( - [(ss.classified_by, ss.prediction_probas.shape[1]) for ss in list_of_ss] + self._update_info( + target_level=target_level, + model_outputs=target_contributions.columns.values.tolist(), + model_inputs=tuple( + [(ss.classified_by, ss.prediction_probas.shape[1]) for ss in list_of_ss] + ), + normalization=tuple( + [(ss.classified_by, ss.spectra_state) for ss in list_of_ss] + ), ) return history @@ -434,15 +412,15 @@ def predict(self, list_of_ss: List[SampleSet], verbose=False) -> pd.DataFrame: `DataFrame` of predicted results for the `Sampleset`(s) """ X = [ss.prediction_probas for ss in list_of_ss] - # output size will be n_samples by n_labels - results = self.get_predictions(X, verbose=verbose) + results = self.model.predict(X, verbose=verbose) col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level) col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1] results_df = pd.DataFrame( data=results, columns=pd.MultiIndex.from_tuples( - self.model_outputs, names=col_level_subset + self.get_model_outputs_as_label_tuples(), + names=col_level_subset ) ) return results_df @@ -483,8 +461,6 @@ class LabelProportionEstimator(PyRIIDModel): ) } INFO_KEYS = ( - # model metadata - "_info", # model architecture "hidden_layers", "learning_rate", @@ -509,7 +485,6 @@ class LabelProportionEstimator(PyRIIDModel): # dictionaries "source_dict", # populated when loading model - "history", "spline_snrs", "spline_recon_errors", ) @@ -517,14 +492,13 @@ class LabelProportionEstimator(PyRIIDModel): def __init__(self, hidden_layers: tuple = (256,), sup_loss="sparsemax", unsup_loss="sse", metrics=("mae", "categorical_crossentropy",), beta=0.9, source_dict=None, optimizer="adam", optimizer_kwargs=None, learning_rate: float = 1e-3, - epsilon: float = 0.05, hidden_layer_activation: str = "mish", + hidden_layer_activation: str = "mish", kernel_l1_regularization: float = 0.0, kernel_l2_regularization: float = 0.0, bias_l1_regularization: float = 0.0, bias_l2_regularization: float = 0.0, activity_l1_regularization: float = 0.0, activity_l2_regularization: float = 0.0, - dropout: float = 0.0, target_level: str = "Seed", ood_fp_rate: float = 0.05, + dropout: float = 0.0, ood_fp_rate: float = 0.05, fit_spline: bool = True, spline_bins: int = 15, spline_k: int = 3, - spline_s: int = 0, spline_snrs=None, spline_recon_errors=None, history=None, - _info=None, **base_kwargs): + spline_s: int = 0, spline_snrs=None, spline_recon_errors=None): """ Args: hidden_layers: tuple defining the number and size of dense layers @@ -538,8 +512,7 @@ def __init__(self, hidden_layers: tuple = (256,), sup_loss="sparsemax", unsup_lo optimizer: tensorflow optimizer or optimizer name to use for training optimizer_kwargs: kwargs for optimizer learning_rate: learning rate for the optimizer - epsilon: epsilon constant for the Adam optimizer - hidden_layer_activation: activattion function to use for each dense layer + hidden_layer_activation: activation function to use for each dense layer kernel_l1_regularization: l1 regularization value for the kernel regularizer kernel_l2_regularization: l2 regularization value for the kernel regularizer bias_l1_regularization: l1 regularization value for the bias regularizer @@ -547,7 +520,6 @@ def __init__(self, hidden_layers: tuple = (256,), sup_loss="sparsemax", unsup_lo activity_l1_regularization: l1 regularization value for the activity regularizer activity_l2_regularization: l2 regularization value for the activity regularizer dropout: amount of dropout to apply to each dense layer - target_level: `SampleSet.sources` column level to use ood_fp_rate: false positive rate used to determine threshold for out-of-distribution (OOD) detection fit_spline: whether or not to fit UnivariateSpline for OOD threshold function @@ -560,10 +532,8 @@ def __init__(self, hidden_layers: tuple = (256,), sup_loss="sparsemax", unsup_lo spline_snrs: SNRs from training used as the x-values to fit the UnivariateSpline spline_recon_errors: reconstruction errors from training used as the y-values to fit the UnivariateSpline - history: dictionary of training/val history, automatically filled when loading model - _info: internal dictionary uses to store target level and output columns """ - super().__init__(**base_kwargs) + super().__init__() self.hidden_layers = hidden_layers self.sup_loss = sup_loss @@ -588,7 +558,6 @@ def __init__(self, hidden_layers: tuple = (256,), sup_loss="sparsemax", unsup_lo self.beta = beta self.source_dict = source_dict self.semisup_loss_func_name = "semisup_loss" - self.model = None self.hidden_layer_activation = hidden_layer_activation self.kernel_l1_regularization = kernel_l1_regularization self.kernel_l2_regularization = kernel_l2_regularization @@ -597,17 +566,24 @@ def __init__(self, hidden_layers: tuple = (256,), sup_loss="sparsemax", unsup_lo self.activity_l1_regularization = activity_l1_regularization self.activity_l2_regularization = activity_l2_regularization self.dropout = dropout - self.target_level = target_level self.ood_fp_rate = ood_fp_rate self.fit_spline = fit_spline self.spline_bins = spline_bins self.spline_k = spline_k self.spline_s = spline_s - self.history = history self.spline_snrs = spline_snrs self.spline_recon_errors = spline_recon_errors - if _info: - self.info = _info + self.model = None + + self._update_custom_objects("L1NormLayer", L1NormLayer) + + @property + def source_dict(self) -> dict: + return self.info["source_dict"] + + @source_dict.setter + def source_dict(self, value: dict): + self.info["source_dict"] = value def _get_sup_loss_func(self, loss_func_str, prefix): if loss_func_str not in self.SUPERVISED_LOSS_FUNCS: @@ -623,10 +599,7 @@ def _get_unsup_loss_func(self, loss_func_str): def _initialize_model(self, input_size, output_size): spectra_input = tf.keras.layers.Input(input_size, name="input_spectrum") - - spectra_norm = tf.keras.layers.Lambda(_l1_norm, name="normalized_input_spectrum")( - spectra_input - ) + spectra_norm = L1NormLayer(name="normalized_input_spectrum")(spectra_input) x = spectra_norm for layer, nodes in enumerate(self.hidden_layers): x = tf.keras.layers.Dense( @@ -661,20 +634,23 @@ def _initialize_model(self, input_size, output_size): ) def _get_info_as_dict(self): - info_dict = {k: v for k, v in vars(self).items() if k in self.INFO_KEYS} + info_dict = {} + for k, v in vars(self).items(): + if k not in self.INFO_KEYS: + continue + if isinstance(v, np.ndarray): + info_dict[k] = v.tolist() + else: + info_dict[k] = v return info_dict - def _get_model_file_paths(self, save_path): - SUPPORTED_ONNX_EXT = ".onnx" - - root, ext = os.path.splitext(save_path) - if ext.lower() != SUPPORTED_ONNX_EXT: - raise NameError("Model must be an .onnx file.") - - model_path = root + SUPPORTED_ONNX_EXT - model_info_path = root + "_info.json" - - return model_info_path, model_path + def _get_spline_threshold_func(self): + return UnivariateSpline( + self.info["avg_snrs"], + self.info["thresholds"], + k=self.spline_k, + s=self.spline_s + ) def _fit_spline_threshold_func(self): out = pd.qcut( @@ -689,11 +665,11 @@ def _fit_spline_threshold_func(self): avg_snrs = [ np.mean(np.array(self.spline_snrs)[out == int(i)]) for i in range(self.spline_bins) ] - self.ood_threshold_func = UnivariateSpline( - avg_snrs, - thresholds, - k=self.spline_k, - s=self.spline_s + self._update_info( + avg_snrs=avg_snrs, + thresholds=thresholds, + spline_k=self.spline_k, + spline_s=self.spline_s, ) def _get_snrs(self, ss: SampleSet, bg_cps: float, is_gross: bool) -> np.ndarray: @@ -709,7 +685,7 @@ def fit(self, seeds_ss: SampleSet, ss: SampleSet, bg_cps: int = 300, is_gross: b callbacks=None, patience: int = 15, es_monitor: str = "val_loss", es_mode: str = "min", es_verbose=0, es_min_delta: float = 0.0, normalize_sup_loss: bool = True, normalize_func=tf.math.tanh, - normalize_scaler: float = 1.0, verbose: bool = False): + normalize_scaler: float = 1.0, target_level="Isotope", verbose: bool = False): """Fit a model to the given SampleSet(s). Args: @@ -730,10 +706,11 @@ def fit(self, seeds_ss: SampleSet, ss: SampleSet, bg_cps: int = 300, is_gross: b normalize_sup_loss: whether to normalize the supervised loss term normalize_func: normalization function used for supervised loss term normalize_scaler: scalar that sets the steepness of the normalization function + target_level: source level to target for model output verbose: whether model training output is printed to the terminal """ spectra = ss.get_samples().astype(float) - sources_df = ss.sources.groupby(axis=1, level=self._info["target_level"], sort=False).sum() + sources_df = ss.sources.groupby(axis=1, level=target_level, sort=False).sum() sources = sources_df.values.astype(float) self.sources_columns = sources_df.columns @@ -745,7 +722,7 @@ def fit(self, seeds_ss: SampleSet, ss: SampleSet, bg_cps: int = 300, is_gross: b seeds_ss.spectra, seeds_ss.sources, self.sources_columns, - target_level=self._info["target_level"] + target_level=target_level ).values if not self.model: @@ -826,7 +803,6 @@ def fit(self, seeds_ss: SampleSet, ss: SampleSet, bg_cps: int = 300, is_gross: b shuffle=True, batch_size=batch_size ) - self.history = history.history if self.fit_spline: if verbose: @@ -840,17 +816,20 @@ def fit(self, seeds_ss: SampleSet, ss: SampleSet, bg_cps: int = 300, is_gross: b self.source_dict, self.unsup_loss_func ).numpy() - self.spline_snrs = self._get_snrs(ss, bg_cps, is_gross) - self._fit_spline_threshold_func() - self.model_outputs = sources_df.columns.values + info = self._get_info_as_dict() + self._update_info( + target_level=target_level, + model_outputs=sources_df.columns.values.tolist(), + normalization=ss.spectra_state, + **info, + ) return history - def predict(self, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False, - verbose=False): + def predict(self, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False, verbose=False): """Estimate the proportions of counts present in each sample of the provided SampleSet. Results are stored inside the SampleSet's prediction_probas property. @@ -864,16 +843,15 @@ def predict(self, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False, """ test_spectra = ss.get_samples().astype(float) - logits = self.get_predictions(test_spectra, verbose=verbose) - + logits = self.model.predict(test_spectra, verbose=verbose) lpes = self.activation(tf.convert_to_tensor(logits, dtype=tf.float32)) - col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self._info["target_level"]) + col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level) col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1] ss.prediction_probas = pd.DataFrame( data=lpes, columns=pd.MultiIndex.from_tuples( - self._info["model_outputs"], + self.get_model_outputs_as_label_tuples(), names=col_level_subset ) ) @@ -885,67 +863,35 @@ def predict(self, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False, self.source_dict, self.unsup_loss_func ).numpy() - ss.info[self.unsup_loss_func_name] = recon_errors if self.fit_spline: snrs = self._get_snrs(ss, bg_cps, is_gross) + thresholds = self._get_spline_threshold_func()(snrs) + is_ood = recon_errors > thresholds + ss.info["ood"] = is_ood - # Generate OOD predictions - try: - thresholds = self.ood_threshold_func(snrs) - except AttributeError: - self._fit_spline_threshold_func() - thresholds = self.ood_threshold_func(snrs) - ss.info["ood"] = recon_errors > thresholds - - def save(self, file_path) -> Tuple[str, str]: - """Save the model in ONNX format. - - Args: - file_path: file path at which to save the model - - Returns: - Tuple containing path to model and additional info - """ - model_info_path, model_path = \ - self._get_model_file_paths(file_path) - - dir_path = os.path.dirname(model_path) - if not os.path.exists(dir_path): - os.mkdir(dir_path) - - model_info = self._get_info_as_dict() - model_info_df = pd.DataFrame( - [[v] for v in model_info.values()], - model_info.keys() - ) - model_info_df[0].to_json(model_info_path, indent=4) - - tf2onnx.convert.from_keras( - self.model, - input_signature=None, - output_path=model_path - ) + ss.info["recon_error"] = recon_errors - return model_info_path, model_path - def load(self, file_path): - """Load the model from ONNX format in place. - - Args: - file_path: path from which to load the model - """ - model_info_path, model_path = \ - self._get_model_file_paths(file_path) +def _get_reordered_spectra(old_spectra_df: pd.DataFrame, old_sources_df: pd.DataFrame, + new_sources_columns, target_level) -> pd.DataFrame: + collapsed_sources_df = old_sources_df\ + .groupby(axis=1, level=target_level)\ + .sum() + reordered_spectra_df = old_spectra_df.iloc[ + collapsed_sources_df[ + new_sources_columns + ].idxmax() + ].reset_index(drop=True) - with open(model_info_path) as fin: - model_info = json.load(fin) - self.__init__(**model_info) + return reordered_spectra_df - self.onnx_session = onnxruntime.InferenceSession(model_path) +class L1NormLayer(tf.keras.layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) -def _l1_norm(x): - sums = tf.reduce_sum(x, axis=-1) - l1_norm = x / tf.reshape(sums, (-1, 1)) - return l1_norm + def call(self, inputs): + sums = tf.reduce_sum(inputs, axis=-1) + l1_norm = inputs / tf.reshape(sums, (-1, 1)) + return l1_norm diff --git a/riid/models/neural_nets/arad.py b/riid/models/neural_nets/arad.py index c9f6a843..df386572 100644 --- a/riid/models/neural_nets/arad.py +++ b/riid/models/neural_nets/arad.py @@ -2,8 +2,11 @@ # Under the terms of Contract DE-NA0003525 with NTESS, # the U.S. Government retains certain rights in this software. """This module contains implementations of the ARAD model architecture.""" +from typing import List + +import pandas as pd import tensorflow as tf -from keras.activations import sigmoid +from keras.activations import sigmoid, softplus from keras.callbacks import EarlyStopping, ReduceLROnPlateau from keras.initializers import GlorotNormal, HeNormal from keras.layers import (BatchNormalization, Concatenate, Conv1D, @@ -65,17 +68,21 @@ def __init__(self, latent_dim: int = 5): b3 = self._get_branch(encoder_input, b3_config, 0.1, "softplus", "B3", 5) x = Concatenate(axis=1)([b1, b2, b3]) x = Reshape((15,), name="reshape")(x) - latent_space = Dense(units=latent_dim, name="D1_latent_space")(x) + latent_space = Dense( + units=latent_dim, + name="D1_latent_space", + activation=softplus + )(x) encoder_output = BatchNormalization(name="D1_batch_norm")(latent_space) encoder = Model(encoder_input, encoder_output, name="encoder") # Decoder decoder_input = Input(shape=(latent_dim,), name="decoder_input") - x = Dense(units=40, name="D2")(decoder_input) + x = Dense(units=40, name="D2", activation=softplus)(decoder_input) x = Dropout(rate=0.1, name="D2_dropout")(x) decoder_output = Dense( units=128, - activation=sigmoid, # unclear from paper, seems to be necessary + activation=softplus, name="D3" )(x) decoder = Model(decoder_input, decoder_output, name="decoder") @@ -116,7 +123,11 @@ def _get_branch(self, input_layer, config, dropout_rate, activation, branch_name x = BatchNormalization(name=f"{layer_name}_batch_norm")(x) x = Dropout(rate=dropout_rate, name=f"{layer_name}_dropout")(x) x = Flatten(name=f"{branch_name}_flatten")(x) - x = Dense(units=dense_units, name=f"{branch_name}_D1")(x) + x = Dense( + units=dense_units, + name=f"{branch_name}_D1", + activation=activation + )(x) x = BatchNormalization(name=f"{branch_name}_batch_norm")(x) return x @@ -184,7 +195,7 @@ def __init__(self, latent_dim: int = 8): # Decoder decoder_input = Input(shape=latent_dim, name="decoder_input") - x = Dense(units=32, name="D2")(decoder_input) + x = Dense(units=32, name="D2", activation=mish)(decoder_input) x = BatchNormalization(name="D2_batch_norm")(x) x = Reshape((4, 8), name="reshape")(x) reversed_config = enumerate(reversed(config[1:]), start=1) @@ -235,32 +246,19 @@ def call(self, x): return decoded -class ARAD(PyRIIDModel): - """PyRIID-compatible ARAD model to work with SampleSets. +class ARADv1(PyRIIDModel): + """PyRIID-compatible ARAD v1 model supporting SampleSets. """ - def __init__(self, model: Model = ARADv2TF()): + def __init__(self, model: ARADv1TF = None): """ Args: - model: instantiated model of the desired version of ARAD to use. + model: a previously initialized TF implementation of ARADv1 """ super().__init__() self.model = model - def _check_spectra(self, ss): - """Checks if SampleSet spectra are compatible with ARAD models.""" - if ss.n_samples <= 0: - raise ValueError("No spectr[a|um] provided!") - if not ss.all_spectra_sum_to_one(): - raise ValueError("All spectra must sum to one.") - if not ss.spectra_state == SpectraState.L1Normalized: - raise ValueError( - f"SpectraState must be L1Normalzied, provided SpectraState is {ss.spectra_state}." - ) - if not ss.n_channels == 128: - raise ValueError( - f"Spectra must have 128 channels, provided spectra have {ss.n_channels} channels." - ) + self._update_custom_objects("ARADv1TF", ARADv1TF) def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2, es_verbose: int = 0, verbose: bool = False): @@ -276,70 +274,52 @@ def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2, Returns: reconstructed_spectra: output of ARAD model """ - self._check_spectra(ss) + _check_spectra(ss) x = ss.get_samples().astype(float) - is_v1 = isinstance(self.model, ARADv1TF) - is_v2 = isinstance(self.model, ARADv2TF) - if is_v1: - optimizer = tf.keras.optimizers.Nadam( - learning_rate=1e-4 - ) - loss_func = None - es_patience = 5 - es_min_delta = 1e-7 - lr_sched_patience = 3 - lr_sched_min_delta = 1e-8 - batch_size = 64 - y = None - elif is_v2: - optimizer = tf.keras.optimizers.Adam( - learning_rate=0.01, - epsilon=0.05 - ) - loss_func = jensen_shannon_distance - es_patience = 6 - es_min_delta = 1e-4 - lr_sched_patience = 3 - lr_sched_min_delta = 1e-4 - batch_size = 32 - y = x - else: - raise ValueError("Invalid model provided, must be ARADv1TF or ARADv2TF.") + optimizer = tf.keras.optimizers.Nadam( + learning_rate=1e-4 + ) + if not self.model: + self.model = ARADv1TF() self.model.compile( - loss=loss_func, + loss=None, optimizer=optimizer ) + callbacks = [ EarlyStopping( monitor="val_loss", - patience=es_patience, + patience=5, verbose=es_verbose, restore_best_weights=True, mode="min", - min_delta=es_min_delta + min_delta=1e-7 ), ReduceLROnPlateau( monitor="val_loss", factor=0.1, - patience=lr_sched_patience, - min_delta=lr_sched_min_delta + patience=3, + min_delta=1e-8 ) ] history = self.model.fit( x=x, - y=y, + y=None, epochs=epochs, verbose=verbose, validation_split=validation_split, callbacks=callbacks, shuffle=True, - batch_size=batch_size + batch_size=64 + ) + + self._update_info( + normalization=ss.spectra_state, ) - self.history = history.history return history @@ -352,21 +332,339 @@ def predict(self, ss: SampleSet, verbose=False): Returns: reconstructed_spectra: output of ARAD model """ - self._check_spectra(ss) + _check_spectra(ss) + + x = ss.get_samples().astype(float) + + reconstructed_spectra = self.model.predict(x, verbose=verbose) + reconstruction_errors = entropy(x, reconstructed_spectra, axis=1) + ss.info["recon_error"] = reconstruction_errors + + return reconstructed_spectra + + +class ARADv2(PyRIIDModel): + """PyRIID-compatible ARAD v2 model supporting SampleSets. + """ + def __init__(self, model: ARADv2TF = None): + """ + Args: + model: a previously initialized TF implementation of ARADv1 + """ + super().__init__() + + self.model = model + + self._update_custom_objects("ARADv2TF", ARADv2TF) + + def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2, + es_verbose: int = 0, verbose: bool = False): + """Fit a model to the given `SampleSet`. + + Args: + ss: `SampleSet` of `n` spectra where `n` >= 1 + epochs: maximum number of training epochs + validation_split: percentage of the training data to use as validation data + es_verbose: verbosity level for `tf.keras.callbacks.EarlyStopping` + verbose: whether to show detailed model training output + + Returns: + reconstructed_spectra: output of ARAD model + """ + _check_spectra(ss) x = ss.get_samples().astype(float) - reconstructed_spectra = self.get_predictions(x, verbose=verbose) + optimizer = tf.keras.optimizers.Adam( + learning_rate=0.01, + epsilon=0.05 + ) + + if not self.model: + self.model = ARADv2TF() + self.model.compile( + loss=jensen_shannon_distance, + optimizer=optimizer + ) + + callbacks = [ + EarlyStopping( + monitor="val_loss", + patience=6, + verbose=es_verbose, + restore_best_weights=True, + mode="min", + min_delta=1e-4 + ), + ReduceLROnPlateau( + monitor="val_loss", + factor=0.1, + patience=3, + min_delta=1e-4 + ) + ] + + history = self.model.fit( + x=x, + y=x, + epochs=epochs, + verbose=verbose, + validation_split=validation_split, + callbacks=callbacks, + shuffle=True, + batch_size=32 + ) + + self._update_info( + normalization=ss.spectra_state, + ) - is_v1 = isinstance(self.model, ARADv1TF) - is_v2 = isinstance(self.model, ARADv2TF) - if is_v1: - # Entropy is equivalent to KL Divergence with how it is used here - reconstruction_metric = entropy - elif is_v2: - reconstruction_metric = jensenshannon + return history + + def predict(self, ss: SampleSet, verbose=False): + """Generate reconstructions for given `SampleSet`. - reconstruction_errors = reconstruction_metric(x, reconstructed_spectra, axis=1) + Args: + ss: `SampleSet` of `n` spectra where `n` >= 1 + + Returns: + reconstructed_spectra: output of ARAD model + """ + _check_spectra(ss) + + x = ss.get_samples().astype(float) + + reconstructed_spectra = self.model.predict(x, verbose=verbose) + reconstruction_errors = jensenshannon(x, reconstructed_spectra, axis=1) ss.info["recon_error"] = reconstruction_errors return reconstructed_spectra + + +class ARADLatentPredictor(PyRIIDModel): + """PyRIID-compatible model for branching from the latent space of a pre-trained + ARAD latent space for a separate prediction task. + """ + def __init__(self, hidden_layers: tuple = (8, 4,), + hidden_activation: str = "relu", final_activation: str = "linear", + loss: str = "mse", optimizer="adam", optimizer_kwargs=None, + learning_rate: float = 1e-3, metrics: tuple = ("mse", ), + kernel_l1_regularization: float = 0.0, kernel_l2_regularization: float = 0.0, + bias_l1_regularization: float = 0.0, bias_l2_regularization: float = 0.0, + activity_l1_regularization: float = 0.0, activity_l2_regularization: float = 0.0, + dropout: float = 0.0, **base_kwargs): + """ + Args: + hidden_layers: tuple defining the number and size of dense layers + hidden_activation: activation function to use for each dense layer + final_activation: activation function to use for final layer + loss: loss function to use for training + optimizer: tensorflow optimizer or optimizer name to use for training + optimizer_kwargs: kwargs for optimizer + learning_rate: optional learning rate for the optimizer + metrics: list of metrics to be evaluating during training + kernel_l1_regularization: l1 regularization value for the kernel regularizer + kernel_l2_regularization: l2 regularization value for the kernel regularizer + bias_l1_regularization: l1 regularization value for the bias regularizer + bias_l2_regularization: l2 regularization value for the bias regularizer + activity_l1_regularization: l1 regularization value for the activity regularizer + activity_l2_regularization: l2 regularization value for the activity regularizer + dropout: amount of dropout to apply to each dense layer + """ + super().__init__(**base_kwargs) + + self.hidden_layers = hidden_layers + self.hidden_activation = hidden_activation + self.final_activation = final_activation + self.loss = loss + self.optimizer = optimizer + if isinstance(optimizer, str): + self.optimizer = tf.keras.optimizers.get(optimizer) + if optimizer_kwargs is not None: + for key, value in optimizer_kwargs.items(): + setattr(self.optimizer, key, value) + self.optimizer.learning_rate = learning_rate + self.metrics = metrics + self.kernel_l1_regularization = kernel_l1_regularization + self.kernel_l2_regularization = kernel_l2_regularization + self.bias_l1_regularization = bias_l1_regularization + self.bias_l2_regularization = bias_l2_regularization + self.activity_l1_regularization = activity_l1_regularization + self.activity_l2_regularization = activity_l2_regularization + self.dropout = dropout + self.model = None + self.encoder = None + + def _initialize_model(self, arad: Model, output_size: int): + """Build Keras MLP model. + """ + encoder = arad.get_layer("encoder") + encoder_input = encoder.get_layer(index=0).input + encoder_output = encoder.get_layer(index=-1).output + encoder_output_shape = encoder_output.shape + + predictor_input = Input(shape=encoder_output_shape, name="predictor_input") + x = predictor_input + for layer, nodes in enumerate(self.hidden_layers): + x = tf.keras.layers.Dense( + nodes, + activation=self.hidden_activation, + kernel_regularizer=L1L2( + l1=self.kernel_l1_regularization, + l2=self.kernel_l2_regularization + ), + bias_regularizer=L1L2( + l1=self.bias_l1_regularization, + l2=self.bias_l2_regularization + ), + activity_regularizer=L1L2( + l1=self.activity_l1_regularization, + l2=self.activity_l2_regularization + ), + name=f"dense_{layer}" + )(x) + if self.dropout > 0: + x = tf.keras.layers.Dropout(self.dropout)(x) + predictor_output = tf.keras.layers.Dense( + output_size, + activation=self.final_activation, + name="output" + )(x) + predictor = Model(predictor_input, predictor_output, name="predictor") + + encoded_spectrum = encoder(encoder_input) + predictions = predictor(encoded_spectrum) + self.model = Model(encoder_input, predictions, name="predictor") + # Freeze the layers corresponding to the autoencoder + # Note: setting trainable to False is recursive to sub-layers per TF docs: + # https://www.tensorflow.org/guide/keras/transfer_learning#recursive_setting_of_the_trainable_attribute + for layer in self.model.layers[:-1]: + layer.trainable = False + + def _check_targets(self, target_info_columns, target_level): + """Check that valid target options are provided.""" + if target_info_columns and target_level: + raise ValueError(( + "You have specified both target_info_columns (regression task) and " + "a target_level (classification task), but only one can be set." + )) + if not target_info_columns and not target_level: + raise ValueError(( + "You must specify either target_info_columns (regression task) or " + "a target_level (classification task)." + )) + + def fit(self, arad: Model, ss: SampleSet, target_info_columns: List[str] = None, + target_level: str = None, batch_size: int = 10, epochs: int = 20, + validation_split: float = 0.2, callbacks=None, patience: int = 15, + es_monitor: str = "val_loss", es_mode: str = "min", es_verbose=0, + es_min_delta: float = 0.0, verbose: bool = False): + """Fit a model to the given SampleSet(s). + + Args: + arad: a pretrained ARAD model (a TensorFlow Model object, not a PyRIIDModel wrapper) + ss: `SampleSet` of `n` spectra where `n` >= 1 + target_info_columns: list of columns names from SampleSet info dataframe which + denote what values the model should target + target_level: `SampleSet.sources` column level to target for classification + batch_size: number of samples per gradient update + epochs: maximum number of training iterations + validation_split: proportion of training data to use as validation data + callbacks: list of callbacks to be passed to TensorFlow Model.fit() method + patience: number of epochs to wait for tf.keras.callbacks.EarlyStopping object + es_monitor: quantity to be monitored for tf.keras.callbacks.EarlyStopping object + es_mode: mode for tf.keras.callbacks.EarlyStopping object + es_verbose: verbosity level for tf.keras.callbacks.EarlyStopping object + es_min_delta: minimum change to count as an improvement for early stopping + verbose: whether model training output is printed to the terminal + """ + self._check_targets(target_info_columns, target_level) + + x_train = ss.get_samples().astype(float) + if target_info_columns: + y_train = ss.info[target_info_columns].values.astype(float) + else: + source_contributions_df = ss.sources.groupby( + axis=1, + level=target_level, + sort=False + ).sum() + y_train = source_contributions_df.values.astype(float) + + if not self.model: + self._initialize_model(arad=arad, output_size=y_train.shape[1]) + + self.model.compile( + loss=self.loss, + optimizer=self.optimizer, + metrics=self.metrics + ) + es = EarlyStopping( + monitor=es_monitor, + patience=patience, + verbose=es_verbose, + restore_best_weights=True, + mode=es_mode, + min_delta=es_min_delta + ) + if callbacks: + callbacks.append(es) + else: + callbacks = [es] + history = self.model.fit( + x_train, + y_train, + epochs=epochs, + verbose=verbose, + validation_split=validation_split, + callbacks=callbacks, + shuffle=True, + batch_size=batch_size + ) + + self._update_info( + normalization=ss.spectra_state, + target_level=target_level, + model_outputs=target_info_columns, + ) + if target_level: + self._update_info( + model_outputs=source_contributions_df.columns.values.tolist(), + ) + + return history + + def predict(self, ss: SampleSet, verbose=False): + spectra = ss.get_samples().astype(float) + predictions = self.model.predict(spectra, verbose=verbose) + + if self.target_level: + col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level) + col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1] + ss.prediction_probas = pd.DataFrame( + data=predictions, + columns=pd.MultiIndex.from_tuples( + self.get_model_outputs_as_label_tuples(), + names=col_level_subset + ) + ) + + ss.classified_by = self.model_id + + return predictions + + +def _check_spectra(ss: SampleSet): + """Checks if SampleSet spectra are compatible with ARAD models.""" + if ss.n_samples <= 0: + raise ValueError("No spectr[a|um] provided!") + if not ss.all_spectra_sum_to_one(): + raise ValueError("All spectra must sum to one.") + if not ss.spectra_state == SpectraState.L1Normalized: + raise ValueError( + f"SpectraState must be L1Normalzied, provided SpectraState is {ss.spectra_state}." + ) + if not ss.n_channels == 128: + raise ValueError( + f"Spectra must have 128 channels, provided spectra have {ss.n_channels} channels." + ) diff --git a/tests/anomaly_tests.py b/tests/anomaly_tests.py index 289be2db..617384df 100644 --- a/tests/anomaly_tests.py +++ b/tests/anomaly_tests.py @@ -61,7 +61,8 @@ def test_event_detector(self): _ = ed.add_measurement( measurement_id, noisy_bg_measurement, - SAMPLE_INTERVAL + SAMPLE_INTERVAL, + verbose=False ) measurement_id += 1 @@ -73,6 +74,7 @@ def test_event_detector(self): measurement_id=measurement_id, measurement=gross_spectrum, duration=SAMPLE_INTERVAL, + verbose=False ) measurement_id += 1 if event_result: diff --git a/tests/model_tests.py b/tests/model_tests.py index 7b145891..39357f3f 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -2,6 +2,7 @@ # Under the terms of Contract DE-NA0003525 with NTESS, # the U.S. Government retains certain rights in this software. """This module tests the bayes module.""" +import os import unittest import numpy as np @@ -11,11 +12,12 @@ from riid.data.synthetic import get_dummy_seeds from riid.data.synthetic.seed import SeedMixer from riid.data.synthetic.static import StaticSynthesizer +from riid.models import PyRIIDModel from riid.models.bayes import (NegativeSpectrumError, PoissonBayesClassifier, ZeroTotalCountsError) from riid.models.neural_nets import (LabelProportionEstimator, MLPClassifier, MultiEventClassifier) -from riid.models.neural_nets.arad import ARAD, ARADv1TF, ARADv2TF +from riid.models.neural_nets.arad import ARADLatentPredictor, ARADv1, ARADv2 class TestModels(unittest.TestCase): @@ -24,6 +26,24 @@ def setUp(self): """Test setup.""" pass + @classmethod + def setUpClass(self): + self.seeds_ss = get_dummy_seeds(n_channels=128) + self.fg_seeds_ss, self.bg_seeds_ss = self.seeds_ss.split_fg_and_bg() + self.mixed_bg_seeds_ss = SeedMixer(self.bg_seeds_ss, mixture_size=3).generate(1) + self.static_synth = StaticSynthesizer(samples_per_seed=5) + self.train_ss, _ = self.static_synth.generate(self.fg_seeds_ss, self.mixed_bg_seeds_ss, + verbose=False) + self.train_ss.prediction_probas = self.train_ss.sources + self.train_ss.normalize() + self.test_ss, _ = self.static_synth.generate(self.fg_seeds_ss, self.mixed_bg_seeds_ss, + verbose=False) + self.test_ss.normalize() + + @classmethod + def tearDownClass(self): + pass + def test_pb_constructor_errors(self): """Testing for constructor errors when different arguments are provided.""" pb_model = PoissonBayesClassifier() @@ -54,7 +74,7 @@ def test_pb_constructor_errors(self): ss.spectra = pd.DataFrame(spectra) self.assertRaises(ZeroTotalCountsError, pb_model.fit, ss) - def test_pb_constructor_and_predict(self): + def test_pb_predict(self): """Tests the constructor with a valid SampleSet.""" seeds_ss = get_dummy_seeds() fg_seeds_ss, bg_seeds_ss = seeds_ss.split_fg_and_bg() @@ -80,18 +100,73 @@ def test_pb_constructor_and_predict(self): pb_model.predict(test_gross_ss, test_bg_ss) truth_labels = fg_seeds_ss.get_labels() - predictions_labels = test_gross_ss.get_predictions() - assert (truth_labels == predictions_labels).all() - - def test_all_constructors(self): - _ = PoissonBayesClassifier() - _ = MLPClassifier() - _ = LabelProportionEstimator() - _ = MultiEventClassifier() - arad_v1 = ARADv1TF() - _ = ARAD(arad_v1) - arad_v2 = ARADv2TF() - _ = ARAD(arad_v2) + predicted_labels = test_gross_ss.get_predictions() + assert (truth_labels == predicted_labels).all() + + def test_pb_fit_save_load(self): + _test_model_fit_save_load_predict(self, PoissonBayesClassifier, None, self.fg_seeds_ss) + + def test_mlp_fit_save_load_predict(self): + _test_model_fit_save_load_predict(self, MLPClassifier, self.test_ss, self.train_ss, + epochs=1) + + def test_mec_fit_save_load_predict(self): + test_copy_ss = self.test_ss[:] + test_copy_ss.prediction_probas = test_copy_ss.sources + _test_model_fit_save_load_predict( + self, + MultiEventClassifier, + [test_copy_ss], + [self.train_ss], + self.train_ss.sources.groupby(axis=1, level="Isotope", sort=False).sum(), + epochs=1 + ) + + def test_lpe_fit_save_load_predict(self): + _test_model_fit_save_load_predict(self, LabelProportionEstimator, self.test_ss, + self.fg_seeds_ss, self.train_ss, epochs=1) + + def test_aradv1_fit_save_load_predict(self): + _test_model_fit_save_load_predict(self, ARADv1, self.test_ss, self.train_ss, epochs=1) + + def test_aradv2_fit_save_load_predict(self): + _test_model_fit_save_load_predict(self, ARADv2, self.test_ss, self.train_ss, epochs=1) + + def test_alp_fit_save_load_predict(self): + arad_v2 = ARADv2() + arad_v2.fit(self.train_ss, epochs=1) + _test_model_fit_save_load_predict(self, ARADLatentPredictor, self.test_ss, arad_v2.model, + self.train_ss, target_info_columns=["snr"], epochs=1) + + +def _try_remove_model_and_info(model_path: str): + if os.path.exists(model_path): + if os.path.isdir(model_path): + os.rmdir(model_path) + else: + os.remove(model_path) + + +def _test_model_fit_save_load_predict(test_case: unittest.TestCase, model_class: PyRIIDModel, + test_ss: SampleSet = None, *args_for_fit, **kwargs_for_fit): + m1 = model_class() + m2 = model_class() + + m1.fit(*args_for_fit, **kwargs_for_fit) + + model_path = m1._temp_file_path + + _try_remove_model_and_info(model_path) + test_case.assertRaises(ValueError, m2.load, model_path) + + m1.save(model_path) + test_case.assertRaises(ValueError, m1.save, model_path) + + m2.load(model_path) + _try_remove_model_and_info(model_path) + + if test_ss is not None: + m1.predict(test_ss) if __name__ == "__main__":