Skip to content

Commit

Permalink
compiler: Refactor abridge_dim_names as singledispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
EdCaunt committed Dec 11, 2023
1 parent 507e04b commit 054f34f
Showing 1 changed file with 28 additions and 26 deletions.
54 changes: 28 additions & 26 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
FindNodes, FindSymbols, Transformer, Uxreplace,
filter_iterations, retrieve_iteration_tree, pull_dims)
from devito.passes.iet.engine import iet_pass
from devito.ir.iet.efunc import DeviceFunction
from devito.ir.iet.efunc import DeviceFunction, EntryFunction
from devito.symbolics import evalrel, has_integer_args
from devito.tools import as_mapper, filter_ordered, split

Expand Down Expand Up @@ -210,38 +210,27 @@ def remove_redundant_moddims(iet):
return iet


@singledispatch
def abridge_dim_names(iet):
return iet

def rename_subdims(target, dimensions):
# Find SubDimensions or SubDimension-derived dimensions used as indices in
# the expression
indexeds = FindSymbols('indexeds').visit(target)
dims = pull_dims(indexeds, flag=False)
dims = [d for d in dims if any([dim.is_Sub for dim in d._defines])]
dims = [d for d in dims if not d.is_SubIterator]
names = [d.root.name for d in dims]

# Rename them to use the name of their root dimension if this will not cause a
# clash with Dimensions or other renamed SubDimensions
return {d: d._rebuild(d.root.name) for d in dims
if d.root not in dimensions
and names.count(d.root.name) < 2}

if isinstance(iet, DeviceFunction):
# Catch SubDimensions not in EntryFunction
mapper = rename_subdims(iet, FindSymbols('dimensions').visit(iet))

if not mapper: # No SubDimensions to be renamed
return iet
@abridge_dim_names.register(DeviceFunction)
def _(iet):
# Catch SubDimensions not in EntryFunction
mapper = _rename_subdims(iet, FindSymbols('dimensions').visit(iet))
return Uxreplace(mapper, nested=True).visit(iet)

return Uxreplace(mapper, nested=True).visit(iet)

@abridge_dim_names.register(EntryFunction)
def _(iet):
# SubDimensions in the main loop nests
mapper = {}
# Build a mapper replacing SubDimension names with respective root dimension
# names where possible
for tree in retrieve_iteration_tree(iet):
# Rename SubDimensions present as indices in innermost loop
mapper.update(rename_subdims(tree.inner, tree.dimensions))
mapper.update(_rename_subdims(tree.inner, tree.dimensions))

# Update unbound index parents with renamed SubDimensions
dims = set().union(*[i.uindices for i in tree])
Expand All @@ -253,7 +242,20 @@ def rename_subdims(target, dimensions):
dims = [d for d in dims if d.is_Modulo and d.parent in mapper]
mapper.update({d: d._rebuild(parent=mapper[d.parent]) for d in dims})

if not mapper: # No SubDimensions to be renamed
return iet

return Uxreplace(mapper, nested=True).visit(iet)


def _rename_subdims(target, dimensions):
# Find SubDimensions or SubDimension-derived dimensions used as indices in
# the expression
indexeds = FindSymbols('indexeds').visit(target)
dims = pull_dims(indexeds, flag=False)
dims = [d for d in dims if any([dim.is_Sub for dim in d._defines])]
dims = [d for d in dims if not d.is_SubIterator]
names = [d.root.name for d in dims]

# Rename them to use the name of their root dimension if this will not cause a
# clash with Dimensions or other renamed SubDimensions
return {d: d._rebuild(d.root.name) for d in dims
if d.root not in dimensions
and names.count(d.root.name) < 2}

0 comments on commit 054f34f

Please sign in to comment.