Skip to content

Commit

Permalink
Merge pull request #2520 from devitocodes/wood-run-stable
Browse files Browse the repository at this point in the history
api: Misc fixes for builtins and harmonic averaging
  • Loading branch information
mloubout authored Jan 21, 2025
2 parents fa903e4 + d437b25 commit 490a627
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 171 deletions.
19 changes: 19 additions & 0 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,23 @@ class Mod(DifferentiableOp, sympy.Mod):
__sympy_class__ = sympy.Mod


class SafeInv(Differentiable, sympy.core.function.Application):
_fd_priority = 0

@property
def base(self):
return self.args[1]

@property
def val(self):
return self.args[0]

def __str__(self):
return Pow(self.args[0], -1).__str__()

__repr__ = __str__


class IndexSum(sympy.Expr, Evaluable):

"""
Expand Down Expand Up @@ -675,6 +692,8 @@ def __repr__(self):
def _sympystr(self, printer):
return str(self)

_latex = _sympystr

def _hashable_content(self):
return super()._hashable_content() + (self.dimensions,)

Expand Down
5 changes: 4 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,14 +894,17 @@ def apply(self, **kwargs):
>>> op = Operator(Eq(u3.forward, u3 + 1))
>>> summary = op.apply(time_M=10)
"""
# Compile the operator before building the arguments list
# to avoid out of memory with greedy compilers
cfunction = self.cfunction

# Build the arguments list to invoke the kernel function
with self._profiler.timer_on('arguments'):
args = self.arguments(**kwargs)

# Invoke kernel function with args
arg_values = [args[p.name] for p in self.parameters]
try:
cfunction = self.cfunction
with self._profiler.timer_on('apply', comm=args.comm):
retval = cfunction(*arg_values)
except ctypes.ArgumentError as e:
Expand Down
8 changes: 8 additions & 0 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sympy

from devito.finite_differences import Max, Min
from devito.finite_differences.differentiable import SafeInv
from devito.ir import (Any, Forward, DummyExpr, Iteration, List, Prodder,
FindApplications, FindNodes, FindSymbols, Transformer,
Uxreplace, filter_iterations, retrieve_iteration_tree,
Expand Down Expand Up @@ -225,6 +226,13 @@ def _(expr):
return ()


@_lower_macro_math.register(SafeInv)
def _(expr):
eps = np.finfo(np.float32).resolution**2
return (('SAFEINV(a, b)',
f'(((a) < {eps} || (b) < {eps}) ? (0.0F) : (1.0F / (a)))'),)


@iet_pass
def minimize_symbols(iet):
"""
Expand Down
3 changes: 2 additions & 1 deletion devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,11 @@ def estimate_cost(exprs, estimate=False):
estimate_values = {
'elementary': 100,
'pow': 50,
'SafeInv': 10,
'div': 5,
'Abs': 5,
'floor': 1,
'ceil': 1
'ceil': 1,
}


Expand Down
6 changes: 6 additions & 0 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ def _print_Pow(self, expr):
else:
return f'pow{suffix}({self._print(expr.base)}, {self._print(expr.exp)})'

def _print_SafeInv(self, expr):
"""Print a SafeInv as a C-like division with a check for zero."""
base = self._print(expr.base)
val = self._print(expr.val)
return f'SAFEINV({val}, {base})'

def _print_Mod(self, expr):
"""Print a Mod as a C-like %-based operation."""
args = ['(%s)' % self._print(a) for a in expr.args]
Expand Down
23 changes: 17 additions & 6 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,12 +1167,16 @@ def _evaluate(self, **kwargs):
# Apply interpolation from inner most dim
for d, i in self._grid_map.items():
retval = retval.diff(d, deriv_order=0, fd_order=2, x0={d: i})
if self._avg_mode == 'harmonic':
retval = 1 / retval

# Evaluate. Since we used `self.function` it will be on the grid when evaluate
# is called again within FD
return retval.evaluate.expand()
if self._avg_mode == 'harmonic':
from devito.finite_differences.differentiable import SafeInv
retval = SafeInv(retval.evaluate, self.function)
else:
retval = retval.evaluate

return retval

@property
def shape(self):
Expand Down Expand Up @@ -1450,12 +1454,19 @@ def indexify(self, indices=None, subs=None):
# Indices after substitutions
indices = []
for a, d, o, s in zip(self.args, self.dimensions, self.origin, subs):
if d in a.free_symbols:
if a.is_Function and len(a.args) == 1:
# E.g. Abs(expr)
arg = a.args[0]
func = a.func
else:
arg = a
func = lambda x: x
if d in arg.free_symbols:
# Shift by origin d -> d - o.
indices.append(sympy.sympify(a.subs(d, d - o).xreplace(s)))
indices.append(func(sympy.sympify(arg.subs(d, d - o).xreplace(s))))
else:
# Dimension has been removed, e.g. u[10], plain shift by origin
indices.append(sympy.sympify(a - o).xreplace(s))
indices.append(func(sympy.sympify(arg - o).xreplace(s)))

indices = [i.xreplace({k: sympy.Integer(k) for k in i.atoms(sympy.Float)})
for i in indices]
Expand Down
3 changes: 0 additions & 3 deletions examples/seismic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,6 @@ def _initialize_physics(self, vp, space_order, **kwargs):
vs = kwargs.pop('vs')
self.lam = self._gen_phys_param((vp**2 - 2. * vs**2)/b, 'lam', space_order,
is_param=True)
# Need to add small value to avoid division by zero
if isinstance(vs, np.ndarray):
vs = vs + 1e-12
self.mu = self._gen_phys_param(vs**2 / b, 'mu', space_order, is_param=True,
avg_mode='harmonic')
else:
Expand Down
280 changes: 122 additions & 158 deletions examples/seismic/tutorials/06_elastic_varying_parameters.ipynb

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions tests/test_differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pytest

from devito import Function, Grid, Differentiable, NODE
from devito.finite_differences.differentiable import Add, Mul, Pow, diffify, interp_for_fd
from devito.finite_differences.differentiable import (Add, Mul, Pow, diffify,
interp_for_fd, SafeInv)


def test_differentiable():
Expand Down Expand Up @@ -113,4 +114,7 @@ def test_avg_mode(ndim):
assert sympy.simplify(a_avg - 0.5**ndim * sum(a.subs(arg) for arg in args)) == 0

# Harmonic average, h(a[.5]) = 1/(.5/a[0] + .5/a[1])
assert sympy.simplify(b_avg - 1/(0.5**ndim * sum(1/b.subs(arg) for arg in args))) == 0
expected = 1/(0.5**ndim * sum(1/b.subs(arg) for arg in args))
assert sympy.simplify(1/b_avg.args[0] - expected) == 0
assert isinstance(b_avg, SafeInv)
assert b_avg.base == b
15 changes: 15 additions & 0 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
Min, Max)
from devito.finite_differences.differentiable import SafeInv
from devito.ir import Expression, FindNodes
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
CallFromPointer, Cast, DefFunction, FieldFromPointer,
Expand Down Expand Up @@ -345,6 +346,20 @@ def test_intdiv():
assert ccode(v) == 'b*((a + b) / 2) + 3'


def test_safeinv():
grid = Grid(shape=(11, 11))
x, y = grid.dimensions

u1 = Function(name='u', grid=grid)
u2 = Function(name='u', grid=grid, dtype=np.float64)

op1 = Operator(Eq(u1, SafeInv(u1, u1)))
op2 = Operator(Eq(u2, SafeInv(u2, u2)))

assert 'SAFEINV' in str(op1)
assert 'SAFEINV' in str(op2)


def test_def_function():
foo0 = DefFunction('foo', arguments=['a', 'b'], template=['int'])
foo1 = DefFunction('foo', arguments=['a', 'b'], template=['int'])
Expand Down

0 comments on commit 490a627

Please sign in to comment.