Skip to content
This repository has been archived by the owner on Jan 20, 2024. It is now read-only.

Commit

Permalink
Squashed Commit Of All Map Changes in amd-trunk-dev-map-b4-squash
Browse files Browse the repository at this point in the history
Consists of:
   1-D array sectioning mapping
   initial declare target mapping
   initial 1-D pointer/alloca/target mapping
   A start at structure lowering with contained pointers (can't test yet due to some required frontned work, but it's utilised to lower pointers etc. at the moment, which is similar in Fortran)
   tidying up of the above
   Sergio's CSE patch
  • Loading branch information
agozillon committed Oct 14, 2023
1 parent c57822a commit 710dc8e
Show file tree
Hide file tree
Showing 25 changed files with 1,102 additions and 309 deletions.
7 changes: 6 additions & 1 deletion flang/include/flang/Lower/OpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class ConvertOp;
namespace Fortran {
namespace parser {
struct OpenMPConstruct;
struct OpenMPBlockConstruct;
struct OpenMPDeclarativeConstruct;
struct OmpEndLoopDirective;
struct OmpClauseList;
Expand All @@ -51,7 +52,6 @@ struct Variable;
// Generate the OpenMP terminator for Operation at Location.
void genOpenMPTerminator(fir::FirOpBuilder &, mlir::Operation *,
mlir::Location);

void genOpenMPConstruct(AbstractConverter &, semantics::SemanticsContext &,
pft::Evaluation &, const parser::OpenMPConstruct &);
void genOpenMPDeclarativeConstruct(AbstractConverter &, pft::Evaluation &,
Expand All @@ -61,6 +61,11 @@ void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &);
void genOpenMPReduction(AbstractConverter &,
const Fortran::parser::OmpClauseList &clauseList);
void genImplicitMapsForTarget(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPBlockConstruct &ompBlock);

mlir::Operation *findReductionChain(mlir::Value, mlir::Value * = nullptr);
fir::ConvertOp getConvertFromReductionOp(mlir::Operation *, mlir::Value);
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Frontend/FrontendActions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ bool CodeGenAction::beginSourceFileAction() {
}

pm.enableVerifier(/*verifyPasses=*/true);

pm.addPass(std::make_unique<Fortran::lower::VerifierPass>());

if (mlir::failed(pm.run(*mlirModule))) {
Expand Down
5 changes: 5 additions & 0 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2363,6 +2363,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
genOpenMPReduction(*this, blockClauses);
}

if (ompBlock) {
genImplicitMapsForTarget(*this, bridge.getSemanticsContext(), getEval(),
*ompBlock);
}

localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);

Expand Down
1 change: 1 addition & 0 deletions flang/lib/Lower/ConvertVariable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ static fir::GlobalOp declareGlobal(Fortran::lower::AbstractConverter &converter,
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
if (fir::GlobalOp global = builder.getNamedGlobal(globalName))
return global;

// Always define linkonce data since it may be optimized out from the module
// that actually owns the variable if it does not refers to it.
if (linkage == builder.createLinkOnceODRLinkage() ||
Expand Down
320 changes: 249 additions & 71 deletions flang/lib/Lower/OpenMP.cpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,7 @@ fir::BoxValue fir::factory::createBoxValue(fir::FirOpBuilder &builder,
const fir::ExtendedValue &exv) {
if (auto *boxValue = exv.getBoxOf<fir::BoxValue>())
return *boxValue;

mlir::Value box = builder.createBox(loc, exv);
llvm::SmallVector<mlir::Value> lbounds;
llvm::SmallVector<mlir::Value> explicitTypeParams;
Expand Down
6 changes: 6 additions & 0 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,7 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
bool useInputType = fir::isPolymorphicType(boxTy) || isUnlimitedPolymorphic;
mlir::Value descriptor =
rewriter.create<mlir::LLVM::UndefOp>(loc, llvmBoxTy);

descriptor =
insertField(rewriter, loc, descriptor, {kElemLenPosInBox}, eleSize);
descriptor = insertField(rewriter, loc, descriptor, {kVersionPosInBox},
Expand Down Expand Up @@ -1483,6 +1484,7 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
fir::unwrapIfDerived(boxTy));
}
}

if (typeDesc)
descriptor =
insertField(rewriter, loc, descriptor, {typeDescFieldId}, typeDesc,
Expand Down Expand Up @@ -1919,6 +1921,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
mlir::LogicalResult
matchAndRewrite(fir::cg::XReboxOp rebox, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

mlir::Location loc = rebox.getLoc();
mlir::Type idxTy = lowerTy().indexType();
mlir::Value loweredBox = adaptor.getOperands()[0];
Expand Down Expand Up @@ -3785,6 +3788,9 @@ class FIRToLLVMLowering
if (mlir::failed(runPipeline(mathConvertionPM, mod)))
return signalPassFailure();

// if a loadop spawns an alloca if it's a box, why does the alloca for the
// target op load op end up external to the target op?

auto *context = getModule().getContext();
fir::LLVMTypeConverter typeConverter{getModule(),
options.applyTBAA || applyTBAA,
Expand Down
12 changes: 11 additions & 1 deletion flang/lib/Optimizer/Transforms/OMPEarlyOutlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,17 @@ class OMPEarlyOutliningPass
llvm::SetVector<mlir::Value> inputs;
mlir::Region &targetRegion = targetOp.getRegion();
mlir::getUsedValuesDefinedAbove(targetRegion, inputs);


// Collect all map info, even non-used maps must be collected to avoid
// ICEs.
for (mlir::Value oper : targetOp->getOperands()) {
if (auto mapEntry =
mlir::dyn_cast<mlir::omp::MapInfoOp>(oper.getDefiningOp())) {
if (!inputs.contains(mapEntry.getVarPtr()))
inputs.insert(mapEntry.getVarPtr());
}
}

// filter out declareTarget and map entries which are specially handled
// at the moment, so we do not wish these to end up as function arguments
// which would just be more noise in the IR.
Expand Down
1 change: 1 addition & 0 deletions flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ func.func @_QPomp_target() {
// CHECK: llvm.return
// CHECK: }


// -----

func.func @_QPsimdloop_with_nested_loop() {
Expand Down
14 changes: 7 additions & 7 deletions flang/test/Lower/OpenMP/FIR/array-bounds.f90
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
!DEVICE: %[[C8:.*]] = arith.constant 1 : index
!DEVICE: %[[BOUNDS1:.*]] = omp.bounds lower_bound(%[[C5]] : index) upper_bound(%[[C6]] : index) stride(%[[C8]] : index) start_idx(%[[C8]] : index)
!DEVICE: %[[MAP1:.*]] = omp.map_info var_ptr(%[[ARG2]] : !fir.ref<!fir.array<10xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS1]]) -> !fir.ref<!fir.array<10xi32>> {name = "sp_write(2:5)"}
!DEVICE: omp.target map_entries(%[[MAP0]], %[[MAP1]] : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>) {
!DEVICE: omp.target map_entries(%[[MAP0]], %[[MAP1]], {{.*}} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>, {{.*}}) {

!HOST-LABEL: func.func @_QPread_write_section() {
!HOST: %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFread_write_sectionEi"}
Expand All @@ -31,8 +31,8 @@
!HOST: %[[C6:.*]] = arith.constant 4 : index
!HOST: %[[BOUNDS1:.*]] = omp.bounds lower_bound(%[[C5]] : index) upper_bound(%[[C6]] : index) stride(%[[C4]] : index) start_idx(%[[C4]] : index)
!HOST: %[[MAP1:.*]] = omp.map_info var_ptr(%[[WRITE]] : !fir.ref<!fir.array<10xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS1]]) -> !fir.ref<!fir.array<10xi32>> {name = "sp_write(2:5)"}
!HOST: omp.target map_entries(%[[MAP0]], %[[MAP1]] : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>) {

!HOST: %[[MAP2:.*]] = omp.map_info var_ptr({{.*}}) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {implicit = true, name = "i"}
!HOST: omp.target map_entries(%[[MAP0]], %[[MAP1]], %[[MAP2]] : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>, !fir.ref<i32>) {
subroutine read_write_section()
integer :: sp_read(10) = (/1,2,3,4,5,6,7,8,9,10/)
integer :: sp_write(10) = (/0,0,0,0,0,0,0,0,0,0/)
Expand All @@ -57,7 +57,7 @@ module assumed_array_routines
!DEVICE: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C0]] : index) upper_bound(%[[C1]] : index) stride(%[[C3]]#2 : index) start_idx(%[[C4]] : index) {stride_in_bytes = true}
!DEVICE: %[[ARGADDR:.*]] = fir.box_addr %[[ARG1]] : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
!DEVICE: %[[MAP:.*]] = omp.map_info var_ptr(%[[ARGADDR]] : !fir.ref<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
!DEVICE: omp.target map_entries(%[[MAP]] : !fir.ref<!fir.array<?xi32>>) {
!DEVICE: omp.target map_entries(%[[MAP]], {{.*}} : !fir.ref<!fir.array<?xi32>>, {{.*}}) {

!HOST-LABEL: func.func @_QMassumed_array_routinesPassumed_shape_array(
!HOST-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
Expand All @@ -70,7 +70,7 @@ module assumed_array_routines
!HOST: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C3]] : index) upper_bound(%[[C4]] : index) stride(%[[C2]]#2 : index) start_idx(%[[C0]] : index) {stride_in_bytes = true}
!HOST: %[[ADDROF:.*]] = fir.box_addr %arg0 : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
!HOST: %[[MAP:.*]] = omp.map_info var_ptr(%[[ADDROF]] : !fir.ref<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
!HOST: omp.target map_entries(%[[MAP]] : !fir.ref<!fir.array<?xi32>>) {
!HOST: omp.target map_entries(%[[MAP]], {{.*}} : !fir.ref<!fir.array<?xi32>>, {{.*}}) {
subroutine assumed_shape_array(arr_read_write)
integer, intent(inout) :: arr_read_write(:)

Expand All @@ -89,7 +89,7 @@ end subroutine assumed_shape_array
!DEVICE: %[[C3:.*]] = arith.constant 1 : index
!DEVICE: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C0]] : index) upper_bound(%[[C1]] : index) stride(%[[C3]] : index) start_idx(%[[C3]] : index)
!DEVICE: %[[MAP:.*]] = omp.map_info var_ptr(%[[ARG1]] : !fir.ref<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
!DEVICE: omp.target map_entries(%[[MAP]] : !fir.ref<!fir.array<?xi32>>) {
!DEVICE: omp.target map_entries(%[[MAP]], {{.*}} : !fir.ref<!fir.array<?xi32>>, {{.*}}) {

!HOST-LABEL: func.func @_QMassumed_array_routinesPassumed_size_array(
!HOST-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
Expand All @@ -99,7 +99,7 @@ end subroutine assumed_shape_array
!HOST: %[[C2:.*]] = arith.constant 4 : index
!HOST: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C1]] : index) upper_bound(%[[C2]] : index) stride(%[[C0]] : index) start_idx(%[[C0]] : index)
!HOST: %[[MAP:.*]] = omp.map_info var_ptr(%[[ARG0]] : !fir.ref<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
!HOST: omp.target map_entries(%[[MAP]] : !fir.ref<!fir.array<?xi32>>) {
!HOST: omp.target map_entries(%[[MAP]], {{.*}} : !fir.ref<!fir.array<?xi32>>, {{.*}}) {
subroutine assumed_size_array(arr_read_write)
integer, intent(inout) :: arr_read_write(*)

Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/OpenMP/FIR/location.f90
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ subroutine sub_parallel()
!CHECK-LABEL: sub_target
subroutine sub_target()
print *, x
!CHECK: omp.target {
!CHECK: omp.target {{.*}} {
!$omp target
print *, x
!CHECK: omp.terminator loc(#[[TAR_LOC:.*]])
Expand Down
49 changes: 25 additions & 24 deletions flang/test/Lower/OpenMP/FIR/omp-target-early-outlining.f90
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,31 @@ SUBROUTINE TARGET_FUNCTION()
!CHECK: %[[C1_1:.*]] = arith.constant 1 : index
!CHECK: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C1]] : index) upper_bound(%[[C4]] : index) stride(%[[C1_1]] : index) start_idx(%[[C1_1]] : index)
!CHECK: %[[ENTRY:.*]] = omp.map_info var_ptr(%[[ARG1]] : !fir.ref<!fir.array<10xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<10xi32>> {name = "sp_write(2:5)"}
!CHECK: omp.target map_entries(%[[ENTRY]] : !fir.ref<!fir.array<10xi32>>) {
!CHECK: %c2_i32 = arith.constant 2 : i32
!CHECK: %2 = fir.convert %c2_i32 : (i32) -> index
!CHECK: %c5_i32 = arith.constant 5 : i32
!CHECK: %3 = fir.convert %c5_i32 : (i32) -> index
!CHECK: %c1_2 = arith.constant 1 : index
!CHECK: %4 = fir.convert %2 : (index) -> i32
!CHECK: %5:2 = fir.do_loop %arg2 = %2 to %3 step %c1_2 iter_args(%arg3 = %4) -> (index, i32) {
!CHECK: fir.store %arg3 to %[[ARG0]] : !fir.ref<i32>
!CHECK: %6 = fir.load %[[ARG0]] : !fir.ref<i32>
!CHECK: %7 = fir.load %[[ARG0]] : !fir.ref<i32>
!CHECK: %8 = fir.convert %7 : (i32) -> i64
!CHECK: %c1_i64 = arith.constant 1 : i64
!CHECK: %9 = arith.subi %8, %c1_i64 : i64
!CHECK: %10 = fir.coordinate_of %[[ARG1]], %9 : (!fir.ref<!fir.array<10xi32>>, i64) -> !fir.ref<i32>
!CHECK: fir.store %6 to %10 : !fir.ref<i32>
!CHECK: %11 = arith.addi %arg2, %c1_2 : index
!CHECK: %12 = fir.convert %c1_2 : (index) -> i32
!CHECK: %13 = fir.load %[[ARG0]] : !fir.ref<i32>
!CHECK: %14 = arith.addi %13, %12 : i32
!CHECK: fir.result %11, %14 : index, i32
!CHECK: }
!CHECK: fir.store %5#1 to %[[ARG0]] : !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: %[[ENTRY_IMP:.*]] = omp.map_info var_ptr(%[[ARG0]] : !fir.ref<i32>) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {implicit = true, name = "i"}
!CHECK: omp.target map_entries(%[[ENTRY]], %[[ENTRY_IMP]] : !fir.ref<!fir.array<10xi32>>, !fir.ref<i32>) {
!CHECK: %c2_i32 = arith.constant 2 : i32
!CHECK: %3 = fir.convert %c2_i32 : (i32) -> index
!CHECK: %c5_i32 = arith.constant 5 : i32
!CHECK: %4 = fir.convert %c5_i32 : (i32) -> index
!CHECK: %c1_2 = arith.constant 1 : index
!CHECK: %5 = fir.convert %3 : (index) -> i32
!CHECK: %6:2 = fir.do_loop %arg2 = %3 to %4 step %c1_2 iter_args(%arg3 = %5) -> (index, i32) {
!CHECK: fir.store %arg3 to %[[ARG0]] : !fir.ref<i32>
!CHECK: %7 = fir.load %[[ARG0]] : !fir.ref<i32>
!CHECK: %8 = fir.load %[[ARG0]] : !fir.ref<i32>
!CHECK: %9 = fir.convert %8 : (i32) -> i64
!CHECK: %c1_i64 = arith.constant 1 : i64
!CHECK: %10 = arith.subi %9, %c1_i64 : i64
!CHECK: %11 = fir.coordinate_of %[[ARG1]], %10 : (!fir.ref<!fir.array<10xi32>>, i64) -> !fir.ref<i32>
!CHECK: fir.store %7 to %11 : !fir.ref<i32>
!CHECK: %12 = arith.addi %arg2, %c1_2 : index
!CHECK: %13 = fir.convert %c1_2 : (index) -> i32
!CHECK: %14 = fir.load %[[ARG0]] : !fir.ref<i32>
!CHECK: %15 = arith.addi %14, %13 : i32
!CHECK: fir.result %12, %15 : index, i32
!CHECK: }
!CHECK: fir.store %6#1 to %[[ARG0]] : !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK:return
!CHECK:}
Expand Down
8 changes: 4 additions & 4 deletions flang/test/Lower/OpenMP/FIR/target.f90
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ end subroutine omp_target_exit_simple
!===============================================================================
! Target_Exit Map types
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_exit_mt() {
subroutine omp_target_exit_mt
integer :: a(1024)
Expand Down Expand Up @@ -245,11 +244,12 @@ end subroutine omp_target_device_ptr
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_device_addr() {
subroutine omp_target_device_addr
subroutine omp_target_device_addr
integer, pointer :: a
!CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.ptr<i32>> {bindc_name = "a", uniq_name = "_QFomp_target_device_addrEa"}
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}}) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "a"}
!CHECK: omp.target_data map_entries(%[[MAP]] : {{.*}}) use_device_addr(%[[VAL_0]] : !fir.ref<!fir.box<!fir.ptr<i32>>>) {
!CHECK: %[[MAP1:.*]] = omp.map_info var_ptr({{.*}}) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "a"}
!CHECK: %[[MAP2:.*]] = omp.map_info var_ptr({{.*}}) var_ptr_ptr({{.*}}) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "a"}
!CHECK: omp.target_data map_entries(%[[MAP1]], %[[MAP2]] : {{.*}}) use_device_addr(%[[VAL_0]] : !fir.ref<!fir.box<!fir.ptr<i32>>>) {
!$omp target data map(tofrom: a) use_device_addr(a)
!CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref<!fir.box<!fir.ptr<i32>>>):
!CHECK: {{.*}} = fir.load %[[VAL_1]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/OpenMP/location.f90
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ subroutine sub_parallel()
!CHECK-LABEL: sub_target
subroutine sub_target()
print *, x
!CHECK: omp.target {
!CHECK: omp.target {{.*}} {
!$omp target
print *, x
!CHECK: omp.terminator loc(#[[TAR_LOC:.*]])
Expand Down
3 changes: 0 additions & 3 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,6 @@ class OffloadEntriesInfoManager {

/// Kind of device clause for declare target variables
/// and functions
/// NOTE: Currently not used as a part of a variable entry
/// used for Flang and Clang to interface with the variable
/// related registration functions
enum OMPTargetDeviceClauseKind : uint32_t {
/// The target is marked for all devices
OMPTargetDeviceClauseAny = 0x0,
Expand Down
41 changes: 35 additions & 6 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4946,10 +4946,12 @@ static Function *createOutlinedFunction(
ParameterTypes.push_back(Arg->getType());
}

auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes,
/*isVarArg*/ false);
auto Func = Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName,
Builder.GetInsertBlock()->getModule());
FunctionType *FuncType =
FunctionType::get(Builder.getVoidTy(), ParameterTypes,
/*isVarArg*/ false);
Function *Func =
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName,
Builder.GetInsertBlock()->getModule());

if (OMPBuilder.Config.isTargetDevice()) {
std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
Expand Down Expand Up @@ -4993,11 +4995,38 @@ static Function *createOutlinedFunction(
Builder.restoreIP(
ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP()));

// Collect all the instructions
// Things like GEP's can come in the form of Constants, constants and
// ConstantExpr's do not have access to the knowledge of what they're
// contained in, so we must dig a little to find an instruction so we can
// tell if they're used inside of the function we're outlining. We also
// replace the original constant expression with a new instruction
// equivelant; an instruction as it allows easy modification in the
// following loop, as we can now know the constant (instruction) is owned by
// our target function and replaceUsesOfWith can now be invoked on it
// (cannot do this with constants it seems), a brand new one also allows us
// to be cautious as it is perhaps possible the old expression was used
// inside of the function but exists and is used externally (unlikely by the
// nature of a Constant, but still)
auto ReplaceConstantUsedInFunction = [](Constant *Const, Function *Func) {
if (auto *ConstExpr = dyn_cast<ConstantExpr>(Const))
for (User *User : make_early_inc_range(ConstExpr->users()))
if (auto *Instr = dyn_cast<Instruction>(User))
if (Instr->getFunction() == Func)
Instr->replaceUsesOfWith(ConstExpr,
ConstExpr->getAsInstruction(Instr));
};

for (User *User : make_early_inc_range(Input->users()))
if (auto Instr = dyn_cast<Instruction>(User))
if (auto Const = dyn_cast<Constant>(User))
ReplaceConstantUsedInFunction(Const, Func);

// Collect all the instructions
for (User *User : make_early_inc_range(Input->users())) {
if (auto *Instr = dyn_cast<Instruction>(User)) {
if (Instr->getFunction() == Func)
Instr->replaceUsesOfWith(Input, InputCopy);
}
}
}

// Restore insert point.
Expand Down
2 changes: 1 addition & 1 deletion llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5672,7 +5672,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
};

llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP)
-> llvm::OpenMPIRBuilder::MapInfosTy & {
CreateDefaultMapInfos(OMPBuilder, CapturedArgs, CombinedInfos);
return CombinedInfos;
Expand Down
4 changes: 1 addition & 3 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> {
//===---------------------------------------------------------------------===//
// 2.14.2 target data Construct
//===---------------------------------------------------------------------===//

def Target_DataOp: OpenMP_Op<"target_data", [AttrSizedOperandSegments]>{
let summary = "target data construct";
let description = [{
Expand Down Expand Up @@ -1284,9 +1284,7 @@ def Target_DataOp: OpenMP_Op<"target_data", [AttrSizedOperandSegments]>{
Variadic<OpenMP_PointerLikeType>:$use_device_ptr,
Variadic<OpenMP_PointerLikeType>:$use_device_addr,
Variadic<OpenMP_PointerLikeType>:$map_operands);

let regions = (region AnyRegion:$region);

let assemblyFormat = [{
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
| `device` `(` $device `:` type($device) `)`
Expand Down
Loading

0 comments on commit 710dc8e

Please sign in to comment.