Skip to content

Commit

Permalink
compiler: convert printer to f-string
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Feb 6, 2025
1 parent 79212c7 commit 61fb519
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 40 deletions.
78 changes: 44 additions & 34 deletions devito/ir/cgen/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,12 @@ def func_prefix(self, expr, abs=False):

def parenthesize(self, item, level, strict=False):
if isinstance(item, BooleanFunction):
return "(%s)" % self._print(item)
return f"({self._print(item)})"
return super().parenthesize(item, level, strict=strict)

def _print_PyCPointerType(self, expr):
return f'{self._print_type(expr._type_)} *'

def _print_type(self, expr):
try:
expr = dtype_to_ctype(expr)
Expand All @@ -120,7 +123,7 @@ def _print_Function(self, expr):
return super()._print_Function(expr)

def _print_CondEq(self, expr):
return "%s == %s" % (self._print(expr.lhs), self._print(expr.rhs))
return f"{self._print(expr.lhs)} == {self._print(expr.rhs)}"

def _print_Indexed(self, expr):
"""
Expand All @@ -131,7 +134,7 @@ def _print_Indexed(self, expr):
U[t,x,y,z] -> U[t][x][y][z]
"""
inds = ''.join(['[' + self._print(x) + ']' for x in expr.indices])
return '%s%s' % (self._print(expr.base.label), inds)
return f'{self._print(expr.base.label)}{inds}'

def _print_FIndexed(self, expr):
"""
Expand All @@ -146,7 +149,7 @@ def _print_FIndexed(self, expr):
label = expr.accessor.label
except AttributeError:
label = expr.base.label
return '%s(%s)' % (self._print(label), inds)
return f'{self._print(label)}({inds})'

def _print_Rational(self, expr):
"""Print a Rational as a C-like float/float division."""
Expand All @@ -155,10 +158,8 @@ def _print_Rational(self, expr):
# to be 32-bit floats.
# http://en.cppreference.com/w/cpp/language/floating_literal
p, q = int(expr.p), int(expr.q)
if self.dtype == np.float64:
return '%d.0/%d.0' % (p, q)
else:
return '%d.0F/%d.0F' % (p, q)
prec = self.prec_literal(expr)
return f'{p}.0{prec}/{q}.0{prec}'

def _print_math_func(self, expr, nest=False, known=None):
cls = type(expr)
Expand Down Expand Up @@ -208,16 +209,22 @@ def _print_SafeInv(self, expr):

def _print_Mod(self, expr):
"""Print a Mod as a C-like %-based operation."""
args = ['(%s)' % self._print(a) for a in expr.args]
args = [f'({self._print(a)})' for a in expr.args]
return '%'.join(args)

def _print_Mul(self, expr):
term = super()._print_Mul(expr)
# avoid (-1)*...
term = term.replace("(-1)*", "-")
# Avoid (-1) / ...
term = term.replace("(-1)/", f"-{self._prec(expr)(1)}/")
return term
args = [a for a in expr.args if a != -1]
neg = (len(expr.args) - len(args)) % 2

if len(args) > 1:
term = super()._print_Mul(expr.func(*args, evaluate=False))
else:
term = self.parenthesize(args[0], precedence(expr))

if neg:
return f'-{term}'
else:
return term

def _print_fmath_func(self, name, expr):
args = ",".join([self._print(i) for i in expr.args])
Expand All @@ -230,7 +237,7 @@ def _print_Min(self, expr):
expr.func(*expr.args[1:]),
evaluate=False))
elif has_integer_args(*expr.args) and len(expr.args) == 2:
return "MIN(%s)" % self._print(expr.args)[1:-1]
return f"MIN({self._print(expr.args)[1:-1]})"
else:
return self._print_fmath_func('min', expr)

Expand All @@ -240,7 +247,7 @@ def _print_Max(self, expr):
expr.func(*expr.args[1:]),
evaluate=False))
elif has_integer_args(*expr.args) and len(expr.args) == 2:
return "MAX(%s)" % self._print(expr.args)[1:-1]
return f"MAX({self._print(expr.args)[1:-1]})"
else:
return self._print_fmath_func('max', expr)

Expand All @@ -251,7 +258,7 @@ def _print_Abs(self, expr):
# AOMPCC errors with abs, always use fabs
if isinstance(self.compiler, AOMPCompiler) and \
not np.issubdtype(self._prec(expr), np.integer):
return "fabs(%s)" % self._print(arg)
return f"fabs({self._print(arg)})"
return self._print_fmath_func('abs', expr)

def _print_Add(self, expr, order=None):
Expand All @@ -265,7 +272,7 @@ def _print_Add(self, expr, order=None):
for term in terms:
t = self._print(term)
if precedence(term) < PREC:
l.extend(["+", "(%s)" % t])
l.extend(["+", f"({t})"])
elif t.startswith('-'):
l.extend(["-", t[1:]])
else:
Expand Down Expand Up @@ -305,44 +312,44 @@ def _print_Float(self, expr):
return f'{rv}{self.prec_literal(expr)}'

def _print_Differentiable(self, expr):
return "(%s)" % self._print(expr._expr)
return f"({self._print(expr._expr)})"

_print_EvalDerivative = _print_Add

def _print_CallFromPointer(self, expr):
indices = [self._print(i) for i in expr.params]
return "%s->%s(%s)" % (expr.pointer, expr.call, ', '.join(indices))
return f"{expr.pointer}->{expr.call}({', '.join(indices)})"

def _print_CallFromComposite(self, expr):
indices = [self._print(i) for i in expr.params]
return "%s.%s(%s)" % (expr.pointer, expr.call, ', '.join(indices))
return f"{expr.pointer}.{expr.call}({', '.join(indices)})"

def _print_FieldFromPointer(self, expr):
return "%s->%s" % (expr.pointer, expr.field)
return f"{expr.pointer}->{expr.field}"

def _print_FieldFromComposite(self, expr):
return "%s.%s" % (expr.pointer, expr.field)
return f"{expr.pointer}.{expr.field}"

def _print_ListInitializer(self, expr):
return "{%s}" % ', '.join([self._print(i) for i in expr.params])
return f"{{{', '.join(self._print(i) for i in expr.params)}}}"

def _print_IndexedPointer(self, expr):
return "%s%s" % (expr.base, ''.join('[%s]' % self._print(i) for i in expr.index))
return f"{expr.base}{''.join(f'[{self._print(i)}]' for i in expr.index)}"

def _print_IntDiv(self, expr):
lhs = self._print(expr.lhs)
if not expr.lhs.is_Atom:
lhs = '(%s)' % (lhs)
lhs = f"({lhs})"
rhs = self._print(expr.rhs)
PREC = precedence(expr)
return self.parenthesize("%s / %s" % (lhs, rhs), PREC)
return self.parenthesize(f"{lhs} / {rhs}", PREC)

def _print_InlineIf(self, expr):
cond = self._print(expr.cond)
true_expr = self._print(expr.true_expr)
false_expr = self._print(expr.false_expr)
PREC = precedence(expr)
return self.parenthesize("(%s) ? %s : %s" % (cond, true_expr, false_expr), PREC)
return self.parenthesize(f"({cond}) ? {true_expr} : {false_expr}", PREC)

def _print_UnaryOp(self, expr, op=None, parenthesize=False):
op = op or expr._op
Expand All @@ -356,20 +363,23 @@ def _print_Cast(self, expr):
return self._print_UnaryOp(expr, op=cast)

def _print_ComponentAccess(self, expr):
return "%s.%s" % (self._print(expr.base), expr.sindex)
return f"{self._print(expr.base)}.{expr.sindex}"

def _print_DefFunction(self, expr):
arguments = [self._print(i) for i in expr.arguments]
if expr.template:
template = '<%s>' % ','.join([str(i) for i in expr.template])
ctemplate = ','.join([str(i) for i in expr.template])
template = f'<{ctemplate}>'
else:
template = ''
return "%s%s(%s)" % (expr.name, template, ','.join(arguments))
args = ','.join(arguments)
return f"{expr.name}{template}({args})"

def _print_SizeOf(self, expr):
return f'sizeof({self._print(expr.intype)}{self._print(expr.stars)})'

_print_MathFunction = _print_DefFunction
def _print_MathFunction(self, expr):
return f"{self._ns}{self._print_DefFunction(expr)}"

def _print_Fallback(self, expr):
return expr.__str__()
Expand All @@ -385,7 +395,7 @@ def _print_Fallback(self, expr):

# Lifted from SymPy so that we go through our own `_print_math_func`
for k in ('exp log sin cos tan ceiling floor').split():
setattr(BasePrinter, '_print_%s' % k, BasePrinter._print_math_func)
setattr(BasePrinter, f'_print_{k}', BasePrinter._print_math_func)


# Always parenthesize IntDiv and InlineIf within expressions
Expand Down
2 changes: 1 addition & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ def __setstate__(self, state):
self._lib.name = soname

self._allocator = default_allocator(
'%s.%s.%s' % (self._compiler.__class__.name, self._language, self._platform)
'%s.%s.%s' % (type(self._compiler).__name__, self._language, self._platform)
)


Expand Down
4 changes: 2 additions & 2 deletions devito/passes/iet/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
List, Break, Return, FindNodes, FindSymbols, Transformer,
make_callable)
from devito.passes.iet.engine import iet_pass
from devito.symbolics import CondEq, DefFunction
from devito.symbolics import CondEq, MathFunction
from devito.tools import dtype_to_ctype
from devito.types import Eq, Inc, LocalObject, Symbol

Expand Down Expand Up @@ -58,7 +58,7 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None):
irs, byproduct = rcompile(eqns)

name = sregistry.make_name(prefix='is_finite')
retval = Return(DefFunction('isfinite', accumulator))
retval = Return(MathFunction('isfinite', accumulator))
body = irs.iet.body.body + (retval,)
efunc = make_callable(name, body, retval='int')

Expand Down
9 changes: 6 additions & 3 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,9 +776,12 @@ def __new__(cls, intype, stars=None, **kwargs):
stars = stars or ''
if not isinstance(intype, (str, ReservedWord)):
ctype = dtype_to_ctype(intype)
if ctype in ctypes_vector_mapper.values():
idx = list(ctypes_vector_mapper.values()).index(ctype)
intype = list(ctypes_vector_mapper.keys())[idx]
for k, v in ctypes_vector_mapper.items():
if ctype is v:
intype = k
break
else:
intype = ctypes_to_cstr(ctype)

newobj = super().__new__(cls, 'sizeof', arguments=f'{intype}{stars}', **kwargs)
newobj.stars = stars
Expand Down

0 comments on commit 61fb519

Please sign in to comment.