Skip to content

Commit

Permalink
change the per-100-epochs monitoring of chi2 to avoid having to recom…
Browse files Browse the repository at this point in the history
…pute losses
  • Loading branch information
scarlehoff committed Mar 5, 2024
1 parent baefa5f commit 9e6dd2c
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 52 deletions.
10 changes: 8 additions & 2 deletions n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,12 +410,18 @@ def load_identical_replicas(self, model_file):

def save_weights(self, file, save_format="h5"):
"""
Compatibility function for tf < 2.16
Compatibility function for:
- tf < 2.16, keras < 3: argument save format needed for h5
- tf >= 2.16, keras >= 3: save format is deduced from the file extension
In both cases, the final weights are finally copied to the ``file`` path.
"""
try:
# Keras 2, tf < 2.16
super().save_weights(file, save_format=save_format)
except TypeError:
new_file = file.with_suffix(".weights.h5")
# Newer versions of keras (>=3) drop the ``save_format`` argument
# and instead take the format from the extension of the file
new_file = file.with_suffix(f".weights.{save_format}")
super().save_weights(new_file)
shutil.move(new_file, file)

Expand Down
12 changes: 5 additions & 7 deletions n3fit/src/n3fit/backends/keras_backend/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,13 @@ def __init__(self, stopping_object, log_freq=100):
super().__init__()
self.log_freq = log_freq
self.stopping_object = stopping_object
self._current_loss = None

def on_epoch_begin(self, epoch, logs=None):
# TODO This is an unnecessary performance hit, just for testing
self._current_loss = self.model.compute_losses()

def on_epoch_end(self, epoch, logs=None):
"""Function to be called at the end of every epoch"""
logs = self._current_loss
"""Function to be called at the end of every epoch
Every ``log_freq`` number of epochs, the ``monitor_chi2`` method of the ``stopping_object``
will be called and the validation loss (broken down by experiment) will be logged.
For the training model only the total loss is logged during the training.
"""
print_stats = ((epoch + 1) % self.log_freq) == 0
# Note that the input logs correspond to the fit before the weights are updated
self.stopping_object.monitor_chi2(logs, epoch, print_stats=print_stats)
Expand Down
14 changes: 9 additions & 5 deletions n3fit/src/n3fit/backends/keras_backend/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@

from validphys.convolution import OP

# Select a concatenate function depending on the tensorflow version
try:
# For tensorflow >= 2.16, Keras >= 3
concatenate_function = keras.ops.concatenate
except AttributeError:
# keras.ops was introduced in keras 3
concatenate_function = tf.concat


def evaluate(tensor):
"""Evaluate input tensor using the backend"""
Expand Down Expand Up @@ -251,11 +259,7 @@ def concatenate(tensor_list, axis=-1, target_shape=None, name=None):
Concatenates a list of numbers or tensor into a bigger tensor
If the target shape is given, the output is reshaped to said shape
"""
try:
# For tensorflow >= 2.16, Keras >= 3
concatenated_tensor = keras.ops.concatenate(tensor_list, axis=axis)
except AttributeError:
concatenated_tensor = tf.concat(tensor_list, axis=axis)
concatenated_tensor = concatenate_function(tensor_list, axis=axis)

if target_shape is None:
return concatenated_tensor
Expand Down
72 changes: 34 additions & 38 deletions n3fit/src/n3fit/stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,16 @@ class FitState:
all losses for the training model
validation_info: dict
all losses for the validation model
training_loss: float
total training loss, this can be given if per-exp``training_info``
is not available
"""

vl_ndata = None
tr_ndata = None
vl_suffix = None

def __init__(self, training_info, validation_info):
def __init__(self, training_info, validation_info, training_loss=None):
if self.vl_ndata is None or self.tr_ndata is None or self.vl_suffix is None:
raise ValueError(
"FitState cannot be instantiated until vl_ndata, tr_ndata and vl_suffix are filled"
Expand All @@ -164,6 +167,8 @@ def __init__(self, training_info, validation_info):
self._tr_chi2 = None # This is an overall training chi2
self._vl_dict = None
self._tr_dict = None
# This can be given if ``training_info`` is not given
self._training_loss = training_loss

@property
def vl_loss(self):
Expand All @@ -173,6 +178,8 @@ def vl_loss(self):
@property
def tr_loss(self):
"""Return the total validation loss as it comes from the info dictionaries"""
if self._training is None:
return self._training_loss
return self._training.get("loss")

def _parse_chi2(self):
Expand Down Expand Up @@ -223,7 +230,7 @@ def total_partial_tr_chi2(self):

def total_partial_vl_chi2(self):
"""Return the vl chi2 summed over replicas per experiment"""
return {k: np.sum(v) for k, v in self.all_tr_chi2.items()}
return {k: np.sum(v) for k, v in self.all_vl_chi2.items()}

def total_tr_chi2(self):
"""Return the total tr chi2 summed over replicas"""
Expand Down Expand Up @@ -273,27 +280,12 @@ def get_state(self, epoch):
f"Tried to get obtain the state for epoch {epoch} when only {len(self._history)} epochs have been saved"
) from e

def register(self, epoch, training_info, validation_info):
"""Save a new fitstate and updates the current final epoch
Parameters
----------
epoch: int
the current epoch of the fit
training_info: dict
all losses for the training model
validation_info: dict
all losses for the validation model
Returns
-------
FitState
def register(self, epoch, fitstate):
"""Save the current fitstate and the associated epoch
and set the current epoch as the final one should the fit end now
"""
# Save all the information in a fitstate object
fitstate = FitState(training_info, validation_info)
self.final_epoch = epoch
self._history.append(fitstate)
return fitstate


class Stopping:
Expand Down Expand Up @@ -425,8 +417,8 @@ def monitor_chi2(self, training_info, epoch, print_stats=False):
Parameters
----------
training_info: dict
output of a .fit() call, dictionary of the total loss (summed over replicas) for
each experiment
output of a .fit() call, dictionary of the total training loss
(summed over replicas and experiments)
epoch: int
index of the epoch
Expand All @@ -436,7 +428,7 @@ def monitor_chi2(self, training_info, epoch, print_stats=False):
true/false according to the status of the run
"""
# Step 1. Check whether the fit has NaN'd and stop it if so
if np.isnan(training_info["loss"]):
if np.isnan(training_loss := training_info["loss"]):
log.warning(" > NaN found, stopping activated")
self.make_stop()
return False
Expand All @@ -445,7 +437,9 @@ def monitor_chi2(self, training_info, epoch, print_stats=False):
validation_info = self._validation.compute_losses()

# Step 3. Register the current point in (the) history
fitstate = self._history.register(epoch, training_info, validation_info)
# and set the current final epoch as the current one
fitstate = FitState(None, validation_info, training_loss)
self._history.register(epoch, fitstate)
if print_stats:
self.print_current_stats(epoch, fitstate)

Expand Down Expand Up @@ -496,21 +490,23 @@ def _restore_best_weights(self):

def print_current_stats(self, epoch, fitstate):
"""
Prints ``fitstate`` training and validation chi2s
Prints ``fitstate`` validation chi2 for every experiment
and the current total training loss as well as the validation loss
after the training step
"""
epoch_index = epoch + 1
tr_chi2 = fitstate.total_tr_chi2()
vl_chi2 = fitstate.total_vl_chi2()
total_str = f"At epoch {epoch_index}/{self.total_epochs}, total chi2: {tr_chi2}\n"
total_str = f"""Epoch {epoch_index}/{self.total_epochs}: loss: {fitstate.tr_loss:.7f}
Validation loss after training step: {vl_chi2:.7f}.
Validation chi2s: """

# The partial chi2 makes no sense for more than one replica at once:
if self._n_replicas == 1:
partial_tr_chi2 = fitstate.total_partial_tr_chi2()
partial_vl_chi2 = fitstate.total_partial_vl_chi2()
partials = []
for experiment, chi2 in partial_tr_chi2.items():
for experiment, chi2 in partial_vl_chi2.items():
partials.append(f"{experiment}: {chi2:.3f}")
total_str += ", ".join(partials) + "\n"
total_str += f"Validation chi2 at this point: {vl_chi2}"
total_str += ", ".join(partials)
log.info(total_str)

def stop_here(self):
Expand All @@ -525,6 +521,7 @@ def stop_here(self):
def chi2exps_json(self, i_replica=0, log_each=100):
"""
Returns and apt-for-json dictionary with the status of the fit every `log_each` epochs
It reports the total training loss and the validation loss broken down by experiment.
Parameters
----------
Expand All @@ -543,16 +540,14 @@ def chi2exps_json(self, i_replica=0, log_each=100):

for epoch in range(log_each - 1, final_epoch + 1, log_each):
fitstate = self._history.get_state(epoch)
all_tr = fitstate.all_tr_chi2_for_replica(i_replica)
all_vl = fitstate.all_vl_chi2_for_replica(i_replica)
# Get the training and validation losses
tmp = {"training_loss": fitstate.tr_loss, "validation_loss": fitstate.vl_loss.tolist()}

tmp = {exp: {"training": tr_chi2} for exp, tr_chi2 in all_tr.items()}
for exp, vl_chi2 in all_vl.items():
if exp not in tmp:
tmp[exp] = {"training": None}
tmp[exp]["validation"] = vl_chi2
# And the validation chi2 broken down by experiment

tmp["validation_chi2s"] = fitstate.all_vl_chi2_for_replica(i_replica)
json_dict[epoch + 1] = tmp

return json_dict


Expand Down Expand Up @@ -586,6 +581,7 @@ def check_positivity(self, history_object):
otherwise, it passes.
It returns an array booleans which are True if positivity passed
story_object[key_loss] < self.threshold
Parameters
----------
history_object: dict
Expand Down

0 comments on commit 9e6dd2c

Please sign in to comment.