diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index c1c907eae..55547101a 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -252,6 +252,51 @@ def make_symbol(var, /): return sp.Symbol(var, real=True) if isinstance(var, str) else var +def rename_to_sympy(var, /): + """ + Symbol-aware renaming of reserved or built-in symbols. + + Notes + ----- + For renaming variables in complex expressions, see + :ref:``regex_rename_to_sympy``. + """ + var = make_symbol(var) + + if var.name in forbidden_var: + new_name = f"_sympy_{var.name}" + + # there is no "Pythonic" way of doing this since the SymPy API does not + # allow changing names of symbols, or actually copying objects properly + if isinstance(var, sp.Symbol): + assumptions = var.assumptions0 + return sp.Symbol(new_name, **assumptions) + + elif isinstance(var, sp.IndexedBase): + return sp.IndexedBase(new_name, shape=var.shape) + + elif isinstance(var, sp.FunctionClass): + return sp.Function(new_name) + + return var + + +def regex_rename_to_sympy(expression, /): + """ + Rename expression containing reserved or built-in symbols using a regex. + + Notes + ----- + For renaming single variables, see :ref:``rename_to_sympy``. + """ + + for var in forbidden_var: + pattern = re.compile(rf"\b{var}\b") + expression = re.sub(pattern, f"_sympy_{var}", expression) + + return expression + + def solve_lin_system( eq_strings, vars, @@ -622,18 +667,23 @@ def differentiate2c( if stepsize <= 0: raise ValueError("arg `stepsize` must be > 0") prev_expressions = prev_expressions or [] - # every symbol (a.k.a variable) that SymPy - # is going to manipulate needs to be declared - # explicitly - x = make_symbol(dependent_var) - vars = set(vars) - vars.discard(dependent_var) + + # we keep the original symbol around as well so we can rename it back + x_original = make_symbol(dependent_var) + sympy_vars_original = { + **{str(var): make_symbol(var) for var in vars}, + str(x_original): x_original, + } + # declare all other supplied variables - sympy_vars = {str(var): make_symbol(var) for var in vars} - sympy_vars[dependent_var] = x + x = rename_to_sympy(dependent_var) + sympy_vars = { + **{str(rename_to_sympy(var)): rename_to_sympy(var) for var in vars}, + str(x): x, + } # parse string into SymPy equation - expr = sp.sympify(expression, locals=sympy_vars) + expr = sp.sympify(regex_rename_to_sympy(expression), locals=sympy_vars) # parse previous expressions in the order that they came in # substitute any x-dependent vars in rhs with their rhs expressions, @@ -672,6 +722,16 @@ def differentiate2c( .evalf() ) + # once we have the derivative, it's safe to put back the original variables + reverse_map = { + old_var: new_var + for old_var, new_var in zip( + sympy_vars.values(), + sympy_vars_original.values(), + ) + } + diff = diff.subs(reverse_map) + # the codegen method does not like undefined function calls, so we extract # them here custom_fcts = {str(f.func): str(f.func) for f in diff.atoms(sp.Function)} diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 82e0358d2..ea787e29a 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -3,7 +3,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from nmodl.ode import differentiate2c, integrate2c, make_symbol +from nmodl.ode import differentiate2c, integrate2c, make_symbol, forbidden_var import pytest import sympy as sp @@ -156,6 +156,10 @@ def test_differentiate2c(): stepsize=-1, ) + # test reserved symbols + for var in forbidden_var: + assert sp.parse_expr(differentiate2c(var, var, {})) == sp.parse_expr("1.") + def test_integrate2c():