Skip to content

Commit

Permalink
new JaxRecorder with recording_schedule instead of the update_rate
Browse files Browse the repository at this point in the history
  • Loading branch information
StephanHaa9 committed Dec 13, 2024
1 parent a98206b commit 11b4166
Showing 1 changed file with 47 additions and 20 deletions.
67 changes: 47 additions & 20 deletions znnl/training_recording/papyrus_jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
from znnl.analysis import JAXNTKComputation
from znnl.models import JaxModel

from typing import List, Union
import numpy as np


class JaxRecorder(BaseRecorder):
"""
Expand All @@ -51,8 +54,10 @@ class JaxRecorder(BaseRecorder):
The size of the chunks in which the data will be stored.
overwrite : bool (default=False)
Whether to overwrite the existing data in the database.
update_rate : int (default=1)
The rate at which the recorder will update the neural state.
recording_schedule : int or List[int]
The schedule at which the recorder will update the neural state.
- if int: The rate at which the recorder will update the neural state.
- if List[int]: At which epoch index the recorder will update the neural state.
neural_state_keys : List[str]
The keys of the neural state that the recorder takes as input.
A neural state is a dictionary of numpy arrays that represent the state of
Expand All @@ -71,15 +76,15 @@ class JaxRecorder(BaseRecorder):
An NTK computation module. For more information see the JAXNTKComputation
class.
"""

def __init__(
self,
name: str,
storage_path: str,
measurements: List[BaseMeasurement],
chunk_size: int = 1e5,
overwrite: bool = False,
update_rate: int = 1,
recording_schedule: Union[int, List[int]] = 1,
):
"""
Constructor method of the BaseRecorder class.
Expand All @@ -97,11 +102,11 @@ def __init__(
The size of the chunks in which the data will be stored.
overwrite : bool (default=False)
Whether to overwrite the existing data in the database.
update_rate : int (default=1)
The rate at which the recorder will update the neural state.
recording_schedule : Union[int, List[int]] (default=1)
The schedule at which the recorder will update the neural state.
"""
super().__init__(name, storage_path, measurements, chunk_size, overwrite)
self.update_rate = update_rate
self.recording_schedule = recording_schedule

self.neural_state = {}

Expand Down Expand Up @@ -218,6 +223,7 @@ def _compute_neural_state(self, model: JaxModel):
predictions = model(self._data_set[list(self._data_set.keys())[0]])
self.neural_state["predictions"] = [predictions]


def record(self, epoch: int, model: JaxModel, **kwargs):
"""
Perform the recording of a neural state.
Expand All @@ -240,16 +246,37 @@ def record(self, epoch: int, model: JaxModel, **kwargs):
result : onp.ndarray
The result of the recorder.
"""
if epoch % self.update_rate == 0:
# Compute the neural state
self._compute_neural_state(model)
# Add all other kwargs to the neural state dictionary
self.neural_state.update(kwargs)
for key, val in self._data_set.items():
self.neural_state[key] = [val]
# Check if incoming data is complete
self._check_keys()
# Perform measurements
self._measure(**self.neural_state)
# Store the measurements
self.store(ignore_chunk_size=False)

# Check the type of recording_schedule and determine if recording is needed
do_record = False
# if List[int]
if isinstance(self.recording_schedule, list):
# Check if the current epoch is in the schedule list
do_record = np.isin(epoch, self.recording_schedule)
# if int
elif isinstance(self.recording_schedule, int):
# Check if the current epoch is a multiple of the recording schedule
do_record = (epoch % self.recording_schedule == 0)
else:
raise ValueError(
f"Invalid type for recording_schedule: {type(self.recording_schedule)}. "
"Expected int or list of int."
)

# Perform recording if do_record is True
if do_record:
print(f"Recording at epoch {epoch}")
# Compute the neural state
self._compute_neural_state(model)
# Add all other kwargs to the neural state dictionary
self.neural_state.update(kwargs)
for key, val in self._data_set.items():
self.neural_state[key] = [val]
# Check if incoming data is complete
self._check_keys()
# Perform measurements
self._measure(**self.neural_state)
# Store the measurements
self.store(ignore_chunk_size=False)


0 comments on commit 11b4166

Please sign in to comment.