diff --git a/dspy/primitives/module.py b/dspy/primitives/module.py index 04bb47dd9..085dd4a2a 100644 --- a/dspy/primitives/module.py +++ b/dspy/primitives/module.py @@ -23,10 +23,11 @@ def named_parameters(self): named_parameters = [] def add_parameter(param_name, param_value): - if isinstance(param_value, Parameter) and id(param_value) not in visited: - visited.add(id(param_value)) - param_name = postprocess_parameter_name(param_name, param_value) - named_parameters.append((param_name, param_value)) + if isinstance(param_value, Parameter): + if id(param_value) not in visited: + visited.add(id(param_value)) + param_name = postprocess_parameter_name(param_name, param_value) + named_parameters.append((param_name, param_value)) elif isinstance(param_value, dspy.Module): # When a sub-module is pre-compiled, keep it frozen. diff --git a/tests/primitives/test_program.py b/tests/primitives/test_program.py index 219ad5259..63a37a75d 100644 --- a/tests/primitives/test_program.py +++ b/tests/primitives/test_program.py @@ -136,3 +136,15 @@ def test_complex_module_traversal(): assert ( found_names == expected_names ), f"Missing or extra modules found. Missing: {expected_names-found_names}, Extra: {found_names-expected_names}" + +class DuplicateModule(Module): + def __init__(self): + super().__init__() + self.p0 = dspy.Predict("question -> answer") + self.p1 = self.p0 + +def test_named_parameters_duplicate_references(): + module = DuplicateModule() + # Only testing for whether exceptions are thrown or not + # As Module.named_parameters() is recursive, this is mainly for catching infinite recursion + module.named_parameters()