From 507e04be553fb86cbf6ac1d607b57e6bf857a132 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Fri, 8 Dec 2023 16:40:58 +0000 Subject: [PATCH 1/2] compiler: Check DeviceFunctions for SubDimensions --- devito/passes/iet/misc.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 31576566e4..2e819160a9 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -8,6 +8,7 @@ FindNodes, FindSymbols, Transformer, Uxreplace, filter_iterations, retrieve_iteration_tree, pull_dims) from devito.passes.iet.engine import iet_pass +from devito.ir.iet.efunc import DeviceFunction from devito.symbolics import evalrel, has_integer_args from devito.tools import as_mapper, filter_ordered, split @@ -210,23 +211,37 @@ def remove_redundant_moddims(iet): def abridge_dim_names(iet): - mapper = {} - # Build a mapper replacing SubDimension names with respective root dimension - # names where possible - for tree in retrieve_iteration_tree(iet): + + def rename_subdims(target, dimensions): # Find SubDimensions or SubDimension-derived dimensions used as indices in - # the expression in the innermost loop - indexeds = FindSymbols('indexeds').visit(tree.inner) + # the expression + indexeds = FindSymbols('indexeds').visit(target) dims = pull_dims(indexeds, flag=False) dims = [d for d in dims if any([dim.is_Sub for dim in d._defines])] dims = [d for d in dims if not d.is_SubIterator] names = [d.root.name for d in dims] # Rename them to use the name of their root dimension if this will not cause a - # clash with Dimensions or other renamed SubDimensions within this innermost - # loop - mapper.update({d: d._rebuild(d.root.name) for d in dims - if d.root not in tree.dimensions and names.count(d.root.name) < 2}) + # clash with Dimensions or other renamed SubDimensions + return {d: d._rebuild(d.root.name) for d in dims + if d.root not in dimensions + and names.count(d.root.name) < 2} + + if isinstance(iet, DeviceFunction): + # Catch SubDimensions not in EntryFunction + mapper = rename_subdims(iet, FindSymbols('dimensions').visit(iet)) + + if not mapper: # No SubDimensions to be renamed + return iet + + return Uxreplace(mapper, nested=True).visit(iet) + + mapper = {} + # Build a mapper replacing SubDimension names with respective root dimension + # names where possible + for tree in retrieve_iteration_tree(iet): + # Rename SubDimensions present as indices in innermost loop + mapper.update(rename_subdims(tree.inner, tree.dimensions)) # Update unbound index parents with renamed SubDimensions dims = set().union(*[i.uindices for i in tree]) From 054f34f443e9ccc4a47e76b6687798b4f673efcf Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Mon, 11 Dec 2023 11:24:36 +0000 Subject: [PATCH 2/2] compiler: Refactor abridge_dim_names as singledispatch --- devito/passes/iet/misc.py | 54 ++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 2e819160a9..95a30c93c0 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -8,7 +8,7 @@ FindNodes, FindSymbols, Transformer, Uxreplace, filter_iterations, retrieve_iteration_tree, pull_dims) from devito.passes.iet.engine import iet_pass -from devito.ir.iet.efunc import DeviceFunction +from devito.ir.iet.efunc import DeviceFunction, EntryFunction from devito.symbolics import evalrel, has_integer_args from devito.tools import as_mapper, filter_ordered, split @@ -210,38 +210,27 @@ def remove_redundant_moddims(iet): return iet +@singledispatch def abridge_dim_names(iet): + return iet - def rename_subdims(target, dimensions): - # Find SubDimensions or SubDimension-derived dimensions used as indices in - # the expression - indexeds = FindSymbols('indexeds').visit(target) - dims = pull_dims(indexeds, flag=False) - dims = [d for d in dims if any([dim.is_Sub for dim in d._defines])] - dims = [d for d in dims if not d.is_SubIterator] - names = [d.root.name for d in dims] - - # Rename them to use the name of their root dimension if this will not cause a - # clash with Dimensions or other renamed SubDimensions - return {d: d._rebuild(d.root.name) for d in dims - if d.root not in dimensions - and names.count(d.root.name) < 2} - - if isinstance(iet, DeviceFunction): - # Catch SubDimensions not in EntryFunction - mapper = rename_subdims(iet, FindSymbols('dimensions').visit(iet)) - if not mapper: # No SubDimensions to be renamed - return iet +@abridge_dim_names.register(DeviceFunction) +def _(iet): + # Catch SubDimensions not in EntryFunction + mapper = _rename_subdims(iet, FindSymbols('dimensions').visit(iet)) + return Uxreplace(mapper, nested=True).visit(iet) - return Uxreplace(mapper, nested=True).visit(iet) +@abridge_dim_names.register(EntryFunction) +def _(iet): + # SubDimensions in the main loop nests mapper = {} # Build a mapper replacing SubDimension names with respective root dimension # names where possible for tree in retrieve_iteration_tree(iet): # Rename SubDimensions present as indices in innermost loop - mapper.update(rename_subdims(tree.inner, tree.dimensions)) + mapper.update(_rename_subdims(tree.inner, tree.dimensions)) # Update unbound index parents with renamed SubDimensions dims = set().union(*[i.uindices for i in tree]) @@ -253,7 +242,20 @@ def rename_subdims(target, dimensions): dims = [d for d in dims if d.is_Modulo and d.parent in mapper] mapper.update({d: d._rebuild(parent=mapper[d.parent]) for d in dims}) - if not mapper: # No SubDimensions to be renamed - return iet - return Uxreplace(mapper, nested=True).visit(iet) + + +def _rename_subdims(target, dimensions): + # Find SubDimensions or SubDimension-derived dimensions used as indices in + # the expression + indexeds = FindSymbols('indexeds').visit(target) + dims = pull_dims(indexeds, flag=False) + dims = [d for d in dims if any([dim.is_Sub for dim in d._defines])] + dims = [d for d in dims if not d.is_SubIterator] + names = [d.root.name for d in dims] + + # Rename them to use the name of their root dimension if this will not cause a + # clash with Dimensions or other renamed SubDimensions + return {d: d._rebuild(d.root.name) for d in dims + if d.root not in dimensions + and names.count(d.root.name) < 2}