From 6f1a0552388f25727563908abde9f1b405b6e4b0 Mon Sep 17 00:00:00 2001 From: John Siirola Date: Wed, 7 Feb 2024 15:39:39 -0700 Subject: [PATCH 1/3] Update ExitNodeDispatcher to be compatible with inherited expression types --- pyomo/repn/util.py | 67 +++++++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/pyomo/repn/util.py b/pyomo/repn/util.py index b65aa9427d5..108bb0ab972 100644 --- a/pyomo/repn/util.py +++ b/pyomo/repn/util.py @@ -387,42 +387,49 @@ def __init__(self, *args, **kwargs): super().__init__(None, *args, **kwargs) def __missing__(self, key): - return functools.partial(self.register_dispatcher, key=key) - - def register_dispatcher(self, visitor, node, *data, key=None): + if type(key) is tuple: + node_class = key[0] + else: + node_class = key + bases = node_class.__mro__ + # Note: if we add an `etype`, then this special-case can be removed if ( - isinstance(node, _named_subexpression_types) - or type(node) is kernel.expression.noclone + issubclass(node_class, _named_subexpression_types) + or node_class is kernel.expression.noclone ): - base_type = Expression - elif not node.is_potentially_variable(): - base_type = node.potentially_variable_base_class() - else: - base_type = node.__class__ - if isinstance(key, tuple): - base_key = (base_type,) + key[1:] - # Only cache handlers for unary, binary and ternary operators - cache = len(key) <= 4 - else: - base_key = base_type - cache = True - if base_key in self: - fcn = self[base_key] - elif base_type in self: - fcn = self[base_type] - elif any((k[0] if k.__class__ is tuple else k) is base_type for k in self): - raise DeveloperError( - f"Base expression key '{base_key}' not found when inserting dispatcher" - f" for node '{type(node).__name__}' while walking expression tree." - ) - else: + bases = [Expression] + fcn = None + for base_type in bases: + if isinstance(key, tuple): + base_key = (base_type,) + key[1:] + # Only cache handlers for unary, binary and ternary operators + cache = len(key) <= 4 + else: + base_key = base_type + cache = True + if base_key in self: + fcn = self[base_key] + elif base_type in self: + fcn = self[base_type] + elif any((k[0] if type(k) is tuple else k) is base_type for k in self): + raise DeveloperError( + f"Base expression key '{base_key}' not found when inserting " + f"dispatcher for node '{node_class.__name__}' while walking " + "expression tree." + ) + if fcn is None: + if type(key) is tuple: + node_class = key[0] + else: + node_class = key raise DeveloperError( - f"Unexpected expression node type '{type(node).__name__}' " - "found while walking expression tree." + f"Unexpected expression node type '{node_class.__name__}' " + f"found while walking expression tree." ) + return self.unexpected_expression_type(key) if cache: self[key] = fcn - return fcn(visitor, node, *data) + return fcn def apply_node_operation(node, args): From 51d23370f198a98481d0260c4b72f93b757c9406 Mon Sep 17 00:00:00 2001 From: John Siirola Date: Wed, 7 Feb 2024 15:40:24 -0700 Subject: [PATCH 2/3] Refactor ExitNodeDispatcher to provide hook for unknown classes --- pyomo/repn/util.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pyomo/repn/util.py b/pyomo/repn/util.py index 108bb0ab972..cb67dd92494 100644 --- a/pyomo/repn/util.py +++ b/pyomo/repn/util.py @@ -418,19 +418,21 @@ def __missing__(self, key): "expression tree." ) if fcn is None: - if type(key) is tuple: - node_class = key[0] - else: - node_class = key - raise DeveloperError( - f"Unexpected expression node type '{node_class.__name__}' " - f"found while walking expression tree." - ) return self.unexpected_expression_type(key) if cache: self[key] = fcn return fcn + def unexpected_expression_type(self, key): + if type(key) is tuple: + node_class = key[0] + else: + node_class = key + raise DeveloperError( + f"Unexpected expression node type '{node_class.__name__}' " + f"found while walking expression tree." + ) + def apply_node_operation(node, args): try: From ce7a6b54256f03c24a1089223d355003869bfa63 Mon Sep 17 00:00:00 2001 From: John Siirola Date: Wed, 7 Feb 2024 15:40:39 -0700 Subject: [PATCH 3/3] Add tests for inherited classes --- pyomo/repn/tests/test_util.py | 36 +++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/pyomo/repn/tests/test_util.py b/pyomo/repn/tests/test_util.py index 47cc6b1a63a..3f455aad13f 100644 --- a/pyomo/repn/tests/test_util.py +++ b/pyomo/repn/tests/test_util.py @@ -19,6 +19,7 @@ from pyomo.common.errors import DeveloperError, InvalidValueError from pyomo.common.log import LoggingIntercept from pyomo.core.expr import ( + NumericExpression, ProductExpression, NPV_ProductExpression, SumExpression, @@ -671,16 +672,6 @@ def test_ExitNodeDispatcher_registration(self): self.assertEqual(len(end), 4) self.assertIn(NPV_ProductExpression, end) - class NewProductExpression(ProductExpression): - pass - - node = NewProductExpression((6, 7)) - with self.assertRaisesRegex( - DeveloperError, r".*Unexpected expression node type 'NewProductExpression'" - ): - end[node.__class__](None, node, *node.args) - self.assertEqual(len(end), 4) - end[SumExpression, 2] = lambda v, n, *d: 2 * sum(d) self.assertEqual(len(end), 5) @@ -710,6 +701,31 @@ class NewProductExpression(ProductExpression): self.assertEqual(len(end), 7) self.assertNotIn((SumExpression, 3, 4, 5, 6), end) + class NewProductExpression(ProductExpression): + pass + + node = NewProductExpression((6, 7)) + self.assertEqual(end[node.__class__](None, node, *node.args), 42) + self.assertEqual(len(end), 8) + self.assertIn(NewProductExpression, end) + + class UnknownExpression(NumericExpression): + pass + + node = UnknownExpression((6, 7)) + with self.assertRaisesRegex( + DeveloperError, r".*Unexpected expression node type 'UnknownExpression'" + ): + end[node.__class__](None, node, *node.args) + self.assertEqual(len(end), 8) + + node = UnknownExpression((6, 7)) + with self.assertRaisesRegex( + DeveloperError, r".*Unexpected expression node type 'UnknownExpression'" + ): + end[node.__class__, 6, 7](None, node, *node.args) + self.assertEqual(len(end), 8) + def test_BeforeChildDispatcher_registration(self): class BeforeChildDispatcherTester(BeforeChildDispatcher): @staticmethod