From 8820fc93ba41fcc16b8f970ed30d30b78c454291 Mon Sep 17 00:00:00 2001 From: Scenic Authors Date: Fri, 24 Jan 2025 14:45:55 -0800 Subject: [PATCH] Match the returned # of arguments/gradients between CPU hungarian matching custom VJP primal function and bwd rule. PiperOrigin-RevId: 719440273 --- scenic/model_lib/matchers/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scenic/model_lib/matchers/common.py b/scenic/model_lib/matchers/common.py index 0001de45..9dacf0aa 100644 --- a/scenic/model_lib/matchers/common.py +++ b/scenic/model_lib/matchers/common.py @@ -110,7 +110,7 @@ def matching_fn_hcb_vjp_fwd(cost, n_cols): return matching_fn_hcb(cost, n_cols), None def matching_fn_hcb_vjp_bwd(*_): - return (None,) # Return no gradient. + return (None, None) # Return no gradient. matching_fn_hcb.defvjp(matching_fn_hcb_vjp_fwd, matching_fn_hcb_vjp_bwd)