Skip to content
This repository has been archived by the owner on Jan 20, 2024. It is now read-only.

Commit

Permalink
Fix rebase breakages I made, and fix some test changes/additions sinc…
Browse files Browse the repository at this point in the history
…e previous rebase
  • Loading branch information
agozillon committed Aug 22, 2023
1 parent ba0fdc8 commit 38aabd4
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 101 deletions.
120 changes: 39 additions & 81 deletions flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2805,9 +2805,9 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter,

static mlir::omp::DataOp
genDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
mlir::Value ifClauseOperand, deviceOperand;
llvm::SmallVector<mlir::Value> mapOperands, devicePtrOperands,
Expand All @@ -2825,16 +2825,12 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
useDeviceSymbols);
cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
useDeviceSymbols);
// cp.processMap(mapOperands, mapTypes);

llvm::SmallVector<mlir::Attribute> mapTypesAttr(mapTypes.begin(),
mapTypes.end());
mlir::ArrayAttr mapTypesArrayAttr =
mlir::ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr);
cp.processMap(currentLocation, llvm::omp::Directive::OMPD_target_data,
semanticsContext, stmtCtx, mapOperands);

auto dataOp = converter.getFirOpBuilder().create<mlir::omp::DataOp>(
currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
deviceAddrOperands, mapOperands, mapTypesArrayAttr);
deviceAddrOperands, mapOperands);
createBodyOfTargetDataOp(converter, dataOp, useDeviceTypes, useDeviceLocs,
useDeviceSymbols, currentLocation);
return dataOp;
Expand All @@ -2843,20 +2839,23 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
template <typename OpTy>
static OpTy
genEnterExitDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
mlir::Value ifClauseOperand, deviceOperand;
mlir::UnitAttr nowaitAttr;
llvm::SmallVector<mlir::Value> mapOperands;
llvm::SmallVector<mlir::IntegerAttr> mapTypes;

Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
llvm::omp::Directive directive;
if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) {
directive = llvm::omp::Directive::OMPD_target_enter_data;
directiveName =
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
} else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) {
directive = llvm::omp::Directive::OMPD_target_exit_data;
directiveName =
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
} else {
Expand All @@ -2867,27 +2866,20 @@ genEnterExitDataOp(Fortran::lower::AbstractConverter &converter,
cp.processIf(stmtCtx, directiveName, ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processNowait(nowaitAttr);
// cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
// mapOperands, mapTypes, mapCaptureKinds, mapUShape, mapUBounds,
// mapLShape, mapLBounds);

llvm::SmallVector<mlir::Attribute> mapTypesAttr(mapTypes.begin(),
mapTypes.end());
mlir::ArrayAttr mapTypesArrayAttr =
mlir::ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr);
cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
mapOperands);

return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
deviceOperand, nowaitAttr, mapOperands,
mapTypesArrayAttr);
return converter.getFirOpBuilder().create<OpTy>(
currentLocation, ifClauseOperand, deviceOperand, nowaitAttr, mapOperands);
}

static mlir::omp::TargetOp
genTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList,
bool outerCombined = false) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
llvm::omp::Directive directive, bool outerCombined = false) {
Fortran::lower::StatementContext stmtCtx;
mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
mlir::UnitAttr nowaitAttr;
Expand All @@ -2901,29 +2893,13 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
cp.processDevice(stmtCtx, deviceOperand);
cp.processThreadLimit(stmtCtx, threadLimitOperand);
cp.processNowait(nowaitAttr);
// cp.processMap(mapOperands, mapTypes);

llvm::SmallVector<mlir::Attribute> mapTypesAttr(mapTypes.begin(),
mapTypes.end());
mlir::ArrayAttr mapTypesArrayAttr =
mlir::ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr);
mlir::ArrayAttr mapCaptureKindsArr =
mlir::ArrayAttr::get(firOpBuilder.getContext(),
llvm::SmallVector<mlir::Attribute>{
mapCaptureKinds.begin(), mapCaptureKinds.end()});
mlir::DenseIntElementsAttr uBound = mlir::DenseIntElementsAttr::get(
mlir::VectorType::get(llvm::ArrayRef<int64_t>(mapUShape),
firOpBuilder.getI64Type()),
llvm::ArrayRef<int64_t>{mapUBounds});
mlir::DenseIntElementsAttr lBound = mlir::DenseIntElementsAttr::get(
mlir::VectorType::get(llvm::ArrayRef<int64_t>(mapLShape),
firOpBuilder.getI64Type()),
llvm::ArrayRef<int64_t>{mapLBounds});
cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
mapOperands);

return genOpWithBody<mlir::omp::TargetOp>(
converter, eval, currentLocation, outerCombined, &clauseList,
ifClauseOperand, deviceOperand, threadLimitOperand, nowaitAttr,
mapOperands, mapTypesArrayAttr);
mapOperands);
}

static mlir::omp::TeamsOp
Expand Down Expand Up @@ -2964,11 +2940,11 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
// genOMP() Code generation helper functions
//===----------------------------------------------------------------------===//

static void
genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPSimpleStandaloneConstruct
&simpleStandaloneConstruct) {
static void genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
const Fortran::parser::OpenMPSimpleStandaloneConstruct
&simpleStandaloneConstruct) {
const auto &directive =
std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
simpleStandaloneConstruct.t);
Expand All @@ -2984,25 +2960,21 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
firOpBuilder.create<mlir::omp::BarrierOp>(currentLocation);
break;
case llvm::omp::Directive::OMPD_taskwait:
ClauseProcessor(converter, opClauseList)
.processTODO<Fortran::parser::OmpClause::Depend,
Fortran::parser::OmpClause::Nowait>(
currentLocation, llvm::omp::Directive::OMPD_taskwait);
firOpBuilder.create<mlir::omp::TaskwaitOp>(currentLocation);
break;
case llvm::omp::Directive::OMPD_taskyield:
firOpBuilder.create<mlir::omp::TaskyieldOp>(currentLocation);
break;
case llvm::omp::Directive::OMPD_target_data:
genDataOp(converter, currentLocation, opClauseList);
genDataOp(converter, semanticsContext, currentLocation, opClauseList);
break;
case llvm::omp::Directive::OMPD_target_enter_data:
genEnterExitDataOp<mlir::omp::EnterDataOp>(converter, currentLocation,
opClauseList);
genEnterExitDataOp<mlir::omp::EnterDataOp>(converter, semanticsContext,
currentLocation, opClauseList);
break;
case llvm::omp::Directive::OMPD_target_exit_data:
genEnterExitDataOp<mlir::omp::ExitDataOp>(converter, currentLocation,
opClauseList);
genEnterExitDataOp<mlir::omp::ExitDataOp>(converter, semanticsContext,
currentLocation, opClauseList);
break;
case llvm::omp::Directive::OMPD_target_update:
TODO(currentLocation, "OMPD_target_update");
Expand All @@ -3011,34 +2983,16 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
}
}

static void
genOmpFlush(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
llvm::SmallVector<mlir::Value, 4> operandRange;
if (const auto &ompObjectList =
std::get<std::optional<Fortran::parser::OmpObjectList>>(
flushConstruct.t))
genObjectList(*ompObjectList, converter, operandRange);
const auto &memOrderClause =
std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
flushConstruct.t);
if (memOrderClause && memOrderClause->size() > 0)
TODO(converter.getCurrentLocation(), "Handle OmpMemoryOrderClause");
converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
converter.getCurrentLocation(), operandRange);
}

static void
genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) {
std::visit(
Fortran::common::visitors{
[&](const Fortran::parser::OpenMPSimpleStandaloneConstruct
&simpleStandaloneConstruct) {
genOMP(converter, semanticsContext, eval,
genOMP(converter, eval, semanticsContext,
simpleStandaloneConstruct);
},
[&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
Expand Down Expand Up @@ -3069,6 +3023,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,

static void genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, linearVars,
Expand Down Expand Up @@ -3103,7 +3058,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
if ((llvm::omp::allTargetSet & llvm::omp::loopConstructSet)
.test(ompDirective)) {
validDirective = true;
genTargetOp(converter, eval, currentLocation, loopOpClauseList,
genTargetOp(converter, semanticsContext, eval, currentLocation,
loopOpClauseList, ompDirective,
/*outerCombined=*/true);
}
if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet)
Expand Down Expand Up @@ -3284,10 +3240,11 @@ genOMP(Fortran::lower::AbstractConverter &converter,
endClauseList);
break;
case llvm::omp::Directive::OMPD_target:
genTargetOp(converter, eval, currentLocation, beginClauseList);
genTargetOp(converter, semanticsContext, eval, currentLocation,
beginClauseList, directive.v);
break;
case llvm::omp::Directive::OMPD_target_data:
genDataOp(converter, currentLocation, beginClauseList);
genDataOp(converter, semanticsContext, currentLocation, beginClauseList);
break;
case llvm::omp::Directive::OMPD_task:
genTaskOp(converter, eval, currentLocation, beginClauseList);
Expand All @@ -3307,7 +3264,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
bool combinedDirective = false;
if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet)
.test(directive.v)) {
genTargetOp(converter, eval, currentLocation, beginClauseList,
genTargetOp(converter, semanticsContext, eval, currentLocation,
beginClauseList, directive.v,
/*outerCombined=*/true);
combinedDirective = true;
}
Expand Down Expand Up @@ -3977,7 +3935,7 @@ void Fortran::lower::genOpenMPConstruct(
common::visitors{
[&](const Fortran::parser::OpenMPStandaloneConstruct
&standaloneConstruct) {
genOMP(converter, semanticsContext, eval, standaloneConstruct);
genOMP(converter, eval, semanticsContext, standaloneConstruct);
},
[&](const Fortran::parser::OpenMPSectionsConstruct
&sectionsConstruct) {
Expand All @@ -3987,7 +3945,7 @@ void Fortran::lower::genOpenMPConstruct(
genOMP(converter, eval, sectionConstruct);
},
[&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
genOMP(converter, eval, loopConstruct);
genOMP(converter, eval, semanticsContext, loopConstruct);
},
[&](const Fortran::parser::OpenMPDeclarativeAllocate
&execAllocConstruct) {
Expand Down
10 changes: 5 additions & 5 deletions flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
Original file line number Diff line number Diff line change
Expand Up @@ -271,16 +271,16 @@ func.func @_QPomp_target_data() {
// CHECK-LABEL: llvm.func @_QPomp_target_data() {
// CHECK: %0 = llvm.mlir.constant(1024 : index) : i64
// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_1:.*]] = llvm.alloca %[[VAL_0]] x !llvm.array<1024 x i32> {bindc_name = "a", in_type = !fir.array<1024xi32>, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEa"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %[[VAL_1:.*]] = llvm.alloca %[[VAL_0]] x !llvm.array<1024 x i32> {bindc_name = "a", in_type = !fir.array<1024xi32>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEa"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %3 = llvm.mlir.constant(1024 : index) : i64
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_3:.*]] = llvm.alloca %[[VAL_2]] x !llvm.array<1024 x i32> {bindc_name = "b", in_type = !fir.array<1024xi32>, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEb"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %[[VAL_3:.*]] = llvm.alloca %[[VAL_2]] x !llvm.array<1024 x i32> {bindc_name = "b", in_type = !fir.array<1024xi32>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEb"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %6 = llvm.mlir.constant(1024 : index) : i64
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_5:.*]] = llvm.alloca %[[VAL_4]] x !llvm.array<1024 x i32> {bindc_name = "c", in_type = !fir.array<1024xi32>, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEc"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %[[VAL_5:.*]] = llvm.alloca %[[VAL_4]] x !llvm.array<1024 x i32> {bindc_name = "c", in_type = !fir.array<1024xi32>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEc"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %9 = llvm.mlir.constant(1024 : index) : i64
// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_7:.*]] = llvm.alloca %[[VAL_6]] x !llvm.array<1024 x i32> {bindc_name = "d", in_type = !fir.array<1024xi32>, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEd"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %[[VAL_7:.*]] = llvm.alloca %[[VAL_6]] x !llvm.array<1024 x i32> {bindc_name = "d", in_type = !fir.array<1024xi32>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEd"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %12 = llvm.mlir.constant(1 : index) : i64
// CHECK: %13 = llvm.mlir.constant(0 : index) : i64
// CHECK: %14 = llvm.mlir.constant(1023 : index) : i64
Expand Down Expand Up @@ -364,7 +364,7 @@ func.func @_QPopenmp_target_data_region() {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_1:.*]] = llvm.alloca %[[VAL_0]] x !llvm.array<1024 x i32> {bindc_name = "a", in_type = !fir.array<1024xi32>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_data_regionEa"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_3:.*]] = llvm.alloca %[[VAL_2]] x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_data_regionEi"} : (i64) -> !llvm.ptr<i32>
// CHECK: %[[VAL_3:.*]] = llvm.alloca %[[VAL_2]] x i32 {bindc_name = "i", in_type = i32, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_data_regionEi"} : (i64) -> !llvm.ptr<i32>
// CHECK: %[[VAL_MAX:.*]] = llvm.mlir.constant(1024 : index) : i64
// CHECK: %[[VAL_ONE:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[VAL_ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
Expand Down
10 changes: 8 additions & 2 deletions flang/test/Lower/OpenMP/target.f90
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,17 @@ end subroutine omp_target_device_addr

!CHECK-LABEL: func.func @_QPomp_target_parallel_do() {
subroutine omp_target_parallel_do
!CHECK: %[[C1024:.*]] = arith.constant 1024 : index
!CHECK: %[[VAL_0:.*]] = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_parallel_doEa"}
integer :: a(1024)
!CHECK: %[[VAL_1:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFomp_target_parallel_doEi"}
integer :: i
!CHECK: omp.target map((tofrom -> %[[VAL_0]] : !fir.ref<!fir.array<1024xi32>>)) {
!CHECK: %[[C1:.*]] = arith.constant 1 : index
!CHECK: %[[C0:.*]] = arith.constant 0 : index
!CHECK: %[[SUB:.*]] = arith.subi %[[C1024]], %[[C1]] : index
!CHECK: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C0]] : index) upper_bound(%[[SUB]] : index) extent(%[[C1024]] : index) stride(%[[C1]] : index) start_idx(%[[C1]] : index)
!CHECK: %[[MAP:.*]] = omp.map_entry var_ptr(%[[VAL_0]] : !fir.ref<!fir.array<1024xi32>>) map_type_value(35) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target map_entries((tofrom -> %[[MAP]] : !fir.ref<!fir.array<1024xi32>>)) {
!CHECK-NEXT: omp.parallel
!$omp target parallel do map(tofrom: a)
!CHECK: %[[VAL_2:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
Expand All @@ -278,7 +284,7 @@ subroutine omp_target_parallel_do
!CHECK: omp.wsloop for (%[[VAL_6:.*]]) : i32 = (%[[VAL_3]]) to (%[[VAL_4]]) inclusive step (%[[VAL_5]]) {
!CHECK: fir.store %[[VAL_6]] to %[[VAL_2]] : !fir.ref<i32>
!CHECK: %[[VAL_7:.*]] = arith.constant 10 : i32
!CHECK: %[[VAL_8:.*]] = fir.load %2 : !fir.ref<i32>
!CHECK: %[[VAL_8:.*]] = fir.load %5 : !fir.ref<i32>
!CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (i32) -> i64
!CHECK: %[[VAL_10:.*]] = arith.constant 1 : i64
!CHECK: %[[VAL_11:.*]] = arith.subi %[[VAL_9]], %[[VAL_10]] : i64
Expand Down
Loading

0 comments on commit 38aabd4

Please sign in to comment.