Skip to content

Commit

Permalink
Add ARADLatentPredictor; standardize model format.
Browse files Browse the repository at this point in the history
Standardizing model format involved the following:

- Remove ONNX runtime dependency
- Change saving as ONNX to a one-way export
- Change model saving and loading for all models use the same JSON format
- Fix various bugs with

Co-authored-by: Tyler Morrow <[email protected]>
  • Loading branch information
alanjvano and tymorrow committed Feb 14, 2024
1 parent 5f33251 commit 32a50a9
Show file tree
Hide file tree
Showing 12 changed files with 841 additions and 441 deletions.
54 changes: 20 additions & 34 deletions examples/modeling/arad.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid.models.neural_nets.arad import ARAD, ARADv1TF, ARADv2TF
from riid.models.neural_nets.arad import ARADv1, ARADv2

# Config
rng = np.random.default_rng(42)
Expand Down Expand Up @@ -37,47 +37,33 @@
_, gross_train_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss)
gross_train_ss.normalize()

# Train the models
print("Training ARADv1...")
arad_v1 = ARAD(model=ARADv1TF())
arad_v1.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE)
arad_v1.predict(gross_train_ss)
v1_ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE)

print("Training ARADv2...")
arad_v2 = ARAD(model=ARADv2TF())
arad_v2.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE)
arad_v2.predict(gross_train_ss)
v2_ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE)

# Generate test data
static_synth.samples_per_seed = TEST_SAMPLES_PER_SEED
_, test_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss)
test_ss.normalize()

# Predict
# Train the models
results = {}
models = [ARADv1, ARADv2]
for model_class in models:
arad = model_class()
model_name = arad.__class__.__name__

arad_v1_reconstructions = arad_v1.predict(test_ss, verbose=True)
arad_v1_ood = test_ss.info.recon_error.values > v1_ood_threshold
arad_v1_false_positive_rate = arad_v1_ood.mean()
arad_v1_mean_recon_error = test_ss.info.recon_error.values.mean()
print(f"Training and testing {model_name}...")
arad.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE)
arad.predict(gross_train_ss)
ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE)

arad_v2_reconstructions = arad_v2.predict(test_ss, verbose=True)
arad_v2_ood = test_ss.info.recon_error.values > v2_ood_threshold
arad_v2_false_positive_rate = arad_v2_ood.mean()
arad_v2_mean_recon_error = test_ss.info.recon_error.values.mean()
reconstructions = arad.predict(test_ss, verbose=True)
ood = test_ss.info.recon_error.values > ood_threshold
false_positive_rate = ood.mean()
mean_recon_error = test_ss.info.recon_error.values.mean()

results = {
"ARADv1": {
"ood_threshold": f"KLD={v1_ood_threshold:.4f}",
"mean_recon_error": arad_v1_mean_recon_error,
"false_positive_rate": arad_v1_false_positive_rate,
},
"ARADv2": {
"ood_threshold": f"JSD={v2_ood_threshold:.4f}",
"mean_recon_error": arad_v2_mean_recon_error,
"false_positive_rate": arad_v2_false_positive_rate,
results[model_name] = {
"ood_threshold": f"{ood_threshold:.4f}",
"mean_recon_error": mean_recon_error,
"false_positive_rate": false_positive_rate,
}
}

print(f"Target False Positive Rate: {1-OOD_QUANTILE:.4f}")
print(pd.DataFrame.from_dict(results))
82 changes: 82 additions & 0 deletions examples/modeling/arad_latent_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
# Under the terms of Contract DE-NA0003525 with NTESS,
# the U.S. Government retains certain rights in this software.
"""This example demonstrates how to train a regressor or classifier branch
from an ARAD latent space.
"""
import numpy as np
from sklearn.metrics import f1_score, mean_squared_error

from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid.models.neural_nets.arad import ARADv2, ARADLatentPredictor

# Config
rng = np.random.default_rng(42)
VERBOSE = True
# Some of the following parameters are set low because this example runs on GitHub Actions and
# we don't want it taking a bunch of time.
# When running this locally, change the values per their corresponding comment, otherwise
# the results likely will not be meaningful.
EPOCHS = 5 # Change this to 20+
N_MIXTURES = 50 # Change this to 1000+
TRAIN_SAMPLES_PER_SEED = 5 # Change this to 20+
TEST_SAMPLES_PER_SEED = 5

# Generate training data
fg_seeds_ss, bg_seeds_ss = get_dummy_seeds(n_channels=128, rng=rng).split_fg_and_bg()
mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=rng).generate(N_MIXTURES)
static_synth = StaticSynthesizer(
samples_per_seed=TRAIN_SAMPLES_PER_SEED,
snr_function_args=(0, 0),
return_fg=False,
return_gross=True,
rng=rng,
)
_, gross_train_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss)
gross_train_ss.normalize()

# Generate test data
static_synth.samples_per_seed = TEST_SAMPLES_PER_SEED
_, test_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss)
test_ss.normalize()

# Train ARAD model
print("Training ARAD")
arad_v2 = ARADv2()
arad_v2.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE)

# Train regressor to predict SNR
print("Training Regressor")
arad_regressor = ARADLatentPredictor()
_ = arad_regressor.fit(
arad_v2.model,
gross_train_ss,
target_info_columns=["live_time"],
epochs=10,
batch_size=5,
verbose=VERBOSE,
)
regression_predictions = arad_regressor.predict(test_ss)
regression_score = mean_squared_error(gross_train_ss.info.live_time, regression_predictions)
print("Regressor MSE: {:.3f}".format(regression_score))

# Train classifier to predict isotope
print("Training Classifier")
arad_classifier = ARADLatentPredictor(
loss="categorical_crossentropy",
metrics=("accuracy", "categorical_crossentropy"),
final_activation="softmax"
)
arad_classifier.fit(
arad_v2.model,
gross_train_ss,
target_level="Isotope",
epochs=10,
batch_size=5,
verbose=VERBOSE,
)
arad_classifier.predict(test_ss)
classification_score = f1_score(test_ss.get_labels(), test_ss.get_predictions(), average="micro")
print("Classification F1 Score: {:.3f}".format(classification_score))
9 changes: 4 additions & 5 deletions examples/modeling/label_proportion_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
optimizer="RMSprop",
learning_rate=1e-2,
hidden_layer_activation="relu",
l2_alpha=1e-4,
dropout=0.05,
)

Expand All @@ -54,7 +53,7 @@
bg_seeds_ss,
bg_ss,
batch_size=10,
epochs=10,
epochs=2,
validation_split=0.2,
verbose=True,
bg_cps=300
Expand All @@ -74,8 +73,9 @@
)
print(f"Mean Test MAE: {test_meas.mean():.3f}")

# Save model in ONNX format
model_info_path, model_path = model.save("./model.onnx")
# Save model
model_path = "./model.json"
model.save(model_path, overwrite=True)

loaded_model = LabelProportionEstimator()
loaded_model.load(model_path)
Expand All @@ -89,5 +89,4 @@
print(f"Mean Test MAE: {test_maes.mean():.3f}")

# Clean up model file - remove this if you want to keep the model
os.remove(model_info_path)
os.remove(model_path)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ dependencies = [
"pyyaml ==6.0.*",
"seaborn ==0.12.*",
"tf2onnx ==1.14.*",
"onnx ==1.14.1",
"tqdm ==4.65.*",
"numpy ==1.23.*",
"onnxruntime ==1.15.*",
"pandas ==2.0.*",
"parmap ==1.6.*",
"pythonnet ==3.0.*; platform_system == 'Windows'",
Expand Down
2 changes: 2 additions & 0 deletions riid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
SAMPLESET_HDF_FILE_EXTENSION = ".h5"
SAMPLESET_JSON_FILE_EXTENSION = ".json"
PCF_FILE_EXTENSION = ".pcf"
ONNX_MODEL_FILE_EXTENSION = ".onnx"
TFLITE_MODEL_FILE_EXTENSION = ".tflite"
RIID = "riid"

__version__ = get_distribution(RIID).version
Expand Down
4 changes: 2 additions & 2 deletions riid/data/sampleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
_pcf_to_dict, _unpack_compressed_text_buffer)


class SpectraState(Enum):
class SpectraState(int, Enum):
"""States in which SampleSet spectra can exist."""
Unknown = 0
Counts = 1
L1Normalized = 2
L2Normalized = 3


class SpectraType(Enum):
class SpectraType(int, Enum):
"""Types for SampleSet spectra."""
Unknown = 0
Background = 1
Expand Down
Loading

0 comments on commit 32a50a9

Please sign in to comment.