From 98b8a66369fa2cc24409b58e0b6eae586880b2ea Mon Sep 17 00:00:00 2001 From: ianmnz <99749877+ianmnz@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:42:41 +0100 Subject: [PATCH] Feature/improve output value with tolerance (#14) * Add tolerance to ExpectedOutput for floating point values * Add possibility to ignore component/variables * Expose tolerance for float comparison to the user * Explicitly call a approx function Revert "Expose tolerance for float comparison to the user" This reverts commit d7bda60748596a7c2f2908d484e567fc6677c72e. Revert "Add tolerance to ExpectedOutput for floating point values" This reverts commit 5781c90b387aaa9567aa96d4570bd71a84499750. --- src/andromede/simulation/output_values.py | 99 ++++++++++++++++++++--- tests/andromede/test_output_values.py | 60 ++++++++++++-- 2 files changed, 143 insertions(+), 16 deletions(-) diff --git a/src/andromede/simulation/output_values.py b/src/andromede/simulation/output_values.py index 21d0ec57..a569cdb2 100644 --- a/src/andromede/simulation/output_values.py +++ b/src/andromede/simulation/output_values.py @@ -13,15 +13,13 @@ """ Util class to obtain solver results """ +import math from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, TypeVar, Union, cast +from typing import Dict, List, Mapping, Optional, Tuple, TypeVar, Union, cast from andromede.simulation.optimization import SolverAndContext from andromede.study.data import TimeScenarioIndex -T = TypeVar("T") -K = TypeVar("K") - @dataclass class OutputValues: @@ -43,18 +41,45 @@ class Variable: _name: str _value: Dict[TimeScenarioIndex, float] = field(init=False, default_factory=dict) _size: Tuple[int, int] = field(init=False, default=(0, 0)) + ignore: bool = field(default=False, init=False) def __eq__(self, other: object) -> bool: if not isinstance(other, OutputValues.Variable): return NotImplemented - return ( + return (self.ignore or other.ignore) or ( self._name == other._name and self._size == other._size and self._value == other._value ) + def is_close( + self, + other: "OutputValues.Variable", + *, + rel_tol: float = 1.0e-9, + abs_tol: float = 0.0, + ) -> bool: + # From the docs in https://docs.python.org/3/library/math.html#math.isclose + # math.isclose(a, b) returns abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) + return (self.ignore or other.ignore) or ( + self._name == other._name + and self._size == other._size + and self._value.keys() == other._value.keys() + and all( + math.isclose( + self._value[key], + other._value[key], + rel_tol=rel_tol, + abs_tol=abs_tol, + ) + for key in self._value + ) + ) + def __str__(self) -> str: - return f"{self._name} : {str(self.value)}" + return ( + f"{self._name} : {str(self.value)} {'(ignored)' if self.ignore else ''}" + ) @property def value(self) -> Union[None, float, List[float], List[List[float]]]: @@ -111,14 +136,29 @@ class Component: _variables: Dict[str, "OutputValues.Variable"] = field( init=False, default_factory=dict ) + ignore: bool = field(default=False, init=False) def __eq__(self, other: object) -> bool: if not isinstance(other, OutputValues.Component): return NotImplemented - return self._id == other._id and self._variables == other._variables + return self.is_close(other, rel_tol=0.0, abs_tol=0.0) + + def is_close( + self, + other: "OutputValues.Component", + *, + rel_tol: float = 1.0e-9, + abs_tol: float = 0.0, + ) -> bool: + return (self.ignore or other.ignore) or ( + self._id == other._id + and _are_mappings_close( + self._variables, other._variables, rel_tol, abs_tol + ) + ) def __str__(self) -> str: - string = f"{self._id} :\n" + string = f"{self._id} : {'(ignored)' if self.ignore else ''}\n" for var in self._variables.values(): string += f" {str(var)}\n" return string @@ -139,7 +179,14 @@ def __post_init__(self) -> None: def __eq__(self, other: object) -> bool: if not isinstance(other, OutputValues): return NotImplemented - return self._components == other._components + return _are_mappings_close(self._components, other._components, 0.0, 0.0) + + def is_close( + self, other: "OutputValues", *, rel_tol: float = 1.0e-9, abs_tol: float = 0.0 + ) -> bool: + return _are_mappings_close( + self._components, other._components, rel_tol, abs_tol + ) def __str__(self) -> str: string = "\n" @@ -166,3 +213,37 @@ def component(self, component_id: str) -> "OutputValues.Component": if component_id not in self._components: self._components[component_id] = OutputValues.Component(component_id) return self._components[component_id] + + +Comparable = TypeVar("Comparable", OutputValues.Component, OutputValues.Variable) + + +def _are_mappings_close( + lhs: Mapping[str, Comparable], + rhs: Mapping[str, Comparable], + rel_tol: float, + abs_tol: float, +) -> bool: + lhs_keys = lhs.keys() + rhs_keys = rhs.keys() + + if (lhs_only_keys := lhs_keys - rhs_keys) and any( + not lhs[key].ignore for key in lhs_only_keys + ): + return False + + elif (rhs_only_keys := rhs_keys - lhs_keys) and any( + not rhs[key].ignore for key in rhs_only_keys + ): + return False + + elif intersect_keys := lhs_keys & rhs_keys: + if rel_tol == abs_tol == 0.0: + return all(lhs[key] == rhs[key] for key in intersect_keys) + else: + return all( + lhs[key].is_close(rhs[key], rel_tol=rel_tol, abs_tol=abs_tol) + for key in intersect_keys + ) + else: + return True diff --git a/tests/andromede/test_output_values.py b/tests/andromede/test_output_values.py index 14bb2c43..feea3a1e 100644 --- a/tests/andromede/test_output_values.py +++ b/tests/andromede/test_output_values.py @@ -37,7 +37,13 @@ def test_component_and_flow_output_object() -> None: variable_name="component_var_name", block_timestep=0, scenario=0, - ): mock_variable_component + ): mock_variable_component, + TimestepComponentVariableKey( + component_id="component_id_test", + variable_name="component_approx_var_name", + block_timestep=0, + scenario=0, + ): mock_variable_component, } opt_context.block_length.return_value = 1 @@ -45,16 +51,56 @@ def test_component_and_flow_output_object() -> None: problem = SolverAndContext(mock_variable_flow, opt_context) output = OutputValues(problem) - wrong_output = OutputValues() - wrong_output.component("component_id_test").var( + test_output = OutputValues() + assert output != test_output, f"Output is equal to empty output: {output}" + + test_output.component("component_id_test").ignore = True + assert ( + output == test_output + ), f"Output differs from the expected output after 'ignore': {output}" + + test_output.component("component_id_test").ignore = False + test_output.component("component_id_test").var("component_var_name").value = 1.0 + test_output.component("component_id_test").var( + "component_approx_var_name" + ).ignore = True + + assert ( + output == test_output + ), f"Output differs from the expected after 'var_name': {output}" + + test_output.component("component_id_test").var( + "component_approx_var_name" + ).ignore = False + test_output.component("component_id_test").var( + "component_approx_var_name" + ).value = 1.000_000_001 + + assert output != test_output and not output.is_close( + test_output + ), f"Output is equal to expected outside tolerance: {output}" + + test_output.component("component_id_test").var( + "component_approx_var_name" + ).value = 1.000_000_000_1 + + assert output != test_output and output.is_close( + test_output + ), f"Output differs from the expected inside tolerance: {output}" + + test_output.component("component_id_test").var( + "component_approx_var_name" + ).ignore = True + test_output.component("component_id_test").var( "wrong_component_var_name" ).value = 1.0 - assert output != wrong_output, f"Output is equal to wrong output: {output}" + assert output != test_output, f"Output is equal to wrong output: {output}" - expected_output = OutputValues() - expected_output.component("component_id_test").var("component_var_name").value = 1.0 + test_output.component("component_id_test").var( + "wrong_component_var_name" + ).ignore = True - assert output == expected_output, f"Output differs from expected: {output}" + assert output == test_output, f"Output differs from expected: {output}" print(output)