diff --git a/scenic/model_lib/matchers/common.py b/scenic/model_lib/matchers/common.py index 0001de45..9dacf0aa 100644 --- a/scenic/model_lib/matchers/common.py +++ b/scenic/model_lib/matchers/common.py @@ -110,7 +110,7 @@ def matching_fn_hcb_vjp_fwd(cost, n_cols): return matching_fn_hcb(cost, n_cols), None def matching_fn_hcb_vjp_bwd(*_): - return (None,) # Return no gradient. + return (None, None) # Return no gradient. matching_fn_hcb.defvjp(matching_fn_hcb_vjp_fwd, matching_fn_hcb_vjp_bwd)