diff --git a/graphium/nn/utils.py b/graphium/nn/utils.py index 68a8779c4..e9ac4fa0c 100644 --- a/graphium/nn/utils.py +++ b/graphium/nn/utils.py @@ -40,7 +40,7 @@ def scale_kwargs(self, scale_factor: Real, scale_in_dim: bool = False): divide_factor = 1 / scale_factor - if scale_in_dim is None: + if not scale_in_dim: return self.make_mup_base_kwargs(divide_factor=divide_factor) # If scale_in_dim passed, need to check it can be forwarded