Skip to content
This repository has been archived by the owner on Jan 20, 2024. It is now read-only.

Commit

Permalink
handle non-array simple allocatables
Browse files Browse the repository at this point in the history
  • Loading branch information
agozillon committed Nov 6, 2023
1 parent f741bbb commit 27808fb
Showing 1 changed file with 52 additions and 37 deletions.
89 changes: 52 additions & 37 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1805,7 +1805,8 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
// possible to use the BoundsOp's getStride, which points to the
// same field, however, unsure if the stride will change based
// on a user specified stride, whereas, the element size
// should never change.
// of the descriptor should never change and can be used to
// calculate other things.
if (memberClause.getIsFortranAllocatable().value_or(false)) {
llvm::Value *memberEleByteSize = builder.CreateLoad(
builder.getInt64Ty(),
Expand All @@ -1824,6 +1825,18 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
return builder.CreateMul(elementCount,
builder.getInt64(underlyingTypeSzInBits / 8));
}

// The case for allocatables that are not arrays, could perhaps treat these
// as just pointer size
if (memberClause.getBounds().empty() &&
memberClause.getIsFortranAllocatable().value_or(false)) {
return builder.CreateLoad(
builder.getInt64Ty(),
builder.CreateGEP(baseType, basePointer,
std::vector<llvm::Value *>{builder.getInt64(0),
builder.getInt32(1)},
"element_size"));
}
}

return builder.getInt64(underlyingTypeSzInBits / 8);
Expand Down Expand Up @@ -2147,51 +2160,53 @@ static void generateAllocatablesMapInfo(

combinedInfo.BasePointers.emplace_back(memberBase);

assert(!mapClauseInfo.getBounds().empty() &&
"missing map bounds for fortran allocatable, required for lowering");

llvm::Value *memberEleByteSize = builder.CreateLoad(
builder.getInt64Ty(),
builder.CreateGEP(
mapData.BaseType[mapDataIndex], mapData.BasePointers[mapDataIndex],
std::vector<llvm::Value *>{builder.getInt64(0), builder.getInt32(1)},
"size"));

std::vector<llvm::Value *> dimensionIndexSizeOffset{memberEleByteSize};
for (size_t i = 1; i < mapClauseInfo.getBounds().size(); ++i) {
llvm::Value *ubExtent = builder.CreateLoad(
llvm::Value *offsetAddress = nullptr;
if (!mapClauseInfo.getBounds().empty()) {
llvm::Value *memberEleByteSize = builder.CreateLoad(
builder.getInt64Ty(),
builder.CreateGEP(
mapData.BaseType[mapDataIndex], mapData.BasePointers[mapDataIndex],
std::vector<llvm::Value *>{builder.getInt64(0), builder.getInt32(7),
builder.getInt32(i),
builder.getInt32(1)},
"ub_extent"));
dimensionIndexSizeOffset.push_back(
builder.CreateMul(ubExtent, dimensionIndexSizeOffset[i - 1]));
}
builder.CreateGEP(mapData.BaseType[mapDataIndex],
mapData.BasePointers[mapDataIndex],
std::vector<llvm::Value *>{builder.getInt64(0),
builder.getInt32(1)},
"size"));

std::vector<llvm::Value *> dimensionIndexSizeOffset{memberEleByteSize};
for (size_t i = 1; i < mapClauseInfo.getBounds().size(); ++i) {
llvm::Value *ubExtent = builder.CreateLoad(
builder.getInt64Ty(),
builder.CreateGEP(mapData.BaseType[mapDataIndex],
mapData.BasePointers[mapDataIndex],
std::vector<llvm::Value *>{
builder.getInt64(0), builder.getInt32(7),
builder.getInt32(i), builder.getInt32(1)},
"ub_extent"));
dimensionIndexSizeOffset.push_back(
builder.CreateMul(ubExtent, dimensionIndexSizeOffset[i - 1]));
}

llvm::Value *offsetAddress = nullptr;
for (int i = mapClauseInfo.getBounds().size() - 1; i >= 0; --i) {
if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::DataBoundsOp>(
mapClauseInfo.getBounds()[i].getDefiningOp())) {
if (!offsetAddress)
offsetAddress = builder.CreateMul(
moduleTranslation.lookupValue(boundOp.getLowerBound()),
dimensionIndexSizeOffset[i]);
else
offsetAddress = builder.CreateAdd(
offsetAddress, builder.CreateMul(moduleTranslation.lookupValue(
boundOp.getLowerBound()),
dimensionIndexSizeOffset[i]));
for (int i = mapClauseInfo.getBounds().size() - 1; i >= 0; --i) {
if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::DataBoundsOp>(
mapClauseInfo.getBounds()[i].getDefiningOp())) {
if (!offsetAddress)
offsetAddress = builder.CreateMul(
moduleTranslation.lookupValue(boundOp.getLowerBound()),
dimensionIndexSizeOffset[i]);
else
offsetAddress = builder.CreateAdd(
offsetAddress, builder.CreateMul(moduleTranslation.lookupValue(
boundOp.getLowerBound()),
dimensionIndexSizeOffset[i]));
}
}
}

llvm::Value *loadMember =
builder.CreateLoad(memberBase->getType(), memberBase);
llvm::Value *memberIdx = builder.CreateGEP(
builder.getInt8Ty(), loadMember,
std::vector<llvm::Value *>{offsetAddress}, "member_idx");
std::vector<llvm::Value *>{offsetAddress ? offsetAddress
: builder.getInt64(0)},
"member_idx");
combinedInfo.Pointers.emplace_back(memberIdx);
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
}
Expand Down

0 comments on commit 27808fb

Please sign in to comment.