Skip to content

Commit

Permalink
Merge pull request #2524 from devitocodes/fix-derivs-7865
Browse files Browse the repository at this point in the history
api: Fix custom coefficients inlining
  • Loading branch information
mloubout authored Jan 23, 2025
2 parents 8cec681 + 0698917 commit f71764a
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 10 deletions.
9 changes: 5 additions & 4 deletions devito/finite_differences/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici
weights = [weights._subs(wdim, i) for i in range(len(indices))]

# Enforce fixed precision FD coefficients to avoid variations in results
weights = [sympify(w).evalf(_PRECISION) for w in weights]
if scale:
scale = dim.spacing**(-deriv_order)
else:
scale = 1
weights = [sympify(scale * w).evalf(_PRECISION) for w in weights]

# Transpose the FD, if necessary
if matvec == transpose:
Expand Down Expand Up @@ -208,7 +212,4 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici

deriv = EvalDerivative(*terms, base=expr)

if scale:
deriv = dim.spacing**(-deriv_order) * deriv

return deriv
2 changes: 1 addition & 1 deletion devito/types/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _apply_coeffs(cls, expr, coefficients):
if not mapper:
return expr

return expr.xreplace(mapper)
return expr.subs(mapper)

def _evaluate(self, **kwargs):
"""
Expand Down
2 changes: 1 addition & 1 deletion examples/seismic/tutorials/07_DRP_schemes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Eq(-u(t, x, y)/dt + u(t + dt, x, y)/dt + (0.1*u(t, x, y) - 0.6*u(t, x - h_x, y) + 0.6*u(t, x + h_x, y))/h_x, 0)\n"
"Eq(-u(t, x, y)/dt + u(t + dt, x, y)/dt + 0.1*u(t, x, y)/h_x - 0.6*u(t, x - h_x, y)/h_x + 0.6*u(t, x + h_x, y)/h_x, 0)\n"
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions tests/test_symbolic_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def test_staggered_equation(self):

eq_f = Eq(f, f.dx2(weights=weights))

expected = 'Eq(f(x + h_x/2), (1.0*f(x - h_x/2) - 2.0*f(x + h_x/2)'\
' + 1.0*f(x + 3*h_x/2))/h_x**2)'
expected = 'Eq(f(x + h_x/2), 1.0*f(x - h_x/2)/h_x**2 - 2.0*f(x + h_x/2)/h_x**2 '\
'+ 1.0*f(x + 3*h_x/2)/h_x**2)'
assert(str(eq_f.evaluate) == expected)

@pytest.mark.parametrize('stagger', [True, False])
Expand Down
28 changes: 26 additions & 2 deletions tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_numeric_coeffs(self):
Operator(Eq(u, (v*u.dx).dy(weights=w)), opt=opt).cfunction

@pytest.mark.parametrize('coeffs,expected', [
((7, 7, 7), 1), # We've had a bug triggered by identical coeffs
((7, 7, 7), 3), # We've had a bug triggered by identical coeffs
((5, 7, 9), 3),
])
def test_multiple_cross_derivs(self, coeffs, expected):
Expand Down Expand Up @@ -89,7 +89,8 @@ def test_legacy_api(self, order, nweight):
coefficients='symbolic')

w0 = np.arange(so + 1 + nweight) + 1
wstr = '{' + ', '.join([f"{w:1.1f}F" for w in w0]) + '}'
s = f'({x.spacing}*{x.spacing})' if order == 2 else f'{x.spacing}'
wstr = f'{{{w0[0]:1.1f}F/{s},'
wdef = f'[{so + 1 + nweight}] __attribute__ ((aligned (64)))'

coeffs_x_p1 = Coefficient(order, u, x, w0)
Expand All @@ -105,6 +106,29 @@ def test_legacy_api(self, order, nweight):
op = Operator(eqn, opt=('advanced', {'expand': False}))
assert f'{wdef} = {wstr}' in str(op)

def test_legacy_api_v2(self):
grid = Grid(shape=(10, 10, 10))
x, y, z = grid.dimensions

u = TimeFunction(name='u', grid=grid, space_order=4)

cc = np.array([2, 2, 2, 2, 2])
coeffs = [Coefficient(1, u, d, cc) for d in grid.dimensions]
coeffs = Substitutions(*coeffs)

eq0 = Eq(u.forward, u.dx.dz + 1.0)
eq1 = Eq(u.forward, u.dx.dz + 1.0, coefficients=coeffs)

op0 = Operator(eq0, opt=('advanced', {'expand': False}))
op1 = Operator(eq1, opt=('advanced', {'expand': False}))

assert (op0._profiler._sections['section0'].sops ==
op1._profiler._sections['section0'].sops)
weights = [i for i in FindSymbols().visit(op1) if isinstance(i, Weights)]
w0, w1 = sorted(weights, key=lambda i: i.name)
assert all(i.args[1] == 1/x.spacing for i in w0.weights)
assert all(i.args[1] == 1/z.spacing for i in w1.weights)


class Test1Pass:

Expand Down

0 comments on commit f71764a

Please sign in to comment.