From 880366eedfa6bcd4c62218573bc81132c17c9088 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Tue, 23 Apr 2024 15:52:06 +0100 Subject: [PATCH] tests: Add test for combining multiple ConditionalDimensions --- tests/test_dimension.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_dimension.py b/tests/test_dimension.py index 0c3746eb61..a5779e9076 100644 --- a/tests/test_dimension.py +++ b/tests/test_dimension.py @@ -1266,6 +1266,38 @@ def test_implicit_dims(self): assert np.all(f.data == F) + def test_implict_dims_multiple(self): + """Test supplying multiple ConditionalDimensions as implicit dimensions""" + shape = (50,) + start_value = 5 + stop_value = 20 + + time = Dimension(name='time') + f = TimeFunction(name='f', shape=shape, dimensions=[time]) + # The condition to start incrementing + cond0 = ConditionalDimension(name='cond0', + parent=time, condition=time > start_value) + # The condition to stop incrementing + cond1 = ConditionalDimension(name='cond1', + parent=time, condition=time < stop_value) + # Factor of 2 + cond2 = ConditionalDimension(name='cond2', parent=time, factor=2) + + eqs = [Eq(f.forward, f), Eq(f.forward, f.forward + 1, + implicit_dims=[cond0, cond1, cond2])] + op = Operator(eqs) + op.apply(time_M=shape[0] - 2) + + # Make the same calculation in python to assert the result + F = np.zeros(shape[0]) + val = 0 + for i in range(shape[0]): + F[i] = val + if i > start_value and i < stop_value and i % 2 == 0: + val += 1 + + assert np.all(f.data == F) + def test_grouping(self): """ Test that Clusters over the same set of ConditionalDimensions fall within