diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 31576566e4..95a30c93c0 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -8,6 +8,7 @@ FindNodes, FindSymbols, Transformer, Uxreplace, filter_iterations, retrieve_iteration_tree, pull_dims) from devito.passes.iet.engine import iet_pass +from devito.ir.iet.efunc import DeviceFunction, EntryFunction from devito.symbolics import evalrel, has_integer_args from devito.tools import as_mapper, filter_ordered, split @@ -209,24 +210,27 @@ def remove_redundant_moddims(iet): return iet +@singledispatch def abridge_dim_names(iet): + return iet + + +@abridge_dim_names.register(DeviceFunction) +def _(iet): + # Catch SubDimensions not in EntryFunction + mapper = _rename_subdims(iet, FindSymbols('dimensions').visit(iet)) + return Uxreplace(mapper, nested=True).visit(iet) + + +@abridge_dim_names.register(EntryFunction) +def _(iet): + # SubDimensions in the main loop nests mapper = {} # Build a mapper replacing SubDimension names with respective root dimension # names where possible for tree in retrieve_iteration_tree(iet): - # Find SubDimensions or SubDimension-derived dimensions used as indices in - # the expression in the innermost loop - indexeds = FindSymbols('indexeds').visit(tree.inner) - dims = pull_dims(indexeds, flag=False) - dims = [d for d in dims if any([dim.is_Sub for dim in d._defines])] - dims = [d for d in dims if not d.is_SubIterator] - names = [d.root.name for d in dims] - - # Rename them to use the name of their root dimension if this will not cause a - # clash with Dimensions or other renamed SubDimensions within this innermost - # loop - mapper.update({d: d._rebuild(d.root.name) for d in dims - if d.root not in tree.dimensions and names.count(d.root.name) < 2}) + # Rename SubDimensions present as indices in innermost loop + mapper.update(_rename_subdims(tree.inner, tree.dimensions)) # Update unbound index parents with renamed SubDimensions dims = set().union(*[i.uindices for i in tree]) @@ -238,7 +242,20 @@ def abridge_dim_names(iet): dims = [d for d in dims if d.is_Modulo and d.parent in mapper] mapper.update({d: d._rebuild(parent=mapper[d.parent]) for d in dims}) - if not mapper: # No SubDimensions to be renamed - return iet - return Uxreplace(mapper, nested=True).visit(iet) + + +def _rename_subdims(target, dimensions): + # Find SubDimensions or SubDimension-derived dimensions used as indices in + # the expression + indexeds = FindSymbols('indexeds').visit(target) + dims = pull_dims(indexeds, flag=False) + dims = [d for d in dims if any([dim.is_Sub for dim in d._defines])] + dims = [d for d in dims if not d.is_SubIterator] + names = [d.root.name for d in dims] + + # Rename them to use the name of their root dimension if this will not cause a + # clash with Dimensions or other renamed SubDimensions + return {d: d._rebuild(d.root.name) for d in dims + if d.root not in dimensions + and names.count(d.root.name) < 2}