Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SymPy errors when using reserved symbols in differentiate2c #1531

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 69 additions & 9 deletions python/nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,51 @@ def make_symbol(var, /):
return sp.Symbol(var, real=True) if isinstance(var, str) else var


def rename_to_sympy(var, /):
"""
Symbol-aware renaming of reserved or built-in symbols.

Notes
-----
For renaming variables in complex expressions, see
:ref:``regex_rename_to_sympy``.
"""
var = make_symbol(var)

if var.name in forbidden_var:
new_name = f"_sympy_{var.name}"

# there is no "Pythonic" way of doing this since the SymPy API does not
# allow changing names of symbols, or actually copying objects properly
if isinstance(var, sp.Symbol):
assumptions = var.assumptions0
return sp.Symbol(new_name, **assumptions)

elif isinstance(var, sp.IndexedBase):
return sp.IndexedBase(new_name, shape=var.shape)

elif isinstance(var, sp.FunctionClass):
return sp.Function(new_name)

return var


def regex_rename_to_sympy(expression, /):
"""
Rename expression containing reserved or built-in symbols using a regex.

Notes
-----
For renaming single variables, see :ref:``rename_to_sympy``.
"""

for var in forbidden_var:
pattern = re.compile(rf"\b{var}\b")
expression = re.sub(pattern, f"_sympy_{var}", expression)

return expression


def solve_lin_system(
eq_strings,
vars,
Expand Down Expand Up @@ -622,18 +667,23 @@ def differentiate2c(
if stepsize <= 0:
raise ValueError("arg `stepsize` must be > 0")
prev_expressions = prev_expressions or []
# every symbol (a.k.a variable) that SymPy
# is going to manipulate needs to be declared
# explicitly
x = make_symbol(dependent_var)
vars = set(vars)
vars.discard(dependent_var)

# we keep the original symbol around as well so we can rename it back
x_original = make_symbol(dependent_var)
sympy_vars_original = {
**{str(var): make_symbol(var) for var in vars},
str(x_original): x_original,
}

# declare all other supplied variables
sympy_vars = {str(var): make_symbol(var) for var in vars}
sympy_vars[dependent_var] = x
x = rename_to_sympy(dependent_var)
sympy_vars = {
**{str(rename_to_sympy(var)): rename_to_sympy(var) for var in vars},
str(x): x,
}

# parse string into SymPy equation
expr = sp.sympify(expression, locals=sympy_vars)
expr = sp.sympify(regex_rename_to_sympy(expression), locals=sympy_vars)

# parse previous expressions in the order that they came in
# substitute any x-dependent vars in rhs with their rhs expressions,
Expand Down Expand Up @@ -672,6 +722,16 @@ def differentiate2c(
.evalf()
)

# once we have the derivative, it's safe to put back the original variables
reverse_map = {
old_var: new_var
for old_var, new_var in zip(
sympy_vars.values(),
sympy_vars_original.values(),
)
}
diff = diff.subs(reverse_map)

# the codegen method does not like undefined function calls, so we extract
# them here
custom_fcts = {str(f.func): str(f.func) for f in diff.atoms(sp.Function)}
Expand Down
6 changes: 5 additions & 1 deletion test/unit/ode/test_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from nmodl.ode import differentiate2c, integrate2c, make_symbol
from nmodl.ode import differentiate2c, integrate2c, make_symbol, forbidden_var
import pytest

import sympy as sp
Expand Down Expand Up @@ -156,6 +156,10 @@ def test_differentiate2c():
stepsize=-1,
)

# test reserved symbols
for var in forbidden_var:
assert sp.parse_expr(differentiate2c(var, var, {})) == sp.parse_expr("1.")


def test_integrate2c():

Expand Down
Loading