Skip to content

Commit

Permalink
Improve standardization of model saving and loading.
Browse files Browse the repository at this point in the history
  • Loading branch information
tymorrow committed Jan 31, 2024
1 parent d071a48 commit b4ea1f2
Show file tree
Hide file tree
Showing 12 changed files with 597 additions and 587 deletions.
54 changes: 20 additions & 34 deletions examples/modeling/arad.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid.models.neural_nets.arad import ARAD
from riid.models.neural_nets.arad import ARADv1, ARADv2

# Config
rng = np.random.default_rng(42)
Expand Down Expand Up @@ -37,47 +37,33 @@
_, gross_train_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss)
gross_train_ss.normalize()

# Train the models
print("Training ARADv1...")
arad_v1 = ARAD(arad_version="v1")
arad_v1.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE)
arad_v1.predict(gross_train_ss)
v1_ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE)

print("Training ARADv2...")
arad_v2 = ARAD(arad_version="v2")
arad_v2.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE)
arad_v2.predict(gross_train_ss)
v2_ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE)

# Generate test data
static_synth.samples_per_seed = TEST_SAMPLES_PER_SEED
_, test_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss)
test_ss.normalize()

# Predict
# Train the models
results = {}
models = [ARADv1, ARADv2]
for model_class in models:
arad = model_class()
model_name = arad.__class__.__name__

arad_v1_reconstructions = arad_v1.predict(test_ss, verbose=True)
arad_v1_ood = test_ss.info.recon_error.values > v1_ood_threshold
arad_v1_false_positive_rate = arad_v1_ood.mean()
arad_v1_mean_recon_error = test_ss.info.recon_error.values.mean()
print(f"Training and testing {model_name}...")
arad.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE)
arad.predict(gross_train_ss)
ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE)

arad_v2_reconstructions = arad_v2.predict(test_ss, verbose=True)
arad_v2_ood = test_ss.info.recon_error.values > v2_ood_threshold
arad_v2_false_positive_rate = arad_v2_ood.mean()
arad_v2_mean_recon_error = test_ss.info.recon_error.values.mean()
reconstructions = arad.predict(test_ss, verbose=True)
ood = test_ss.info.recon_error.values > ood_threshold
false_positive_rate = ood.mean()
mean_recon_error = test_ss.info.recon_error.values.mean()

results = {
"ARADv1": {
"ood_threshold": f"KLD={v1_ood_threshold:.4f}",
"mean_recon_error": arad_v1_mean_recon_error,
"false_positive_rate": arad_v1_false_positive_rate,
},
"ARADv2": {
"ood_threshold": f"JSD={v2_ood_threshold:.4f}",
"mean_recon_error": arad_v2_mean_recon_error,
"false_positive_rate": arad_v2_false_positive_rate,
results[model_name] = {
"ood_threshold": f"{ood_threshold:.4f}",
"mean_recon_error": mean_recon_error,
"false_positive_rate": false_positive_rate,
}
}

print(f"Target False Positive Rate: {1-OOD_QUANTILE:.4f}")
print(pd.DataFrame.from_dict(results))
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,23 @@
"""This example demonstrates how to train a regressor or classifier branch
from an ARAD latent space.
"""
import os

import numpy as np
from sklearn.metrics import f1_score
from sklearn.metrics import f1_score, mean_squared_error

from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid.models.neural_nets.arad import ARAD, ARADBranch
from riid.models.neural_nets.arad import ARADv2, ARADLatentPredictor

# Config
rng = np.random.default_rng(42)
OOD_QUANTILE = 0.99
ARAD_MODEL_PATH = "./arad_model.onnx"
ARAD_REGRESSOR_PATH = "./arad_reg_model.onnx"
ARAD_CLASSIFIER_PATH = "./arad_cla_model.onnx"
VERBOSE = True
# Some of the following parameters are set low because this example runs on GitHub Actions and
# we don't want it taking a bunch of time.
# When running this locally, change the values per their corresponding comment, otherwise
# the results likely will not be meaningful.
EPOCHS = 5 # Change this to 20+
N_MIXTURES = 50 # Changes this to 1000+
N_MIXTURES = 50 # Change this to 1000+
TRAIN_SAMPLES_PER_SEED = 5 # Change this to 20+
TEST_SAMPLES_PER_SEED = 5

Expand All @@ -49,59 +43,40 @@
test_ss.normalize()

# Train ARAD model
arad_v2 = ARAD(arad_version="v2")
print("Training ARAD")
arad_v2 = ARADv2()
arad_v2.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE)
arad_v2.predict(gross_train_ss)
v2_ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE)

# Save ARAD model
arad_v2.save(ARAD_MODEL_PATH)

# Train branched model to predict live-time and real-time
arad_regressor = ARADBranch(
ARAD_MODEL_PATH
)
# Train regressor to predict SNR
print("Training Regressor")
arad_regressor = ARADLatentPredictor()
_ = arad_regressor.fit(
arad_v2.model,
gross_train_ss,
target_info_columns=["live_time", "real_time"],
target_info_columns=["live_time"],
epochs=10,
verbose=True,
batch_size=5
batch_size=5,
verbose=VERBOSE,
)
regression_predictions = arad_regressor.predict(test_ss)
regression_score = mean_squared_error(gross_train_ss.info.live_time, regression_predictions)
print("Regressor MSE: {:.3f}".format(regression_score))

# Save, load, and predict with regressor
arad_regressor.save(ARAD_REGRESSOR_PATH)
arad_regressor = ARADBranch()
arad_regressor.load(ARAD_REGRESSOR_PATH)
preds = arad_regressor.predict(test_ss)
for idx, target_name in enumerate(arad_regressor.info["target_info_columns"]):
print(f"{target_name}: {preds[:5, idx]}...")

# Train branched model to classify isotopes
arad_classifier = ARADBranch(
ARAD_MODEL_PATH,
# Train classifier to predict isotope
print("Training Classifier")
arad_classifier = ARADLatentPredictor(
loss="categorical_crossentropy",
metrics=("accuracy", "categorical_crossentropy"),
final_activation="softmax"
)
_ = arad_classifier.fit(
arad_classifier.fit(
arad_v2.model,
gross_train_ss,
target_level="Isotope",
epochs=10,
verbose=True,
batch_size=5
batch_size=5,
verbose=VERBOSE,
)

# Save, load, and predict with regressor
arad_classifier.save(ARAD_CLASSIFIER_PATH)
arad_classifier = ARADBranch()
arad_classifier.load(ARAD_CLASSIFIER_PATH)
arad_classifier.predict(test_ss)
score = f1_score(test_ss.get_labels(), test_ss.get_predictions(), average="micro")
print("F1 Score: {:.3f}".format(score))

# Clean up
for path in [ARAD_MODEL_PATH, ARAD_REGRESSOR_PATH, ARAD_CLASSIFIER_PATH]:
info_path = os.path.splitext(path)[0] + "_info.json"
os.remove(path)
os.remove(info_path)
classification_score = f1_score(test_ss.get_labels(), test_ss.get_predictions(), average="micro")
print("Classification F1 Score: {:.3f}".format(classification_score))
9 changes: 4 additions & 5 deletions examples/modeling/label_proportion_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
optimizer="RMSprop",
learning_rate=1e-2,
hidden_layer_activation="relu",
l2_alpha=1e-4,
dropout=0.05,
)

Expand All @@ -54,7 +53,7 @@
bg_seeds_ss,
bg_ss,
batch_size=10,
epochs=10,
epochs=2,
validation_split=0.2,
verbose=True,
bg_cps=300
Expand All @@ -74,8 +73,9 @@
)
print(f"Mean Test MAE: {test_meas.mean():.3f}")

# Save model in ONNX format
model_info_path, model_path = model.save("./model.onnx")
# Save model
model_path = "./model.json"
model.save(model_path, overwrite=True)

loaded_model = LabelProportionEstimator()
loaded_model.load(model_path)
Expand All @@ -89,5 +89,4 @@
print(f"Mean Test MAE: {test_maes.mean():.3f}")

# Clean up model file - remove this if you want to keep the model
os.remove(model_info_path)
os.remove(model_path)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ dependencies = [
"onnx ==1.14.1",
"tqdm ==4.65.*",
"numpy ==1.23.*",
"onnxruntime ==1.15.*",
"pandas ==2.0.*",
"parmap ==1.6.*",
"pythonnet ==3.0.*; platform_system == 'Windows'",
Expand Down
2 changes: 2 additions & 0 deletions riid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
SAMPLESET_HDF_FILE_EXTENSION = ".h5"
SAMPLESET_JSON_FILE_EXTENSION = ".json"
PCF_FILE_EXTENSION = ".pcf"
ONNX_MODEL_FILE_EXTENSION = ".onnx"
TFLITE_MODEL_FILE_EXTENSION = ".tflite"
RIID = "riid"

__version__ = get_distribution(RIID).version
Expand Down
4 changes: 2 additions & 2 deletions riid/data/sampleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
_pcf_to_dict, _unpack_compressed_text_buffer)


class SpectraState(Enum):
class SpectraState(int, Enum):
"""States in which SampleSet spectra can exist."""
Unknown = 0
Counts = 1
L1Normalized = 2
L2Normalized = 3


class SpectraType(Enum):
class SpectraType(int, Enum):
"""Types for SampleSet spectra."""
Unknown = 0
Background = 1
Expand Down
Loading

0 comments on commit b4ea1f2

Please sign in to comment.