Skip to content
This repository has been archived by the owner on Apr 4, 2023. It is now read-only.

Commit

Permalink
Automatic merge of updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ChWick committed Jul 7, 2021
1 parent 4db5955 commit 5d8c1fa
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 20 deletions.
4 changes: 2 additions & 2 deletions examples/tutorial/full/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class TutorialModelParams(ModelBaseParams):
def cls():
return TutorialModel

@staticmethod
def graph_cls():
@classmethod
def graph_cls(cls):
return TutorialGraph

n_classes: int = field(
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ GitPython
imageio
nptyping
opencv-python-headless
openpyxl
paiargparse==1.1.1
pandas
pillow
Expand Down
2 changes: 1 addition & 1 deletion tfaip/lav/lav.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def run(
keras_model = keras.models.load_model(
os.path.join(self._params.model_path, "serve"),
compile=False,
custom_objects=model.all_custom_objects() if run_eagerly else None,
custom_objects=model.all_custom_objects() if run_eagerly else model.base_custom_objects(),
)

# create a new keras model that uses the inputs and outputs of the loaded model but adds the targets of the
Expand Down
9 changes: 8 additions & 1 deletion tfaip/model/modelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,17 @@ def params_cls(cls) -> Type[TMP]:

@classmethod
def all_custom_objects(cls) -> Dict[str, Type[tf.keras.layers.Layer]]:
"""Custom objects required to instantiate saved keras models"""
"""Custom objects required to instantiate saved keras models in eager mode (reinstantiation)"""
root_graph = cls.root_graph_cls()
return {
root_graph.__name__: root_graph,
**cls.base_custom_objects(),
}

@classmethod
def base_custom_objects(cls) -> Dict[str, Type[tf.keras.layers.Layer]]:
"""Custom objects required to instantiate saved keras models even in graph mode"""
return {
"TensorboardWriter": TensorboardWriter,
}

Expand Down
17 changes: 4 additions & 13 deletions tfaip/model/tensorboardwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,20 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
assert func is not None
self.initial_input_shape = input_shape
self.n_storage = n_storage
self.handle_fn = func
self.store_w = None
assert func is not None
if self.initial_input_shape is not None:
self.setup(self.initial_input_shape)

def setup(self, input_shape):
initial_value = np.zeros([0 if s is None else s for s in input_shape])
self.store_w = tf.Variable(
initial_value=initial_value,
shape=input_shape,
initial_value=[],
shape=tf.TensorShape(None), # dynamic shape
trainable=False,
validate_shape=False,
name="store",
dtype=self.dtype,
)

def update_state(self, y_true, y_pred, **kwargs):
if self.store_w is None:
self.setup(y_pred.shape)

del y_true # not used, the actual data is in y_pred, y_true is dummy data
return self.store_w.assign(y_pred)

Expand All @@ -84,7 +75,7 @@ def result(self):
return tf.stack(self.store_w)

def reset_states(self):
self.store_w.assign_sub(self.store_w) # set to zero
self.store_w.assign([])

def handle(self, name: str, value: np.ndarray, step: int):
return self.handle_fn(name, value, step)
3 changes: 2 additions & 1 deletion tfaip/predict/predictorbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tfaip.data.pipeline.datagenerator import DataGenerator
from tfaip.data.pipeline.datapipeline import DataPipeline
from tfaip.device.device_config import DeviceConfig, distribute_strategy
from tfaip.model.modelbase import ModelBase
from tfaip.predict.raw_predictor import RawPredictor
from tfaip.trainer.callbacks.benchmark_callback import BenchmarkResults
from tfaip.util.multiprocessing.parallelmap import tqdm_wrapper
Expand Down Expand Up @@ -92,7 +93,7 @@ def data(self):

def _load_model(self, model: Union[str, keras.Model]):
if isinstance(model, str):
model = keras.models.load_model(model, compile=False)
model = keras.models.load_model(model, compile=False, custom_objects=ModelBase.base_custom_objects())

return model

Expand Down
6 changes: 4 additions & 2 deletions tfaip/scenario/scenariobase.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,13 @@ def create_predictor(cls, model: str, params: "PredictorParams") -> "Predictor":
data_params = cls.params_from_path(model).data
post_init(data_params)
predictor = cls.predictor_cls()(params, cls.data_cls()(data_params))
model_cls = cls.model_cls()
run_eagerly = params.run_eagerly
if isinstance(model, str):
model = keras.models.load_model(
os.path.join(model, "serve"),
compile=False,
custom_objects=cls.model_cls().all_custom_objects() if params.run_eagerly else None,
custom_objects=model_cls.all_custom_objects() if run_eagerly else model_cls.base_custom_objects(),
)

predictor.set_model(model)
Expand Down Expand Up @@ -605,7 +607,7 @@ def assert_unique_keys(keys):
tmp, include_optimizer=False, options=tf.saved_model.SaveOptions(namespace_whitelist=["Addons"])
)
logger.info("Prediction model successfully saved. Attempting to load it")
keras.models.load_model(tmp, custom_objects=self.model.all_custom_objects())
keras.models.load_model(tmp, custom_objects=self._model.base_custom_objects())
logger.info("Model can be successfully loaded")

def _wrap_data(
Expand Down

0 comments on commit 5d8c1fa

Please sign in to comment.