Skip to content

Commit

Permalink
Add tests for inherited classes
Browse files Browse the repository at this point in the history
  • Loading branch information
jsiirola committed Feb 7, 2024
1 parent 51d2337 commit ce7a6b5
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions pyomo/repn/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pyomo.common.errors import DeveloperError, InvalidValueError
from pyomo.common.log import LoggingIntercept
from pyomo.core.expr import (
NumericExpression,
ProductExpression,
NPV_ProductExpression,
SumExpression,
Expand Down Expand Up @@ -671,16 +672,6 @@ def test_ExitNodeDispatcher_registration(self):
self.assertEqual(len(end), 4)
self.assertIn(NPV_ProductExpression, end)

class NewProductExpression(ProductExpression):
pass

node = NewProductExpression((6, 7))
with self.assertRaisesRegex(
DeveloperError, r".*Unexpected expression node type 'NewProductExpression'"
):
end[node.__class__](None, node, *node.args)
self.assertEqual(len(end), 4)

end[SumExpression, 2] = lambda v, n, *d: 2 * sum(d)
self.assertEqual(len(end), 5)

Expand Down Expand Up @@ -710,6 +701,31 @@ class NewProductExpression(ProductExpression):
self.assertEqual(len(end), 7)
self.assertNotIn((SumExpression, 3, 4, 5, 6), end)

class NewProductExpression(ProductExpression):
pass

node = NewProductExpression((6, 7))
self.assertEqual(end[node.__class__](None, node, *node.args), 42)
self.assertEqual(len(end), 8)
self.assertIn(NewProductExpression, end)

class UnknownExpression(NumericExpression):
pass

node = UnknownExpression((6, 7))
with self.assertRaisesRegex(
DeveloperError, r".*Unexpected expression node type 'UnknownExpression'"
):
end[node.__class__](None, node, *node.args)
self.assertEqual(len(end), 8)

node = UnknownExpression((6, 7))
with self.assertRaisesRegex(
DeveloperError, r".*Unexpected expression node type 'UnknownExpression'"
):
end[node.__class__, 6, 7](None, node, *node.args)
self.assertEqual(len(end), 8)

def test_BeforeChildDispatcher_registration(self):
class BeforeChildDispatcherTester(BeforeChildDispatcher):
@staticmethod
Expand Down

0 comments on commit ce7a6b5

Please sign in to comment.