You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I think there is a problem with the support_set_loss attribute in the Classification layer.
Here is a minimal example to reproduce the bug:
import tensorflow as tf
from keras_fsl.layers import Classification
from keras_fsl.models.head_models import LearntNorms
kernel = LearntNorms(input_shape=(512,), activation="sigmoid")
classifier = Classification(kernel=kernel)
classifier.set_support_set(
support_tensors=tf.random.uniform(shape=(10, 512)),
support_labels_name=tf.constant(list("AAABBCCCDD")),
overwrite=tf.constant(True),
)
Here is the error I got:
Traceback (most recent call last):
File "<input>", line 4, in <module>
File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
result = self._call(*args, **kwds)
File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 627, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 506, in _initialize
*args, **kwds))
File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2446, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2777, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2667, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 981, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 441, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3299, in bound_method_wrapper
return wrapped_fn(*args, **kwargs)
File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 968, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
/Users/toubi/.pyenv/versions/3.6.9/envs/totem/src/keras-fsl/keras_fsl/layers/classification.py:83 set_support_set *
self.support_set_loss.assign(class_consistency_loss(support_labels_one_hot, pair_wise_scores))
/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:846 assign **
self._shape.assert_is_compatible_with(value_tensor.shape)
/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/framework/tensor_shape.py:1117 assert_is_compatible_with
raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes () and (None, None) are incompatible
The problem is that it cannot assign the tf variable support_set_loss with the new support set loss as shapes do not match. This is line 83 in keras_fsl/layers/classification.py:
Hi Clement,
I think there is a problem with the
support_set_loss
attribute in the Classification layer.Here is a minimal example to reproduce the bug:
Here is the error I got:
The problem is that it cannot assign the tf variable
support_set_loss
with the new support set loss as shapes do not match. This is line 83 in keras_fsl/layers/classification.py:I was able to fix this by redefining the
support_set_loss
variable before callingset_support_set
:Also, the is a problem with the Classification layer docstring: args do not correspond to the actual
__init__
arguments:The text was updated successfully, but these errors were encountered: