diff --git a/riid/models/__init__.py b/riid/models/__init__.py index e4cd4f9..e02e2c7 100644 --- a/riid/models/__init__.py +++ b/riid/models/__init__.py @@ -2,13 +2,17 @@ # Under the terms of Contract DE-NA0003525 with NTESS, # the U.S. Government retains certain rights in this software. """This module contains the base TFModel class.""" +import json import os import uuid import warnings from enum import Enum +import numpy as np +import onnxruntime import pandas as pd import tensorflow as tf +import tf2onnx import riid from riid.data.labeling import label_to_index_element @@ -23,7 +27,7 @@ class ModelInput(Enum): ForegroundSpectrum = 2 -class TFModelBase: +class PyRIIDModel: """Base class for TensorFlow models.""" CUSTOM_OBJECTS = {"multi_f1": multi_f1, "single_f1": single_f1} @@ -107,7 +111,8 @@ def save(self, file_path: str): """Save the model to a file. Args: - file_path: file path at which to save the model + file_path: file path at which to save the model, can be either .h5 or + .onnx format Raises: `ValueError` when the given file path already exists @@ -115,10 +120,38 @@ def save(self, file_path: str): if os.path.exists(file_path): raise ValueError("Path already exists.") + SUPPORTED_EXTS = { + "H5": ".h5", + "ONNX": ".onnx" + } + root, ext = os.path.splitext(file_path) + if ext.lower() not in SUPPORTED_EXTS.values(): + raise NameError("Model must be an .onnx or .h5 file.") + warnings.filterwarnings("ignore") - self.model.save(file_path, save_format="h5") - pd.DataFrame([[v] for v in self.info.values()], self.info.keys()).to_hdf(file_path, "_info") + if ext.lower() == SUPPORTED_EXTS["H5"]: + self.model.save(file_path, save_format="h5") + pd.DataFrame( + [[v] for v in self.info.values()], + self.info.keys() + ).to_hdf(file_path, "_info") + + else: + model_path = root + SUPPORTED_EXTS["ONNX"] + model_info_path = root + "_info.json" + + model_info_df = pd.DataFrame( + [[v] for v in self.info.values()], + self.info.keys() + ) + model_info_df[0].to_json(model_info_path, indent=4) + + tf2onnx.convert.from_keras( + self.model, + input_signature=None, + output_path=model_path + ) warnings.resetwarnings() @@ -126,18 +159,50 @@ def load(self, file_path: str): """Load the model from a file. Args: - file_path: file path from which to load the model + file_path: file path from which to load the model, must be either an + .h5 or .onnx file """ + SUPPORTED_EXTS = { + "H5": ".h5", + "ONNX": ".onnx" + } + root, ext = os.path.splitext(file_path) + if ext.lower() not in SUPPORTED_EXTS.values(): + raise NameError("Model must be an .onnx or .h5 file.") + warnings.filterwarnings("ignore", category=DeprecationWarning) - self.model = tf.keras.models.load_model( - file_path, - custom_objects=self.CUSTOM_OBJECTS - ) - self._info = pd.read_hdf(file_path, "_info")[0].to_dict() + if ext.lower() == SUPPORTED_EXTS["H5"]: + self.model = tf.keras.models.load_model( + file_path, + custom_objects=self.CUSTOM_OBJECTS + ) + self._info = pd.read_hdf(file_path, "_info")[0].to_dict() + + else: + model_path = root + SUPPORTED_EXTS["ONNX"] + model_info_path = root + "_info.json" + + with open(model_info_path) as fin: + model_info = json.load(fin) + self._info = model_info + + self.onnx_session = onnxruntime.InferenceSession(model_path) warnings.resetwarnings() + def get_predictions(self, x, **kwargs): + if self.model is None: + outputs = self.onnx_session.run( + [self.onnx_session.get_outputs()[0].name], + {self.onnx_session.get_inputs()[0].name: x.astype(np.float32)} + )[0] + + else: + outputs = self.model.predict(x, **kwargs) + + return outputs + def serialize(self) -> bytes: """Convert model to a bytes object. diff --git a/riid/models/bayes.py b/riid/models/bayes.py index c7d09fb..225e255 100644 --- a/riid/models/bayes.py +++ b/riid/models/bayes.py @@ -8,10 +8,10 @@ import tensorflow_probability as tfp from riid.data.sampleset import SampleSet -from riid.models import TFModelBase +from riid.models import PyRIIDModel -class PoissonBayesClassifier(TFModelBase): +class PoissonBayesClassifier(PyRIIDModel): """This Poisson-Bayes classifier calculates the conditional Poisson log probability of each seed spectrum given the measurement. @@ -139,7 +139,7 @@ def predict(self, gross_ss: SampleSet, bg_ss: SampleSet, bg_spectra = tf.convert_to_tensor(bg_ss.spectra.values, dtype=tf.float32) bg_lts = tf.convert_to_tensor(bg_ss.info.live_time.values, dtype=tf.float32) - prediction_probas = self.model.predict(( + prediction_probas = self.get_predictions(( gross_spectra, gross_lts, bg_spectra, bg_lts ), batch_size=512, verbose=verbose) diff --git a/riid/models/neural_nets/__init__.py b/riid/models/neural_nets/__init__.py index d2c42bd..b020ef1 100644 --- a/riid/models/neural_nets/__init__.py +++ b/riid/models/neural_nets/__init__.py @@ -26,7 +26,7 @@ from riid.losses.sparsemax import SparsemaxLoss, sparsemax from riid.metrics import (build_keras_semisupervised_metric_func, multi_f1, single_f1) -from riid.models import ModelInput, TFModelBase +from riid.models import ModelInput, PyRIIDModel tf2onnx.logging.basicConfig(level=tf2onnx.logging.WARNING) @@ -47,7 +47,7 @@ def _get_reordered_spectra(old_spectra_df: pd.DataFrame, old_sources_df: pd.Data return reordered_spectra_df -class MLPClassifier(TFModelBase): +class MLPClassifier(PyRIIDModel): """Multi-layer perceptron classifier.""" def __init__(self, hidden_layers: tuple = (512,), activation: str = "relu", loss: str = "categorical_crossentropy", @@ -245,7 +245,8 @@ def predict(self, ss: SampleSet, bg_ss: SampleSet = None, verbose=False): X = [x_test, bg_ss.get_samples().astype(float)] else: X = x_test - results = self.model.predict(X, verbose=verbose) + + results = self.get_predictions(X, verbose=verbose) col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level) col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1] @@ -259,7 +260,7 @@ def predict(self, ss: SampleSet, bg_ss: SampleSet = None, verbose=False): ss.classified_by = self.info["model_id"] -class MultiEventClassifier(TFModelBase): +class MultiEventClassifier(PyRIIDModel): """A classifier for spectra from multiple detectors observing the same event.""" def __init__(self, hidden_layers: tuple = (512,), activation: str = "relu", @@ -423,7 +424,7 @@ def fit(self, list_of_ss: List[SampleSet], target_contributions: pd.DataFrame, return history - def predict(self, list_of_ss: List[SampleSet]) -> pd.DataFrame: + def predict(self, list_of_ss: List[SampleSet], verbose=False) -> pd.DataFrame: """Classify the spectra in the provided `SampleSet`(s) based on each one's results. Args: @@ -433,7 +434,8 @@ def predict(self, list_of_ss: List[SampleSet]) -> pd.DataFrame: `DataFrame` of predicted results for the `Sampleset`(s) """ X = [ss.prediction_probas for ss in list_of_ss] - results = self.model.predict(X) # output size will be n_samples by n_labels + # output size will be n_samples by n_labels + results = self.get_predictions(X, verbose=verbose) col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level) col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1] @@ -446,7 +448,7 @@ def predict(self, list_of_ss: List[SampleSet]) -> pd.DataFrame: return results_df -class LabelProportionEstimator(TFModelBase): +class LabelProportionEstimator(PyRIIDModel): UNSUPERVISED_LOSS_FUNCS = { "poisson_nll": poisson_nll_diff, "normal_nll": normal_nll_diff, @@ -847,7 +849,8 @@ def fit(self, seeds_ss: SampleSet, ss: SampleSet, bg_cps: int = 300, is_gross: b return history - def predict(self, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False): + def predict(self, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False, + verbose=False): """Estimate the proportions of counts present in each sample of the provided SampleSet. Results are stored inside the SampleSet's prediction_probas property. @@ -861,16 +864,9 @@ def predict(self, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False): """ test_spectra = ss.get_samples().astype(float) - if self.model is None: - outputs = self.onnx_session.run( - [self.onnx_session.get_outputs()[0].name], - {self.onnx_session.get_inputs()[0].name: test_spectra.astype(np.float32)} - )[0] - lpes = self.activation(tf.convert_to_tensor(outputs, dtype=tf.float32)) + logits = self.get_predictions(test_spectra, verbose=verbose) - else: - logits = self.model.predict(test_spectra) - lpes = self.activation(tf.convert_to_tensor(logits, dtype=tf.float32)) + lpes = self.activation(tf.convert_to_tensor(logits, dtype=tf.float32)) col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self._info["target_level"]) col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1] diff --git a/riid/models/neural_nets/arad.py b/riid/models/neural_nets/arad.py index a9bf725..2f452ca 100644 --- a/riid/models/neural_nets/arad.py +++ b/riid/models/neural_nets/arad.py @@ -16,7 +16,7 @@ from riid.data.sampleset import SampleSet from riid.losses import mish, jensen_shannon_distance -from riid.models import TFModelBase +from riid.models import PyRIIDModel @tf.keras.saving.register_keras_serializable(package="riid") @@ -235,7 +235,7 @@ def call(self, x): return decoded -class ARAD(TFModelBase): +class ARAD(PyRIIDModel): """PyRIID-compatible wrapper around ARAD models. """ def __init__(self, model: Model = ARADv2TF()): @@ -247,11 +247,20 @@ def __init__(self, model: Model = ARADv2TF()): self.model = model - # TODO: enable saving as ONNX + def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2, + es_verbose: int = 0, verbose: bool = False): + """Fit a model to the given `SampleSet`. - def fit(self, ss: SampleSet, epochs: int = 300, es_verbose: int = 0, - verbose: bool = False): - """Fit a model to the given `SampleSet`.""" + Args: + ss: `SampleSet` of `n` spectra where `n` >= 1 + epochs: maximum number of training epochs + validation_split: percentage of the training data to use as validation data + es_verbose: verbosity level for `tf.keras.callbacks.EarlyStopping` + verbose: whether to show detailed model training output + + Returns: + reconstructed_spectra: output of ARAD model + """ if ss.n_samples <= 0: raise ValueError("No spectr[a|um] provided!") @@ -312,7 +321,7 @@ def fit(self, ss: SampleSet, epochs: int = 300, es_verbose: int = 0, spectra, epochs=epochs, verbose=verbose, - validation_split=0.2, + validation_split=validation_split, callbacks=callbacks, shuffle=True, batch_size=batch_size @@ -337,7 +346,7 @@ def predict(self, ss: SampleSet, ood_threshold: float = 0.5, norm_ss.normalize() spectra = norm_ss.get_samples().astype(float) - reconstructed_spectra = self.model.predict(spectra, verbose=verbose) + reconstructed_spectra = self.get_predictions(spectra, verbose=verbose) if isinstance(self.model, ARADv1TF): reconstruction_metric = entropy