Skip to content

Commit

Permalink
Merge pull request #2279 from devitocodes/dim_abridging_fix
Browse files Browse the repository at this point in the history
compiler: Check DeviceFunctions for SubDimensions
  • Loading branch information
mloubout authored Dec 11, 2023
2 parents b95947e + 054f34f commit 9810a8f
Showing 1 changed file with 33 additions and 16 deletions.
49 changes: 33 additions & 16 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
FindNodes, FindSymbols, Transformer, Uxreplace,
filter_iterations, retrieve_iteration_tree, pull_dims)
from devito.passes.iet.engine import iet_pass
from devito.ir.iet.efunc import DeviceFunction, EntryFunction
from devito.symbolics import evalrel, has_integer_args
from devito.tools import as_mapper, filter_ordered, split

Expand Down Expand Up @@ -209,24 +210,27 @@ def remove_redundant_moddims(iet):
return iet


@singledispatch
def abridge_dim_names(iet):
return iet


@abridge_dim_names.register(DeviceFunction)
def _(iet):
# Catch SubDimensions not in EntryFunction
mapper = _rename_subdims(iet, FindSymbols('dimensions').visit(iet))
return Uxreplace(mapper, nested=True).visit(iet)


@abridge_dim_names.register(EntryFunction)
def _(iet):
# SubDimensions in the main loop nests
mapper = {}
# Build a mapper replacing SubDimension names with respective root dimension
# names where possible
for tree in retrieve_iteration_tree(iet):
# Find SubDimensions or SubDimension-derived dimensions used as indices in
# the expression in the innermost loop
indexeds = FindSymbols('indexeds').visit(tree.inner)
dims = pull_dims(indexeds, flag=False)
dims = [d for d in dims if any([dim.is_Sub for dim in d._defines])]
dims = [d for d in dims if not d.is_SubIterator]
names = [d.root.name for d in dims]

# Rename them to use the name of their root dimension if this will not cause a
# clash with Dimensions or other renamed SubDimensions within this innermost
# loop
mapper.update({d: d._rebuild(d.root.name) for d in dims
if d.root not in tree.dimensions and names.count(d.root.name) < 2})
# Rename SubDimensions present as indices in innermost loop
mapper.update(_rename_subdims(tree.inner, tree.dimensions))

# Update unbound index parents with renamed SubDimensions
dims = set().union(*[i.uindices for i in tree])
Expand All @@ -238,7 +242,20 @@ def abridge_dim_names(iet):
dims = [d for d in dims if d.is_Modulo and d.parent in mapper]
mapper.update({d: d._rebuild(parent=mapper[d.parent]) for d in dims})

if not mapper: # No SubDimensions to be renamed
return iet

return Uxreplace(mapper, nested=True).visit(iet)


def _rename_subdims(target, dimensions):
# Find SubDimensions or SubDimension-derived dimensions used as indices in
# the expression
indexeds = FindSymbols('indexeds').visit(target)
dims = pull_dims(indexeds, flag=False)
dims = [d for d in dims if any([dim.is_Sub for dim in d._defines])]
dims = [d for d in dims if not d.is_SubIterator]
names = [d.root.name for d in dims]

# Rename them to use the name of their root dimension if this will not cause a
# clash with Dimensions or other renamed SubDimensions
return {d: d._rebuild(d.root.name) for d in dims
if d.root not in dimensions
and names.count(d.root.name) < 2}

0 comments on commit 9810a8f

Please sign in to comment.