Skip to content

Commit

Permalink
Merge pull request stanfordnlp#1294 from theta-lin/fix-named-paramete…
Browse files Browse the repository at this point in the history
…rs-if

Fix conditional in Module.named_parameters() to prevent infinite recursion
  • Loading branch information
okhat authored Jul 20, 2024
2 parents 54bc449 + 3585efc commit 9b60ef2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
9 changes: 5 additions & 4 deletions dspy/primitives/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ def named_parameters(self):
named_parameters = []

def add_parameter(param_name, param_value):
if isinstance(param_value, Parameter) and id(param_value) not in visited:
visited.add(id(param_value))
param_name = postprocess_parameter_name(param_name, param_value)
named_parameters.append((param_name, param_value))
if isinstance(param_value, Parameter):
if id(param_value) not in visited:
visited.add(id(param_value))
param_name = postprocess_parameter_name(param_name, param_value)
named_parameters.append((param_name, param_value))

elif isinstance(param_value, dspy.Module):
# When a sub-module is pre-compiled, keep it frozen.
Expand Down
12 changes: 12 additions & 0 deletions tests/primitives/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,15 @@ def test_complex_module_traversal():
assert (
found_names == expected_names
), f"Missing or extra modules found. Missing: {expected_names-found_names}, Extra: {found_names-expected_names}"

class DuplicateModule(Module):
def __init__(self):
super().__init__()
self.p0 = dspy.Predict("question -> answer")
self.p1 = self.p0

def test_named_parameters_duplicate_references():
module = DuplicateModule()
# Only testing for whether exceptions are thrown or not
# As Module.named_parameters() is recursive, this is mainly for catching infinite recursion
module.named_parameters()

0 comments on commit 9b60ef2

Please sign in to comment.