diff --git a/public/dimensional_analysis.py b/public/dimensional_analysis.py index a5eac9b3..bb03db12 100644 --- a/public/dimensional_analysis.py +++ b/public/dimensional_analysis.py @@ -63,7 +63,10 @@ ceiling, sign, sqrt, - factorial + factorial, + Basic, + Rational, + Integer ) class ExprWithAssumptions(Expr): @@ -115,34 +118,13 @@ class ImplicitParameter(TypedDict): original_value: str si_value: str - -# generated on the fly in evaluate_statements function, does in exist in incoming json -class UnitlessSubExpressionName(TypedDict): - name: str - unitlessContext: str - -class UnitlessSubExpression(TypedDict): - type: Literal["assignment"] - name: str - sympy: str - params: list[str] - isUnitlessSubExpression: Literal[True] - unitlessContext: str - isFunctionArgument: Literal[False] - isFunction: Literal[False] - unitlessSubExpressions: list['UnitlessSubExpression | UnitlessSubExpressionName'] - index: int # added in Python, not pressent in json - expression: Expr # added in Python, not pressent in json - class BaseUserFunction(TypedDict): type: Literal["assignment"] name: str sympy: str params: list[str] - isUnitlessSubExpression: Literal[False] isFunctionArgument: Literal[False] isFunction: Literal[True] - unitlessSubExpressions: list[UnitlessSubExpression | UnitlessSubExpressionName] functionParameters: list[str] index: int # added in Python, not pressent in json expression: Expr # added in Python, not pressent in json @@ -162,10 +144,8 @@ class UserFunctionRange(BaseUserFunction): class FunctionUnitsQuery(TypedDict): type: Literal["query"] sympy: str - unitlessSubExpressions: list[UnitlessSubExpression | UnitlessSubExpressionName] params: list[str] units: Literal[""] - isUnitlessSubExpression: Literal[False] isFunctionArgument: Literal[False] isFunction: Literal[False] isUnitsQuery: Literal[True] @@ -196,9 +176,7 @@ class FunctionArgumentAssignment(TypedDict): type: Literal["assignment"] name: str sympy: str - unitlessSubExpressions: list[UnitlessSubExpression | UnitlessSubExpressionName] params: list[str] - isUnitlessSubExpression: Literal[False] isFunctionArgument: Literal[True] isFunction: Literal[False] index: int # added in Python, not pressent in json @@ -207,10 +185,8 @@ class FunctionArgumentAssignment(TypedDict): class FunctionArgumentQuery(TypedDict): type: Literal["query"] sympy: str - unitlessSubExpressions: list[UnitlessSubExpression | UnitlessSubExpressionName] params: list[str] name: str - isUnitlessSubExpression: Literal[False] isFunctionArgument: Literal[True] isFunction: Literal[False] isUnitsQuery: Literal[False] @@ -225,7 +201,6 @@ class BlankStatement(TypedDict): type: Literal["blank"] params: list[str] # will be empty list implicitParams: list[ImplicitParameter] # will be empty list - unitlessSubExpressions: list[UnitlessSubExpression | UnitlessSubExpressionName] # will be empty list isFromPlotCell: Literal[False] index: int # added in Python, not pressent in json @@ -235,7 +210,6 @@ class QueryAssignmentCommon(TypedDict): functions: list[UserFunction | UserFunctionRange | FunctionUnitsQuery] arguments: list[FunctionArgumentQuery | FunctionArgumentAssignment] localSubs: list[LocalSubstitution | LocalSubstitutionRange] - unitlessSubExpressions: list[UnitlessSubExpression | UnitlessSubExpressionName] params: list[str] index: int # added in Python, not pressent in json expression: Expr # added in Python, not pressent in json @@ -243,7 +217,6 @@ class QueryAssignmentCommon(TypedDict): class AssignmentStatement(QueryAssignmentCommon): type: Literal["assignment"] name: str - isUnitlessSubExpression: Literal[False] isFunctionArgument: Literal[False] isFunction: Literal[False] isFromPlotCell: Literal[False] @@ -258,7 +231,6 @@ class SystemSolutionAssignmentStatement(AssignmentStatement): class BaseQueryStatement(QueryAssignmentCommon): type: Literal["query"] - isUnitlessSubExpression: Literal[False] isFunctionArgument: Literal[False] isFunction: Literal[False] isUnitsQuery: Literal[False] @@ -316,7 +288,6 @@ class ScatterXValuesQueryStatement(QueryAssignmentCommon): isDataTableQuery: Literal[False] isCodeFunctionQuery: Literal[False] isCodeFunctionRawQuery: Literal[False] - isUnitlessSubExpression: Literal[False] isFunctionArgument: Literal[False] isFunction: Literal[False] isUnitsQuery: Literal[False] @@ -336,7 +307,6 @@ class ScatterYValuesQueryStatement(QueryAssignmentCommon): isDataTableQuery: Literal[False] isCodeFunctionQuery: Literal[False] isCodeFunctionRawQuery: Literal[False] - isUnitlessSubExpression: Literal[False] isFunctionArgument: Literal[False] isFunction: Literal[False] isUnitsQuery: Literal[False] @@ -361,7 +331,6 @@ class ScatterQueryStatement(TypedDict): arguments: list[FunctionArgumentQuery | FunctionArgumentAssignment] localSubs: list[LocalSubstitution | LocalSubstitutionRange] implicitParams: list[ImplicitParameter] - unitlessSubExpressions: list[UnitlessSubExpression | UnitlessSubExpressionName] xValuesQuery: ScatterXValuesQueryStatement yValuesQuery: ScatterYValuesQueryStatement xName: str @@ -396,7 +365,6 @@ class EqualityUnitsQueryStatement(QueryAssignmentCommon): isDataTableQuery: Literal[False] isCodeFunctionQuery: Literal[False] isCodeFunctionRawQuery: Literal[False] - isUnitlessSubExpression: Literal[False] isFunctionArgument: Literal[False] isFunction: Literal[False] isUnitsQuery: Literal[False] @@ -407,7 +375,6 @@ class EqualityUnitsQueryStatement(QueryAssignmentCommon): class EqualityStatement(QueryAssignmentCommon): type: Literal["equality"] - isUnitlessSubExpression: Literal[False] isFunctionArgument: Literal[False] isFunction: Literal[False] isFromPlotCell: Literal[False] @@ -487,14 +454,13 @@ class LocalSubstitutionStatement(TypedDict): name: str params: list[str] function_subs: dict[str, dict[str, str]] - isUnitlessSubExpression: Literal[False] index: int InputStatement = AssignmentStatement | QueryStatement | RangeQueryStatement | BlankStatement | \ CodeFunctionQueryStatement | ScatterQueryStatement | SubQueryStatement InputAndSystemStatement = InputStatement | EqualityUnitsQueryStatement | GuessAssignmentStatement | \ SystemSolutionAssignmentStatement -Statement = InputStatement | UnitlessSubExpression | UserFunction | UserFunctionRange | FunctionUnitsQuery | \ +Statement = InputStatement | UserFunction | UserFunctionRange | FunctionUnitsQuery | \ FunctionArgumentQuery | FunctionArgumentAssignment | \ SystemSolutionAssignmentStatement | LocalSubstitutionStatement | \ GuessAssignmentStatement | EqualityUnitsQueryStatement | CodeFunctionRawQuery | \ @@ -626,7 +592,6 @@ class CombinedExpressionBlank(TypedDict): isBlank: Literal[True] isRange: Literal[False] isScatter: Literal[False] - unitlessSubExpressions: list[UnitlessSubExpression | UnitlessSubExpressionName] isSubQuery: Literal[False] subQueryName: Literal[""] @@ -634,7 +599,6 @@ class CombinedExpressionNoRange(TypedDict): index: int name: str expression: Expr - unitlessSubExpressions: list[UnitlessSubExpression | UnitlessSubExpressionName] isBlank: Literal[False] isRange: Literal[False] isScatter: Literal[False] @@ -653,7 +617,6 @@ class CombinedExpressionRange(TypedDict): index: int name: str expression: Expr - unitlessSubExpressions: list[UnitlessSubExpression | UnitlessSubExpressionName] isBlank: Literal[False] isRange: Literal[True] isParametric: bool @@ -690,6 +653,10 @@ class CombinedExpressionScatter(TypedDict): CombinedExpression = CombinedExpressionBlank | CombinedExpressionNoRange | CombinedExpressionRange | \ CombinedExpressionScatter +class DimValues(TypedDict): + args: list[Expr] + result: Expr + # maps from mathjs dimensions object to sympy dimensions dim_map: dict[int, Dimension] = { 0: mass, @@ -766,6 +733,9 @@ def get_base_units(custom_base_units: CustomBaseUnits | None= None) -> dict[tupl # precision for sympy evalf calls to convert expressions to floating point values PRECISION = 64 +# very large rationals are inefficient for exponential calculations +LARGE_RATIONAL = 1000000 + # num of digits to round to for unit exponents # this makes sure units with a very small difference are identified as the same EXP_NUM_DIGITS = 12 @@ -897,35 +867,6 @@ def custom_latex(expression: Expr) -> str: _range = Function("_range") -def walk_tree(grandparent_func, parent_func, expr) -> Expr: - - if is_matrix(expr): - rows = [] - for i in range(expr.rows): - row = [] - rows.append(row) - for j in range(expr.cols): - row.append(walk_tree(parent_func, Matrix, expr[i,j])) - - return cast(Expr, Matrix(rows)) - - if len(expr.args) == 0: - if parent_func is not Pow and parent_func is not Inverse and expr.is_negative: - return -1*expr - else: - return expr - - if expr.func == _range: - new_args = expr.args - else: - new_args = (walk_tree(parent_func, expr.func, arg) for arg in expr.args) - - return expr.func(*new_args) - -def subtraction_to_addition(expression: Expr | Matrix) -> Expr: - return walk_tree("root", "root", expression) - - def ensure_dims_all_compatible(*args): if args[0].is_zero: if all(arg.is_zero for arg in args): @@ -1019,11 +960,32 @@ def custom_matmul(exp1: Expr, exp2: Expr): else: return Mul(exp1, exp2) -def custom_matmul_dims(*args: Expr): - if len(args) == 2 and is_matrix(args[0]) and is_matrix(args[1]) and \ +def custom_multiply_dims(matmult: bool, *args: Expr): + matrix_args: list[Matrix] = [] + scalar_args: list[Expr] = [] + for arg in args: + if is_matrix(arg): + matrix_args.append(arg) + else: + scalar_args.append(arg) + + if len(matrix_args) > 0 and len(scalar_args) > 0: + first_matrix = matrix_args[0] + scalar = Mul(*scalar_args) + new_rows = [] + for i in range(first_matrix.rows): + new_row = [] + new_rows.append(new_row) + for j in range(first_matrix.cols): + new_row.append(scalar*first_matrix[i,j]) # type: ignore + + matrix_args[0] = Matrix(new_rows) + args = cast(tuple[Expr], matrix_args) + + if matmult and len(args) == 2 and is_matrix(args[0]) and is_matrix(args[1]) and \ (((args[0].rows == 3 and args[0].cols == 1) and (args[1].rows == 3 and args[1].cols == 1)) or \ ((args[0].rows == 1 and args[0].cols == 3) and (args[1].rows == 1 and args[1].cols == 3))): - + # cross product detected for matrix multiplication operator result = Matrix([Add(Mul(args[0][1],args[1][2]),Mul(args[0][2],args[1][1])), Add(Mul(args[0][2],args[1][0]),Mul(args[0][0],args[1][2])), Add(Mul(args[0][0],args[1][1]),Mul(args[0][1],args[1][0]))]) @@ -1115,6 +1077,9 @@ def custom_range(*args: Expr): return Matrix(values) +def custom_range_dims(dim_values: DimValues, *args: Expr): + return Matrix([ensure_dims_all_compatible(*args)]*len(cast(Matrix, dim_values["result"]))) + class PlaceholderFunction(TypedDict): dim_func: Callable | Function sympy_func: object @@ -1129,6 +1094,16 @@ def IndexMatrix(expression: Expr, i: Expr, j: Expr) -> Expr: return expression[i-1, j-1] # type: ignore +def IndexMatrix_dims(dim_values: DimValues, expression: Expr, i: Expr, j: Expr) -> Expr: + if custom_get_dimensional_dependencies(i) != {} or \ + custom_get_dimensional_dependencies(j) != {}: + raise TypeError('Matrix Index Not Dimensionless') + + i_value = dim_values["args"][1] + j_value = dim_values["args"][2] + + return expression[i_value-1, j_value-1] # type: ignore + class CustomFactorial(Function): is_real = True @@ -1184,7 +1159,25 @@ def custom_integral_dims(local_expr: Expr, global_expr: Expr, dummy_integral_var return global_expr * lower_limit_dims # type: ignore else: return global_expr * integral_var # type: ignore + +def custom_add_dims(*args: Expr): + return Add(*[Abs(arg) for arg in args]) +def custom_pow(base: Expr, exponent: Expr): + large_rational = False + for atom in (exponent.atoms(Rational) | base.atoms(Rational)): + if abs(atom.q) > LARGE_RATIONAL: + large_rational = True + + if large_rational: + return Pow(base.evalf(PRECISION), exponent.evalf(PRECISION)) + else: + return Pow(base, exponent) + +def custom_pow_dims(dim_values: DimValues, base: Expr, exponent: Expr): + if custom_get_dimensional_dependencies(exponent) != {}: + raise TypeError('Exponent Not Dimensionless') + return Pow(base.evalf(PRECISION), (dim_values["args"][1]).evalf(PRECISION)) CP = None @@ -1452,6 +1445,9 @@ def __init__(self): def get_next_id(self): self._next_id += 1 return self._next_id-1 + +dim_needs_values_wrapper = Function('_dim_needs_values_wrapper') +function_id_wrapper = Function('_function_id_wrapper') global_placeholder_map: dict[Function, PlaceholderFunction] = { cast(Function, Function('_StrictLessThan')) : {"dim_func": ensure_dims_all_compatible, "sympy_func": StrictLessThan}, @@ -1481,9 +1477,9 @@ def get_next_id(self): cast(Function, Function('_Inverse')) : {"dim_func": ensure_inverse_dims, "sympy_func": UniversalInverse}, cast(Function, Function('_Transpose')) : {"dim_func": custom_transpose, "sympy_func": custom_transpose}, cast(Function, Function('_Determinant')) : {"dim_func": custom_determinant, "sympy_func": custom_determinant}, - cast(Function, Function('_mat_multiply')) : {"dim_func": custom_matmul_dims, "sympy_func": custom_matmul}, - cast(Function, Function('_multiply')) : {"dim_func": Mul, "sympy_func": Mul}, - cast(Function, Function('_IndexMatrix')) : {"dim_func": IndexMatrix, "sympy_func": IndexMatrix}, + cast(Function, Function('_mat_multiply')) : {"dim_func": partial(custom_multiply_dims, True), "sympy_func": custom_matmul}, + cast(Function, Function('_multiply')) : {"dim_func": partial(custom_multiply_dims, False), "sympy_func": Mul}, + cast(Function, Function('_IndexMatrix')) : {"dim_func": IndexMatrix_dims, "sympy_func": IndexMatrix}, cast(Function, Function('_Eq')) : {"dim_func": Eq, "sympy_func": Eq}, cast(Function, Function('_norm')) : {"dim_func": custom_norm, "sympy_func": custom_norm}, cast(Function, Function('_dot')) : {"dim_func": custom_dot, "sympy_func": custom_dot}, @@ -1492,13 +1488,14 @@ def get_next_id(self): cast(Function, Function('_round')) : {"dim_func": ensure_unitless_in, "sympy_func": custom_round}, cast(Function, Function('_Derivative')) : {"dim_func": custom_derivative_dims, "sympy_func": custom_derivative}, cast(Function, Function('_Integral')) : {"dim_func": custom_integral_dims, "sympy_func": custom_integral}, - cast(Function, Function('_range')) : {"dim_func": custom_range, "sympy_func": custom_range}, + cast(Function, Function('_range')) : {"dim_func": custom_range_dims, "sympy_func": custom_range}, cast(Function, Function('_factorial')) : {"dim_func": factorial, "sympy_func": CustomFactorial}, + cast(Function, Function('_add')) : {"dim_func": custom_add_dims, "sympy_func": Add}, + cast(Function, Function('_Pow')) : {"dim_func": custom_pow_dims, "sympy_func": custom_pow}, } global_placeholder_set = set(global_placeholder_map.keys()) dummy_var_placeholder_set = (Function('_Derivative'), Function('_Integral')) -multiply_placeholder_set = (Function('_multiply'), Function('_mat_multiply')) placeholder_inverse_map = { value["sympy_func"]: key for key, value in reversed(global_placeholder_map.items()) } placeholder_inverse_set = set(placeholder_inverse_map.keys()) @@ -1511,11 +1508,22 @@ def replace_sympy_funcs_with_placeholder_funcs(expression: Expr) -> Expr: return expression + def replace_placeholder_funcs(expr: Expr, func_key: Literal["dim_func"] | Literal["sympy_func"], placeholder_map: dict[Function, PlaceholderFunction], placeholder_set: set[Function], + dim_values_dict: dict[tuple[Basic,...], DimValues], + function_parents: list[Basic], data_table_subs: DataTableSubs | None) -> Expr: + + if (not is_matrix(expr)) and expr.func == function_id_wrapper: + function_parents.append(expr.args[0]) + expr = cast(Expr, expr.args[1]) + + if (not is_matrix(expr)) and isinstance(expr, Symbol) and expr.name == "_zero_delayed_substitution": + return sympify('0') + if is_matrix(expr): rows = [] for i in range(expr.rows): @@ -1524,50 +1532,49 @@ def replace_placeholder_funcs(expr: Expr, for j in range(expr.cols): row.append(replace_placeholder_funcs(cast(Expr, expr[i,j]), func_key, placeholder_map, placeholder_set, + dim_values_dict, function_parents, data_table_subs) ) return cast(Expr, Matrix(rows)) + expr = cast(Expr,expr) + if len(expr.args) == 0: return expr - if func_key == "dim_func" and expr.func in multiply_placeholder_set: - processed_args = [replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, data_table_subs) for arg in expr.args] - matrix_args = [] - scalar_args = [] - for arg in processed_args: - if is_matrix(cast(Expr, arg)): - matrix_args.append(arg) + if expr.func == dim_needs_values_wrapper: + if func_key == "sympy_func": + child_expr = expr.args[1] + function_parents_snapshot = list(function_parents) + dim_args = [replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) for arg in child_expr.args] + result = cast(Expr, cast(Callable, placeholder_map[cast(Function, child_expr.func)][func_key])(*dim_args)) + if data_table_subs is not None and len(data_table_subs.subs_stack) > 0: + dim_args_snapshot = list(dim_args) + for i, value in enumerate(dim_args_snapshot): + dim_args_snapshot[i] = cast(Expr, value.subs({key: cast(Matrix, value)[0,0] for key, value in data_table_subs.subs_stack[-1].items()})) + result_snapshot = cast(Expr, cast(Callable, placeholder_map[cast(Function, child_expr.func)][func_key])(*dim_args_snapshot)) + dim_values_dict[(expr.args[0], *function_parents_snapshot)] = DimValues(args=dim_args_snapshot, result=result_snapshot) else: - scalar_args.append(arg) - - if len(matrix_args) > 0 and len(scalar_args) > 0: - first_matrix = matrix_args[0] - scalar = math.prod(scalar_args) - new_rows = [] - for i in range(first_matrix.rows): - new_row = [] - new_rows.append(new_row) - for j in range(first_matrix.cols): - new_row.append(scalar*first_matrix[i,j]) - - matrix_args[0] = Matrix(new_rows) - - return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*matrix_args)) + dim_values_dict[(expr.args[0], *function_parents_snapshot)] = DimValues(args=dim_args, result=result) + return result else: - return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*processed_args)) + child_expr = expr.args[1] + dim_values = dim_values_dict[(expr.args[0],*function_parents)] + child_processed_args = [replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) for arg in child_expr.args] + return cast(Expr, cast(Callable, placeholder_map[cast(Function, child_expr.func)][func_key])(dim_values, *child_processed_args)) elif expr.func in dummy_var_placeholder_set and func_key == "dim_func": - return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, data_table_subs) if index > 0 else arg for index, arg in enumerate(expr.args)))) + return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) if index > 0 else arg for index, arg in enumerate(expr.args)))) elif expr.func in placeholder_set: - return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, data_table_subs) for arg in expr.args))) + return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) for arg in expr.args))) + elif data_table_subs is not None and expr.func == data_table_calc_wrapper: if len(expr.args[0].atoms(data_table_id_wrapper)) == 0: - return replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, data_table_subs) + return replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) data_table_subs.subs_stack.append({}) data_table_subs.shortest_col_stack.append(None) - sub_expr = replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, data_table_subs) + sub_expr = replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) subs = data_table_subs.subs_stack.pop() shortest_col = data_table_subs.shortest_col_stack.pop() @@ -1588,7 +1595,7 @@ def replace_placeholder_funcs(expr: Expr, return cast(Expr, Matrix([sub_expr,]*shortest_col)) elif data_table_subs is not None and expr.func == data_table_id_wrapper: - current_expr = replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, data_table_subs) + current_expr = replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) new_var = Symbol(f"_data_table_var_{data_table_subs.get_next_id()}") if not is_matrix(current_expr): @@ -1606,16 +1613,15 @@ def replace_placeholder_funcs(expr: Expr, return cast(Expr, current_expr[0,0]) else: - return cast(Expr, expr.func(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, data_table_subs) for arg in expr.args))) + return cast(Expr, expr.func(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) for arg in expr.args))) def get_dimensional_analysis_expression(parameter_subs: dict[Symbol, Expr], expression: Expr, placeholder_map: dict[Function, PlaceholderFunction], - placeholder_set: set[Function]) -> tuple[Expr | None, Exception | None]: - # need to remove any subtractions or unary negative since this may - # lead to unintentional cancellation during the parameter substitution process - positive_only_expression = subtraction_to_addition(expression) - expression_with_parameter_subs = cast(Expr, positive_only_expression.xreplace(parameter_subs)) + placeholder_set: set[Function], + dim_values_dict: dict[tuple[Basic,...], DimValues]) -> tuple[Expr | None, Exception | None]: + + expression_with_parameter_subs = cast(Expr, expression.xreplace(parameter_subs)) error = None final_expression = None @@ -1623,7 +1629,7 @@ def get_dimensional_analysis_expression(parameter_subs: dict[Symbol, Expr], try: final_expression = replace_placeholder_funcs(expression_with_parameter_subs, "dim_func", placeholder_map, placeholder_set, - DataTableSubs()) + dim_values_dict, [], DataTableSubs()) except Exception as e: error = e @@ -1660,9 +1666,9 @@ def dimensional_analysis(dimensional_analysis_expression: Expr | None, dim_sub_e custom_units_defined = True except TypeError as e: - print(f"Dimension Error: {e}") - result = "Dimension Error" - result_latex = "Dimension Error" + result = f"Dimension Error: {e}" + result_latex = result + print(result) return result, result_latex, custom_units_defined, custom_units, custom_units_latex @@ -1733,7 +1739,7 @@ def get_sorted_statements(statements: list[Statement], custom_definition_names: zero_place_holder: ImplicitParameter = { "dimensions": [0]*9, "original_value": "0", - "si_value": "0", + "si_value": "_zero_delayed_substitution", "name": ZERO_PLACEHOLDER, "units": "" } @@ -1751,16 +1757,7 @@ def expand_with_sub_statements(statements: list[InputAndSystemStatement]): local_sub_statements: dict[str, LocalSubstitutionStatement] = {} - included_unitless_sub_expressions: set[str] = set() - for statement in statements: - # need to prevent inclusion of already included exponents since solving a system of equations - # will repeat exponents for each variable that is solved for - for unitless_sub_expression in cast(list[UnitlessSubExpression], statement["unitlessSubExpressions"]): - if unitless_sub_expression["name"] not in included_unitless_sub_expressions: - new_statements.append(unitless_sub_expression) - included_unitless_sub_expressions.update([unitless_sub_expression["name"] for unitless_sub_expression in statement["unitlessSubExpressions"]]) - new_statements.extend(statement.get("functions", [])) new_statements.extend(statement.get("arguments", [])) for local_sub in statement.get("localSubs", []): @@ -1770,7 +1767,6 @@ def expand_with_sub_statements(statements: list[InputAndSystemStatement]): "index": 0, # placeholder, will be set in sympy_statements "params": [], "function_subs": {}, - "isUnitlessSubExpression": False }) combined_sub["params"].append(local_sub["argument"]) function_subs = combined_sub["function_subs"] @@ -1802,25 +1798,22 @@ def get_parameter_subs(parameters: list[ImplicitParameter], convert_floats_to_fr return parameter_subs -def sympify_statements(statements: list[Statement] | list[EqualityStatement], - sympify_unitless_sub_expressions=False, convert_floats_to_fractions=True): +def sympify_statements(statements: list[Statement] | list[EqualityStatement], convert_floats_to_fractions=True): for i, statement in enumerate(statements): statement["index"] = i if statement["type"] != "local_sub" and statement["type"] != "blank" and \ statement["type"] != "scatterQuery": try: statement["expression"] = sympify(statement["sympy"], rational=convert_floats_to_fractions) - if sympify_unitless_sub_expressions: - for unitless_sub_expression in cast(list[UnitlessSubExpression], statement["unitlessSubExpressions"]): - unitless_sub_expression["expression"] = sympify(unitless_sub_expression["sympy"], rational=convert_floats_to_fractions) + except SyntaxError: print(f"Parsing error for equation {statement['sympy']}") raise ParsingError -def remove_implicit_and_unitless_sub_expression(input_set: set[str]) -> set[str]: +def remove_implicit(input_set: set[str]) -> set[str]: return {variable for variable in input_set - if not variable.startswith( ("implicit_param__", "unitless__") )} + if not variable.startswith("implicit_param__")} def solve_system(statements: list[EqualityStatement], variables: list[str], @@ -1829,8 +1822,7 @@ def solve_system(statements: list[EqualityStatement], variables: list[str], parameters = get_all_implicit_parameters(statements) parameter_subs = get_parameter_subs(parameters, convert_floats_to_fractions) - sympify_statements(statements, sympify_unitless_sub_expressions=True, - convert_floats_to_fractions=convert_floats_to_fractions) + sympify_statements(statements, convert_floats_to_fractions=convert_floats_to_fractions) # give all of the statements an index so that they can be re-ordered for i, statement in enumerate(statements): @@ -1838,26 +1830,22 @@ def solve_system(statements: list[EqualityStatement], variables: list[str], # define system of equations for sympy.solve function # substitute in all exponents and placeholder functions - system_unitless_sub_expressions: list[UnitlessSubExpression | UnitlessSubExpressionName] = [] system_implicit_params: list[ImplicitParameter] = [] system_variables: set[str] = set() system: list[Expr] = [] for statement in statements: system_variables.update(statement["params"]) - system_unitless_sub_expressions.extend(statement["unitlessSubExpressions"]) system_implicit_params.extend(statement["implicitParams"]) - equality = cast(Expr, statement["expression"]).subs( - {unitless_sub_expression["name"]:unitless_sub_expression["expression"] for unitless_sub_expression in cast(list[UnitlessSubExpression], statement["unitlessSubExpressions"])}) - equality = replace_placeholder_funcs(cast(Expr, equality), + equality = replace_placeholder_funcs(cast(Expr, statement["expression"]), "sympy_func", - placeholder_map, placeholder_set, None) + placeholder_map, placeholder_set, {}, [], None) system.append(cast(Expr, equality.doit())) # remove implicit parameters before solving - system_variables = remove_implicit_and_unitless_sub_expression(system_variables) + system_variables = remove_implicit(system_variables) solutions: list[dict[Symbol, Expr]] = [] solutions = solve(system, variables, dict=True) @@ -1888,8 +1876,6 @@ def solve_system(statements: list[EqualityStatement], variables: list[str], "expression": expression, "implicitParams": system_implicit_params if counter == 0 else [], # only include for one variable in solution to prevent dups "params": [variable.name for variable in cast(list[Symbol], expression.free_symbols)], - "unitlessSubExpressions": system_unitless_sub_expressions, - "isUnitlessSubExpression": False, "isFunction": False, "isFunctionArgument": False, "isRange": False, @@ -1919,8 +1905,7 @@ def solve_system_numerical(statements: list[EqualityStatement], variables: list[ parameters = get_all_implicit_parameters([*statements, *guess_statements]) parameter_subs = get_parameter_subs(parameters, convert_floats_to_fractions) - sympify_statements(statements, sympify_unitless_sub_expressions=True, - convert_floats_to_fractions=convert_floats_to_fractions) + sympify_statements(statements, convert_floats_to_fractions=convert_floats_to_fractions) # give all of the statements an index so that they can be re-ordered for i, statement in enumerate(statements): @@ -1929,25 +1914,21 @@ def solve_system_numerical(statements: list[EqualityStatement], variables: list[ # define system of equations for sympy.solve function # substitute in all exponents, implicit params, and placeholder functions # add equalityUnitsQueries to new_statements that will be added to the whole sheet - system_unitless_sub_expressions: list[UnitlessSubExpression | UnitlessSubExpressionName] = [] system_variables: set[str] = set() system: list[Expr] = [] new_statements: list[EqualityUnitsQueryStatement | GuessAssignmentStatement] = [] for statement in statements: system_variables.update(statement["params"]) - system_unitless_sub_expressions.extend(statement["unitlessSubExpressions"]) - equality = cast(Expr, statement["expression"]).subs( - {unitless_sub_expression["name"]: unitless_sub_expression["expression"] for unitless_sub_expression in cast(list[UnitlessSubExpression], statement["unitlessSubExpressions"])}) - equality = equality.subs(parameter_subs) + equality = cast(Expr, statement["expression"]).subs(parameter_subs) equality = replace_placeholder_funcs(cast(Expr, equality), "sympy_func", - placeholder_map, placeholder_set, None) + placeholder_map, placeholder_set, {}, [], None) system.append(cast(Expr, equality.doit())) new_statements.extend(statement["equalityUnitsQueries"]) # remove implicit parameters before solving - system_variables = remove_implicit_and_unitless_sub_expression(system_variables) + system_variables = remove_implicit(system_variables) solutions: list[dict[Symbol, float]] | list[Any] = [] try: @@ -2373,12 +2354,13 @@ def get_evaluated_expression(expression: Expr, parameter_subs: dict[Symbol, Expr], simplify_symbolic_expressions: bool, placeholder_map: dict[Function, PlaceholderFunction], - placeholder_set: set[Function]) -> tuple[ExprWithAssumptions, str | list[list[str]]]: + placeholder_set: set[Function]) -> tuple[ExprWithAssumptions, str | list[list[str]], dict[tuple[Basic,...],DimValues]]: expression = cast(Expr, expression.xreplace(parameter_subs)) + dim_values_dict: dict[tuple[Basic,...], DimValues] = {} expression = replace_placeholder_funcs(expression, "sympy_func", placeholder_map, - placeholder_set, + placeholder_set, dim_values_dict, [], DataTableSubs()) if not is_matrix(expression): if simplify_symbolic_expressions: @@ -2403,13 +2385,11 @@ def get_evaluated_expression(expression: Expr, row.append(custom_latex(cast(Expr, expression[i,j]))) evaluated_expression = cast(ExprWithAssumptions, expression.evalf(PRECISION)) - return evaluated_expression, symbolic_expression + return evaluated_expression, symbolic_expression, dim_values_dict def get_result(evaluated_expression: ExprWithAssumptions, dimensional_analysis_expression: Expr | None, dim_sub_error: Exception | None, symbolic_expression: str, - unitless_sub_expressions: list[UnitlessSubExpression | UnitlessSubExpressionName], - isRange: bool, unitless_sub_expression_dimensionless: dict[str, bool], - custom_base_units: CustomBaseUnits | None, + isRange: bool, custom_base_units: CustomBaseUnits | None, isSubQuery: bool, subQueryName: str ) -> Result | FiniteImagResult: @@ -2417,12 +2397,7 @@ def get_result(evaluated_expression: ExprWithAssumptions, dimensional_analysis_e custom_units = "" custom_units_latex = "" - if not all([unitless_sub_expression_dimensionless[local_item["name"]] for local_item in unitless_sub_expressions]): - context_set = {local_item["unitlessContext"] for local_item in unitless_sub_expressions if not unitless_sub_expression_dimensionless[local_item["name"]]} - context_combined = ", ".join(context_set) - dim = f"Dimension Error: {context_combined} Not Dimensionless" - dim_latex = f"Dimension Error: {context_combined} Not Dimensionless" - elif isRange: + if isRange: # a separate unitsQuery function is used for plots, no need to perform dimensional analysis before subs are made dim = "" dim_latex = "" @@ -2508,21 +2483,16 @@ def evaluate_statements(statements: list[InputAndSystemStatement], expanded_statements = get_sorted_statements(expanded_statements, custom_definition_names) combined_expressions: list[CombinedExpression] = [] - unitless_sub_expression_subs: dict[str, Expr | float] = {} - unit_sub_expression_dimensionless: dict[str, bool] = {} - function_unitless_sub_expression_replacements: dict[str, dict[Symbol, Symbol]] = {} - function_unitless_sub_expression_context: dict[str, str] = {} + for i, statement in enumerate(expanded_statements): if statement["type"] == "local_sub" or statement["type"] == "blank": continue - if statement["type"] == "assignment" and not statement["isUnitlessSubExpression"] and \ - not statement.get("isFunction", False): + if statement["type"] == "assignment" and not statement.get("isFunction", False): combined_expressions.append({"index": statement["index"], "isBlank": True, "isRange": False, "isScatter": False, - "unitlessSubExpressions": [], "isSubQuery": False, "subQueryName": ""}) continue @@ -2547,110 +2517,34 @@ def evaluate_statements(statements: list[InputAndSystemStatement], # sub equations into each other in topological order if there are more than one function_name = "" - unitless_sub_expression_name = "" - unitless_sub_expression_context = "" + if statement["isFunction"] is True: is_function = True function_name = statement["name"] - is_unitless_sub_expression = False - elif statement["isUnitlessSubExpression"] is True: - is_unitless_sub_expression = True - unitless_sub_expression_name = statement["name"] - unitless_sub_expression_context = statement["unitlessContext"] - is_function = False else: - is_unitless_sub_expression = False is_function = False - dependency_unitless_sub_expressions = statement["unitlessSubExpressions"] - new_function_unitless_sub_expressions: dict[str, Expr] = {} + final_expression = statement["expression"] for sub_statement in reversed(temp_statements[0:-1]): - if (sub_statement["type"] == "assignment" or ((is_function or is_unitless_sub_expression) and sub_statement["type"] == "local_sub")) \ - and not sub_statement["isUnitlessSubExpression"]: + if (sub_statement["type"] == "assignment" or (is_function and sub_statement["type"] == "local_sub")): if sub_statement["type"] == "local_sub": if is_function: current_local_subs = sub_statement["function_subs"].get(function_name, {}) if len(current_local_subs) > 0: final_expression = subs_wrapper(final_expression, current_local_subs) - elif is_unitless_sub_expression: - for local_sub_function_name, function_local_subs in sub_statement["function_subs"].items(): - function_unitless_sub_expression = new_function_unitless_sub_expressions.setdefault(local_sub_function_name, final_expression) - new_function_unitless_sub_expressions[local_sub_function_name] = subs_wrapper(function_unitless_sub_expression, function_local_subs) else: if sub_statement["name"] in map(lambda x: str(x), final_expression.free_symbols): - dependency_unitless_sub_expressions.extend(sub_statement["unitlessSubExpressions"]) final_expression = subs_wrapper(final_expression, {symbols(sub_statement["name"]): sub_statement["expression"]}) - - if is_unitless_sub_expression: - new_function_unitless_sub_expressions = { - key:subs_wrapper(expression, {symbols(sub_statement["name"]): sub_statement["expression"]}) for - key, expression in new_function_unitless_sub_expressions.items() - } - - if is_unitless_sub_expression: - for current_function_name in new_function_unitless_sub_expressions.keys(): - function_unitless_sub_expression_replacements.setdefault(current_function_name, {}).update( - {symbols(unitless_sub_expression_name): symbols(unitless_sub_expression_name+current_function_name)} - ) - function_unitless_sub_expression_context[unitless_sub_expression_name] = unitless_sub_expression_context - - new_function_unitless_sub_expressions[''] = final_expression - - for current_function_name, final_expression in new_function_unitless_sub_expressions.items(): - while(True): - available_unitless_subs = set(function_unitless_sub_expression_replacements.get(current_function_name, {}).keys()) & \ - final_expression.free_symbols - if len(available_unitless_subs) == 0: - break - final_expression = subs_wrapper(final_expression, function_unitless_sub_expression_replacements[current_function_name]) - final_expression = subs_wrapper(final_expression, unitless_sub_expression_subs) - - final_expression = subs_wrapper(final_expression, unitless_sub_expression_subs) - final_expression = cast(Expr, final_expression.doit()) - dimensional_analysis_expression, dim_sub_error = get_dimensional_analysis_expression(dimensional_analysis_subs, - final_expression, - placeholder_map, - placeholder_set) - dim, _, _, _, _ = dimensional_analysis(dimensional_analysis_expression, dim_sub_error) - if dim == "": - unit_sub_expression_dimensionless[unitless_sub_expression_name+current_function_name] = True - else: - unit_sub_expression_dimensionless[unitless_sub_expression_name+current_function_name] = False - - final_expression = cast(Expr, cast(Expr, final_expression).xreplace(parameter_subs)) - final_expression = replace_placeholder_funcs(final_expression, - "sympy_func", - placeholder_map, - placeholder_set, - None) - - unitless_sub_expression_subs[symbols(unitless_sub_expression_name+current_function_name)] = final_expression - - elif is_function: - while(True): - available_unitless_subs = set(function_unitless_sub_expression_replacements.get(function_name, {}).keys()) & \ - final_expression.free_symbols - if len(available_unitless_subs) == 0: - break - final_expression = subs_wrapper(final_expression, function_unitless_sub_expression_replacements[function_name]) - statement["unitlessSubExpressions"].extend([{"name": str(function_unitless_sub_expression_replacements[function_name][key]), - "unitlessContext": function_unitless_sub_expression_context[str(key)]} for key in available_unitless_subs]) - final_expression = subs_wrapper(final_expression, unitless_sub_expression_subs) - if function_name in function_unitless_sub_expression_replacements: - for unitless_sub_expression_i, unitless_sub_expression in enumerate(statement["unitlessSubExpressions"]): - if symbols(unitless_sub_expression["name"]) in function_unitless_sub_expression_replacements[function_name]: - statement["unitlessSubExpressions"][unitless_sub_expression_i] = UnitlessSubExpressionName(name = str(function_unitless_sub_expression_replacements[function_name][symbols(unitless_sub_expression["name"])]), - unitlessContext = unitless_sub_expression["unitlessContext"]) + if is_function: statement["expression"] = final_expression elif statement["type"] == "query": if statement["isRange"] is not True: current_combined_expression: CombinedExpression = {"index": statement["index"], - "expression": subs_wrapper(final_expression, unitless_sub_expression_subs), - "unitlessSubExpressions": dependency_unitless_sub_expressions, + "expression": final_expression, "isBlank": False, "isRange": False, "isScatter": False, @@ -2668,8 +2562,7 @@ def evaluate_statements(statements: list[InputAndSystemStatement], } else: current_combined_expression: CombinedExpression = {"index": statement["index"], - "expression": subs_wrapper(final_expression, unitless_sub_expression_subs), - "unitlessSubExpressions": dependency_unitless_sub_expressions, + "expression": final_expression, "isBlank": False, "isRange": True, "isParametric": statement.get("isParametric", False), @@ -2733,21 +2626,21 @@ def evaluate_statements(statements: list[InputAndSystemStatement], else: expression = cast(Expr, item["expression"].doit()) - evaluated_expression, symbolic_expression = get_evaluated_expression(expression, - parameter_subs, - simplify_symbolic_expressions, - placeholder_map, - placeholder_set) + evaluated_expression, symbolic_expression, dim_values_dict = get_evaluated_expression(expression, + parameter_subs, + simplify_symbolic_expressions, + placeholder_map, + placeholder_set) dimensional_analysis_expression, dim_sub_error = get_dimensional_analysis_expression(dimensional_analysis_subs, expression, placeholder_map, - placeholder_set) + placeholder_set, + dim_values_dict) if not is_matrix(evaluated_expression): results[index] = get_result(evaluated_expression, dimensional_analysis_expression, dim_sub_error, cast(str, symbolic_expression), - item["unitlessSubExpressions"], item["isRange"], - unit_sub_expression_dimensionless, + item["isRange"], custom_base_units, item["isSubQuery"], item["subQueryName"]) @@ -2772,8 +2665,8 @@ def evaluate_statements(statements: list[InputAndSystemStatement], current_result = get_result(cast(ExprWithAssumptions, evaluated_expression[i,j]), cast(Expr, current_dimensional_analysis_expression), - dim_sub_error, symbolic_expression[i][j], item["unitlessSubExpressions"], - item["isRange"], unit_sub_expression_dimensionless, + dim_sub_error, symbolic_expression[i][j], + item["isRange"], custom_base_units, item["isSubQuery"], item["subQueryName"]) @@ -2798,9 +2691,6 @@ def evaluate_statements(statements: list[InputAndSystemStatement], if item["isFunctionArgument"] or item["isUnitsQuery"]: range_dependencies[item["name"]] = cast(Result | FiniteImagResult | MatrixResult, results[index]) - - if item["isCodeFunctionRawQuery"]: - code_func_raw_results[item["name"]] = cast(CombinedExpressionNoRange, item) if item["isCodeFunctionRawQuery"]: current_result = item diff --git a/src/MathCell.svelte b/src/MathCell.svelte index 868477b0..af9ab5a1 100644 --- a/src/MathCell.svelte +++ b/src/MathCell.svelte @@ -335,7 +335,7 @@ const currentResultLatex = getLatexResult(createSubQuery(sympyVar), subResults.get(sympyVar), numberConfig); let newLatex: string; if (currentResultLatex.error) { - newLatex = String.raw`\text{${currentResultLatex.error}}`; + newLatex = String.raw`\text{${currentResultLatex.error.startsWith("Dimension Error:") ? "Dimension Error" : currentResultLatex.error}}`; } else { newLatex = ` ${currentResultLatex.resultLatex}${currentResultLatex.resultUnitsLatex} `; } diff --git a/src/cells/FluidCell.ts b/src/cells/FluidCell.ts index 21c834fa..37b656b9 100644 --- a/src/cells/FluidCell.ts +++ b/src/cells/FluidCell.ts @@ -294,10 +294,8 @@ export default class FluidCell extends BaseCell { name: this.mathField.statement.name, sympy: `${fluidFuncName}(0,0)`, params: [], - isUnitlessSubExpression: false, isFunctionArgument: false, isFunction: false, - unitlessSubExpressions: [], implicitParams: [], functions: [], arguments: [], diff --git a/src/parser/LatexToSympy.ts b/src/parser/LatexToSympy.ts index 9986b70e..657bea56 100644 --- a/src/parser/LatexToSympy.ts +++ b/src/parser/LatexToSympy.ts @@ -4,7 +4,7 @@ import LatexParserVisitor from "./LatexParserVisitor"; import type { FieldTypes, Statement, QueryStatement, RangeQueryStatement, UserFunctionRange, AssignmentStatement, ImplicitParameter, UserFunction, FunctionArgumentQuery, FunctionArgumentAssignment, LocalSubstitution, LocalSubstitutionRange, - UnitlessSubExpression, GuessAssignmentStatement, FunctionUnitsQuery, + GuessAssignmentStatement, FunctionUnitsQuery, SolveParametersWithGuesses, ErrorStatement, EqualityStatement, EqualityUnitsQueryStatement, SolveParameters, AssignmentList, InsertMatrix, @@ -13,13 +13,12 @@ import type { FieldTypes, Statement, QueryStatement, RangeQueryStatement, UserFu ScatterXValuesQueryStatement, ScatterYValuesQueryStatement, DataTableInfo, DataTableQueryStatement, BlankStatement, SubQueryStatement} from "./types"; -import { isInsertion, isReplacement, - type Insertion, type Replacement, applyEdits, - createSubQuery} from "./utility"; +import { type Insertion, type Replacement, applyEdits, + createSubQuery } from "./utility"; import { RESERVED, GREEK_CHARS, UNASSIGNABLE, COMPARISON_MAP, UNITS_WITH_OFFSET, TYPE_PARSING_ERRORS, BUILTIN_FUNCTION_MAP, - ZERO_PLACEHOLDER } from "./constants.js"; + BUILTIN_FUNCTION_NEEDS_VALUES, ZERO_PLACEHOLDER } from "./constants.js"; import { MAX_MATRIX_COLS } from "../constants"; @@ -64,7 +63,7 @@ type ParsingResult = { } export function getBlankStatement(): BlankStatement { - return { type: "blank", params: [], implicitParams: [], unitlessSubExpressions: [], isFromPlotCell: false}; + return { type: "blank", params: [], implicitParams: [], isFromPlotCell: false}; } export function parseLatex(latex: string, id: number, type: FieldTypes, @@ -180,14 +179,13 @@ export class LatexToSympy extends LatexParserVisitor(); - unitlessSubExpressions: UnitlessSubExpression[] = []; subQueries: SubQueryStatement[] = []; subQueryReplacements: [string, Replacement][] = []; inQueryStatement = false; @@ -269,11 +267,6 @@ export class LatexToSympy extends LatexParserVisitor { - const exponentVariableName = this.getNextUnitlessSubExpressionName(); - + visitExponent = (ctx: ExponentContext) => { let base: string; let cursor: number; let exponent: string @@ -1096,8 +1068,10 @@ export class LatexToSympy extends LatexParserVisitor { - const rowVariableName = this.getNextUnitlessSubExpressionName(); - - let cursor = this.params.length; const rowExpression = this.visit(ctx.expr(1)) as string; - - this.unitlessSubExpressions.push({ - type: "assignment", - name: rowVariableName, - sympy: rowExpression, - params: this.params.slice(cursor), - isUnitlessSubExpression: true, - unitlessContext: "Matrix Index", - isFunctionArgument: false, - isFunction: false, - unitlessSubExpressions: [] - }); - this.params.push(rowVariableName); - - const colVariableName = this.getNextUnitlessSubExpressionName(); - cursor = this.params.length; const colExpression = this.visit(ctx.expr(2)) as string; - this.unitlessSubExpressions.push({ - type: "assignment", - name: colVariableName, - sympy: colExpression, - params: this.params.slice(cursor), - isUnitlessSubExpression: true, - unitlessContext: "Matrix Index", - isFunctionArgument: false, - isFunction: false, - unitlessSubExpressions: [] - }); - this.params.push(colVariableName); - - return `_IndexMatrix(${this.visit(ctx.expr(0))}, ${rowVariableName}, ${colVariableName})`; + return `_dim_needs_values_wrapper(__unique_marker_${this.equationIndex}_${this.dimNeedsValuesIndex++},_IndexMatrix(${this.visit(ctx.expr(0))}, ${rowExpression}, ${colExpression}))`; } visitArgument = (ctx: ArgumentContext): (LocalSubstitution | LocalSubstitutionRange)[] => { @@ -1217,11 +1145,9 @@ export class LatexToSympy extends LatexParserVisitor { - return `Add(${this.visit(ctx.expr(0))}, ${this.visit(ctx.expr(1))})`; + return `_add(${this.visit(ctx.expr(0))}, ${this.visit(ctx.expr(1))})`; } visitSubtract = (ctx: SubtractContext) => { - return `Add(${this.visit(ctx.expr(0))}, -(${this.visit(ctx.expr(1))}))`; + return `_add(${this.visit(ctx.expr(0))}, -(${this.visit(ctx.expr(1))}))`; } visitVariable = (ctx: VariableContext) => { @@ -2059,6 +1979,10 @@ export class LatexToSympy extends LatexParserVisitor & { @@ -110,17 +107,9 @@ export type UserFunctionRange = Omit & { }; -export type UnitlessSubExpression = Omit & { - isUnitlessSubExpression: true; - unitlessContext: string; - isFunctionArgument: false; - isFunction: false; -}; - export type FunctionArgumentAssignment = Pick & { - isUnitlessSubExpression: false; + "params"> & { isFunctionArgument: true; isFunction: false; }; @@ -168,13 +157,11 @@ export type EqualityStatement = Omit & { type BaseQueryStatement = { type: "query"; sympy: string; - unitlessSubExpressions: UnitlessSubExpression[]; implicitParams: ImplicitParameter[]; params: string[]; functions: (UserFunction | UserFunctionRange | FunctionUnitsQuery)[]; arguments: (FunctionArgumentAssignment | FunctionArgumentQuery) []; localSubs: (LocalSubstitution | LocalSubstitutionRange)[]; - isUnitlessSubExpression: false; isFunctionArgument: false; isFunction: false; isUnitsQuery: false; @@ -264,7 +251,6 @@ export type ScatterQueryStatement = { arguments: (FunctionArgumentAssignment | FunctionArgumentQuery) []; localSubs: (LocalSubstitution | LocalSubstitutionRange)[]; implicitParams: ImplicitParameter[]; - unitlessSubExpressions: UnitlessSubExpression[]; equationIndex: number; cellNum: number; isFromPlotCell: boolean; @@ -298,9 +284,8 @@ export type CodeFunctionRawQuery = BaseQueryStatement & { isCodeFunctionRawQuery: true; } -export type FunctionArgumentQuery = Pick & { +export type FunctionArgumentQuery = Pick & { name: string; - isUnitlessSubExpression: false; isFunctionArgument: true; isFunction: false; isUnitsQuery: false; @@ -310,9 +295,8 @@ export type FunctionArgumentQuery = Pick & { +export type FunctionUnitsQuery = Pick & { units: ''; - isUnitlessSubExpression: false; isFunctionArgument: false; isFunction: false; isUnitsQuery: true; diff --git a/src/parser/utility.ts b/src/parser/utility.ts index f0909e3b..29e99ce8 100644 --- a/src/parser/utility.ts +++ b/src/parser/utility.ts @@ -89,7 +89,6 @@ export function applyEdits(source: string, pendingEdits: (Insertion | Replacemen export function createSubQuery(name: string): SubQueryStatement { return { type: "query", - unitlessSubExpressions: [], implicitParams: [], params: [name], functions: [], @@ -97,7 +96,6 @@ export function createSubQuery(name: string): SubQueryStatement { localSubs: [], units: "", unitsLatex: "", - isUnitlessSubExpression: false, isFunctionArgument: false, isFunction: false, isUnitsQuery: false, @@ -112,4 +110,5 @@ export function createSubQuery(name: string): SubQueryStatement { isCodeFunctionQuery: false, isCodeFunctionRawQuery: false }; -} \ No newline at end of file +} + diff --git a/tests/test_basic.spec.mjs b/tests/test_basic.spec.mjs index 7057420e..95ea83d8 100644 --- a/tests/test_basic.spec.mjs +++ b/tests/test_basic.spec.mjs @@ -750,6 +750,36 @@ test('Test function notation with exponents and units', async () => { }); +test('Test function notation with exponents and units and nested functions', async () => { + + await page.setLatex(0, String.raw`t\left(s=y\left(x=2\left\lbrack in\right\rbrack\right)\cdot1\left\lbrack in\right\rbrack\right)=`); + await page.click('#add-math-cell'); + await page.setLatex(1, String.raw`t=2^{\frac{s}{1\left\lbrack in\right\rbrack}}`); + await page.click('#add-math-cell'); + await page.setLatex(2, String.raw`y=3^{\frac{x}{1\left\lbrack in\right\rbrack}}`); + + await page.waitForSelector('text=Updating...', {state: 'detached'}); + + let content = await page.textContent('#result-value-0'); + expect(parseLatexFloat(content)).toBeCloseTo(512, precision); + content = await page.textContent('#result-units-0'); + expect(content).toBe(''); +}); + +test('Test zero canceling bug with exponent', async () => { + + await page.setLatex(0, String.raw`y=\frac{0\left\lbrack m\right\rbrack}{2^{x}}`); + await page.click('#add-math-cell'); + await page.setLatex(1, String.raw`y\left(x=1\right)=`); + + await page.waitForSelector('text=Updating...', {state: 'detached'}); + + let content = await page.textContent('#result-value-1'); + expect(parseLatexFloat(content)).toBeCloseTo(0, precision); + content = await page.textContent('#result-units-1'); + expect(content).toBe('m'); +}); + test('Test function notation with integrals', async () => { diff --git a/tests/test_matrix_functions.spec.mjs b/tests/test_matrix_functions.spec.mjs index 556fa90f..2e9e0818 100644 --- a/tests/test_matrix_functions.spec.mjs +++ b/tests/test_matrix_functions.spec.mjs @@ -347,8 +347,17 @@ test('Test range that includes zero value multiplied by dimensioned value', asyn expect(content).toBe(String.raw`\begin{bmatrix} 0\left\lbrack m\right\rbrack \\ 1\left\lbrack m\right\rbrack \\ 2\left\lbrack m\right\rbrack \\ 3\left\lbrack m\right\rbrack \\ 4\left\lbrack m\right\rbrack \\ 5\left\lbrack m\right\rbrack \end{bmatrix}`); }); -test('Test range input needs to be unitless', async () => { - await page.setLatex(0, String.raw`\mathrm{range}\left(1\left\lbrack m\right\rbrack,2\left\lbrack m\right\rbrack,.1\left\lbrack m\right\rbrack\right)=`); +test('Test range with consistent units', async () => { + await page.setLatex(0, String.raw`\mathrm{range}\left(1\left\lbrack m\right\rbrack,2\left\lbrack m\right\rbrack,1\left\lbrack m\right\rbrack\right)=`); + + await page.waitForSelector('text=Updating...', {state: 'detached'}); + + let content = await page.textContent(`#result-value-0`); + expect(content).toBe(String.raw`\begin{bmatrix} 1\left\lbrack m\right\rbrack \\ 2\left\lbrack m\right\rbrack \end{bmatrix}`); +}); + +test('Test range with inconsistent units', async () => { + await page.setLatex(0, String.raw`\mathrm{range}\left(1\left\lbrack m\right\rbrack,2\left\lbrack s\right\rbrack,1\left\lbrack m\right\rbrack\right)=`); await page.waitForSelector('text=Updating...', {state: 'detached'}); diff --git a/tests/test_number_format.spec.mjs b/tests/test_number_format.spec.mjs index 6054c883..ebb35318 100644 --- a/tests/test_number_format.spec.mjs +++ b/tests/test_number_format.spec.mjs @@ -53,6 +53,10 @@ test('Test symbolic format', async () => { await page.locator('#add-math-cell').click(); await page.setLatex(2, String.raw`\frac{-3\left\lbrack mm\right\rbrack}{\sqrt2}=`); + // symbolic expression with fractional exponent + await page.locator('#add-math-cell').click(); + await page.setLatex(3, String.raw`3.0^{.500}=`); + await page.waitForSelector('text=Updating...', {state: 'detached'}); // check all values rendered as floating point values first @@ -69,6 +73,9 @@ test('Test symbolic format', async () => { content = await page.textContent('#result-units-2'); expect(content).toBe('m'); + content = await page.textContent('#result-value-3'); + expect(parseLatexFloat(content)).toBeCloseTo(sqrt(3), precision); + // switch to symbolic formatting await page.getByRole('button', { name: 'Sheet Settings' }).click(); await page.locator('label').filter({ hasText: 'Display Symbolic Results' }).click(); @@ -87,6 +94,8 @@ test('Test symbolic format', async () => { content = await page.textContent('#result-units-2'); expect(content).toBe('m'); + content = await page.textContent('#result-value-3'); + expect(content).toBe(String.raw`\sqrt{3}`); }); test('Test disabling automatic expressions simplification', async () => {