Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLVM update 43d71ba #3086

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/BuildOnLinuxOSX.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd ..
cd llvm-project && git checkout 43d71baae36c8d8b5a9995aa35efebe09cc9c2d6 && cd ..
```

[same-as-file]: <> (utils/build-mlir.sh)
Expand Down
2 changes: 1 addition & 1 deletion docs/BuildOnWindows.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd ..
cd llvm-project && git checkout 43d71baae36c8d8b5a9995aa35efebe09cc9c2d6 && cd ..
```

[same-as-file]: <> (utils/build-mlir.cmd)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ ApiRegistry RegisterAllApis(MLIRContext *context) {
auto int16Ty = IntegerType::get(context, 16);
auto int32Ty = IntegerType::get(context, 32);
auto int64Ty = IntegerType::get(context, 64);
auto float32Ty = FloatType::getF32(context);
auto float32Ty = Float32Type::get(context);

// Declare API type as an enum value, its string name and an LLVM Type
// specifying its signature.
Expand Down Expand Up @@ -570,7 +570,7 @@ Type getZTensorStructTy(MLIRContext *context) {
Type llvmI64Ty = IntegerType::get(context, 64);
Type llvmI1Ty = IntegerType::get(context, 1);
Type llvmI8Ty = IntegerType::get(context, 8);
Type llvmF32Ty = FloatType::getF32(context);
Type llvmF32Ty = Float32Type::get(context);
Type llvmArray3I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 3);
Type llvmArray20I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 20);
Type llvmI8PtrTy = krnl::getPointerType(context, llvmI8Ty);
Expand Down Expand Up @@ -662,7 +662,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
scaleTy.isF32() && "Wrong type for zTensor's rec_scale. Must be float");
create.llvm.store(recScale, recScalePtr);
} else {
Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.);
Value zero = create.llvm.constant(Float32Type::get(context), (double)0.);
create.llvm.store(zero, recScalePtr);
}

Expand All @@ -675,7 +675,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
offsetTy.isF32() && "Wrong type for zTensor's offset. Must be float");
create.llvm.store(offset, offsetPtr);
} else {
Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.);
Value zero = create.llvm.constant(Float32Type::get(context), (double)0.);
create.llvm.store(zero, offsetPtr);
}

Expand Down
10 changes: 0 additions & 10 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ std::vector<std::string> extraLibPaths; // onnx-mlir only
std::vector<std::string> extraLibs; // onnx-mlir only
ProfileIRs profileIR; // onnx-mlir only
OptReport optReport; // onnx-mlir only
bool useOldBufferization; // onnx-mlir only
bool enableTiming; // onnx-mlir only
bool enableBoundCheck; // onnx-mlir only
bool split_input_file; // onnx-mlir-opt only
Expand Down Expand Up @@ -755,15 +754,6 @@ static llvm::cl::opt<bool, true> allowUnregisteredDialectsOpt(
llvm::cl::location(allowUnregisteredDialects), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptOptions));

// Removed once the new LLVM bufferization works without performance regression.
static llvm::cl::opt<bool, true> useOldBufferizationOpt("use-old-bufferization",
llvm::cl::desc(
"Enable the old LLVM bufferization mechanism (default=true).\n"
"This option should be removed once the new LLVM bufferization works "
"well in onnx-mlir."),
llvm::cl::location(useOldBufferization), llvm::cl::init(true),
llvm::cl::cat(OnnxMlirOptions));

// Configuration states associated with certain options.
// For example, when maccel is specified, NNPA can register
// dependent libdnn.
Expand Down
1 change: 0 additions & 1 deletion src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ extern std::vector<std::string> extraLibPaths; // onnx-mlir only
extern std::vector<std::string> extraLibs; // onnx-mlir only
extern ProfileIRs profileIR; // onnx-mlir only
extern OptReport optReport; // onnx-mlir only
extern bool useOldBufferization; // onnx-mlir only
extern bool enableTiming; // onnx-mlir only
extern bool enableBoundCheck; // onnx-mlir only
extern bool debugTestCompilerOpt; // onnx-mlir only
Expand Down
13 changes: 4 additions & 9 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,10 @@ void addKrnlToLLVMPasses(
// Currently this has to be done *after* lowering the affine dialect because
// operations in that dialect do not conform to the requirements explained
// in https://mlir.llvm.org/docs/BufferDeallocationInternals.
if (useOldBufferization) {
pm.addNestedPass<func::FuncOp>(
mlir::bufferization::createBufferDeallocationPass());
} else {
bufferization::BufferDeallocationPipelineOptions bufferDeallocOptions;
mlir::bufferization::buildBufferDeallocationPipeline(
pm, bufferDeallocOptions);
pm.addPass(mlir::createBufferizationToMemRefPass());
}
bufferization::BufferDeallocationPipelineOptions bufferDeallocOptions;
mlir::bufferization::buildBufferDeallocationPipeline(
pm, bufferDeallocOptions);
pm.addPass(bufferization::createOwnershipBasedBufferDeallocationPass());

// Late introduction of OpenMP, after bufferization.
if (enableParallel) {
Expand Down
4 changes: 2 additions & 2 deletions src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ class KrnlRandomNormalOpLowering : public ConversionPattern {
// or
// (memref<3x4x5xf64>, index, f64, f64, f64)
Type llvmVoidTy = LLVM::LLVMVoidType::get(context);
Type llvmOptionsTy = FloatType::getF32(context);
Type llvmOptionsTy = Float32Type::get(context);
Type llvmOutputTy = getPointerType(context, llvmOptionsTy);
if (inType.isF64()) {
llvmOptionsTy = FloatType::getF64(context);
llvmOptionsTy = Float64Type::get(context);
llvmOutputTy = getPointerType(context, llvmOptionsTy);
}
Type llvmI64Ty = IntegerType::get(context, 64);
Expand Down
13 changes: 6 additions & 7 deletions src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,19 @@ class KrnlUnaryMathOpLowering : public ConversionPattern {
Type outType = op->getResultTypes().front();
Type llvmInType, llvmOutType;
if (inType.isF16())
llvmInType = FloatType::getF16(context);
llvmInType = Float16Type::get(context);
else if (inType.isF32())
llvmInType = FloatType::getF32(context);
llvmInType = Float32Type::get(context);
else if (inType.isF64())
llvmInType = FloatType::getF64(context);
llvmInType = Float64Type::get(context);
else if (inType.isBF16())
llvmInType = FloatType::getBF16(context);
llvmInType = Float64Type::get(context);
if (outType.isInteger(1))
llvmOutType = IntegerType::get(context, 1);
else if (outType.isF32())
llvmOutType = FloatType::getF32(context);
llvmOutType = Float32Type::get(context);
else if (outType.isF64())
llvmOutType = FloatType::getF64(context);
llvmOutType = Float64Type::get(context);

// Insert and/or get reference to elementary math function declaration.
assert(
Expand Down Expand Up @@ -214,7 +214,6 @@ class KrnlUnaryMathOpLowering : public ConversionPattern {
return SymbolRefAttr::get(context, mathFuncName);

// Create function declaration.
// auto llvmF32Ty = FloatType::get(context);
auto llvmFnType =
LLVM::LLVMFunctionType::get(llvmOutType, ArrayRef<Type>({llvmInType}));

Expand Down
4 changes: 2 additions & 2 deletions src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern {

// Get memRefDescriptor, the new memref descriptor.
MemRefDescriptor memRefDescriptor =
MemRefDescriptor::undef(rewriter, loc, targetStructType);
MemRefDescriptor::poison(rewriter, loc, targetStructType);
auto targetElementPtrType = memRefDescriptor.getElementPtrType();

// Set the new memref to the same buffer as the source memref.
Expand All @@ -78,7 +78,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern {

int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(targetType, strides, offset)))
if (failed(targetType.getStridesAndOffset(strides, offset)))
return failure();

// Unhandled dynamic offset.
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ struct ONNXCategoryMapperOpLowering
SmallVector<int64_t, 4> strides;
int64_t alignmentOffset; // not used, just to make the function call
// completed.
if (getStridesAndOffset(memRefType, strides, alignmentOffset)
if (memRefType.getStridesAndOffset(strides, alignmentOffset)
.failed())
llvm_unreachable("Failed to get strides");
Value stringMemRef =
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/Math/LRN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct ONNXLRNOpLowering : public OpConversionPattern<ONNXLRNOp> {
float alphaLit = adaptor.getAlpha().convertToFloat();
float betaLit = adaptor.getBeta().convertToFloat();
int sizeLit = adaptor.getSize();
auto f32Type = FloatType::getF32(rewriter.getContext());
auto f32Type = Float32Type::get(rewriter.getContext());
Value biasValue = create.math.constant(f32Type, biasLit);
Value alphaDivSizeValue =
create.math.constant(f32Type, alphaLit / static_cast<float>(sizeLit));
Expand Down
24 changes: 16 additions & 8 deletions src/Conversion/ONNXToTOSA/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"

#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
Expand Down Expand Up @@ -147,14 +148,16 @@ Value TosaBuilder::transpose(Value &value, llvm::ArrayRef<int32_t> perm) {

Value TosaBuilder::slice(Value &inputConst, llvm::ArrayRef<int64_t> size,
llvm::ArrayRef<int64_t> start) {
DenseI64ArrayAttr sizeAttr = rewriter().getDenseI64ArrayAttr(size);
DenseI64ArrayAttr startAttr = rewriter().getDenseI64ArrayAttr(start);
auto startVal =
mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(start));
auto sizeVal =
mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(size));
Value newSliceInput =
tosa::CreateOpAndInfer<mlir::tosa::SliceOp>(rewriter(), loc(),
RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(size.size(), ShapedType::kDynamic),
mlir::cast<ShapedType>(inputConst.getType()).getElementType()),
inputConst, startAttr, sizeAttr);
inputConst, startVal, sizeVal);
return newSliceInput;
}

Expand All @@ -164,8 +167,9 @@ Value TosaBuilder::reshape(Value &value, llvm::ArrayRef<int64_t> shape) {
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(shape.size(), ShapedType::kDynamic),
valueType.getElementType());
return tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(
rewriter(), loc(), newValueType, value, shapeAttr);
return tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter(), loc(),
newValueType, value,
mlir::tosa::getTosaConstShape(rewriter(), loc(), shapeAttr));
}

Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
Expand All @@ -178,8 +182,12 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());

auto int8Type = rewriter().getI8Type();
auto shiftValue = TosaBuilder::createConst(
ArrayRef<int8_t>{static_cast<int8_t>(shift)}, {1}, int8Type);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO its better to change the TosaBuilder::mul parameter shift to int8_t isntead of having this hidden cast inside the function

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate a bit more what you mean by "change the TosaBuilder::mul parameter shift" -- change it where?

Copy link
Collaborator

@jorickert jorickert Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the function signature: Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift).
Right now it takes an int32_t paramter and internally cast it to int8_t.
I would prefer it to change the function signature to Value TosaBuilder::mul(Value &lhs, Value &rhs, int8_t shift) so that it is clear for callers that it only supports int8 shift values

return tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
rewriter(), loc(), newValueType, lhs, rhs, shift);
rewriter(), loc(), newValueType, lhs, rhs, shiftValue);
}

Value TosaBuilder::intdiv(Value &lhs, Value &rhs) {
Expand Down Expand Up @@ -236,8 +244,8 @@ template Value TosaBuilder::binaryOp<mlir::tosa::SubOp>(Value &lhs, Value &rhs);
// Return null if none is found.
ElementsAttr IndexExprBuilderForTosa::getConst(Value value) {
auto definingOp = value.getDefiningOp();
// If we have a cast between index/integer, skip it, i.e. get the defining op
// that is the input to the cast.
// If we have a cast between index/integer, skip it, i.e. get the defining
// op that is the input to the cast.
if (auto castOp = dyn_cast_or_null<arith::IndexCastOp>(definingOp)) {
Value input = castOp.getIn();
definingOp = input.getDefiningOp();
Expand Down
20 changes: 15 additions & 5 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,21 @@ class ONNXReluOpLoweringToTOSA : public OpConversionPattern<ONNXReluOp> {
// Quantized types are not supported right now (in type conversion).
// Once they are, the input should be rescaled for quantized types. (TBD)
// Maps to `tosa.clamp` which has both int and fp limits.
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(op, op.getType(), input,
rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
rewriter.getF32FloatAttr(0.0f),
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
auto inputElementType =
llvm::cast<TensorType>(op.getType()).getElementType();
if (llvm::isa<IntegerType>(inputElementType)) {
auto minClamp = rewriter.getI64IntegerAttr(0);
auto maxClamp =
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max());
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(
op, op.getType(), input, minClamp, maxClamp);
} else {
auto minClamp = rewriter.getF32FloatAttr(0.0f);
auto maxClamp =
rewriter.getF32FloatAttr(std::numeric_limits<float>::max());
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(
op, op.getType(), input, minClamp, maxClamp);
}
return success();
}
};
Expand Down
76 changes: 4 additions & 72 deletions src/Conversion/ONNXToTOSA/Math/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
Expand All @@ -31,9 +32,6 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern<ONNXGemmOp> {
LogicalResult matchAndRewrite(ONNXGemmOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
TosaBuilder tosaBuilder(rewriter, op->getLoc());
// If legal, create a FullyConnected operator instead
if (rewriteToTosaFC(op, adaptor, rewriter, tosaBuilder))
return success();
return rewriteToTosaMatMul(op, adaptor, rewriter, tosaBuilder);
}

Expand Down Expand Up @@ -67,13 +65,14 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern<ONNXGemmOp> {

llvm::SmallVector<int64_t> dynamicTensorShape = {
ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic};

A = tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter, op->getLoc(),
RankedTensorType::get(dynamicTensorShape, AType.getElementType()), A,
rewriter.getDenseI64ArrayAttr(newShapeA))
mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeA))
.getResult();
B = tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter, op->getLoc(),
RankedTensorType::get(dynamicTensorShape, BType.getElementType()), B,
rewriter.getDenseI64ArrayAttr(newShapeB))
mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeB))
.getResult();

// If transA or transB are present, create Transpose operators.
Expand Down Expand Up @@ -149,73 +148,6 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern<ONNXGemmOp> {
// only need to check C[0].
return CShape[0] == AShape[0] || CShape[0] == BShape[0];
}

/// The GEMM can be described as a FullyConnected operator.
/// Y = AB^T + C if we perform a transpose on B only with.
/// alpha and beta factors set to 1.
/// Input A must be of rank 2 (input).
/// Input B must be of rank 2 (weights).
/// Input C must be of rank 1 (bias).
bool rewriteToTosaFC(ONNXGemmOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, TosaBuilder &tosaBuilder) const {
Value A = op.getA();
Value B = op.getB();
Value C = op.getC();

auto AType = mlir::cast<TensorType>(A.getType());
auto BType = mlir::cast<TensorType>(B.getType());

bool isCPresent = !mlir::isa<mlir::NoneType>(C.getType());
// If C is present, it can only be of rank 1, if the rank is not 1, return
// false.
if (mlir::isa<RankedTensorType>(C.getType()) &&
mlir::cast<RankedTensorType>(C.getType()).getRank() != 1)
return false;

// Input tensor must be of rank 2.
// Weights must also be of rank 2.
if (AType.getRank() != 2 || BType.getRank() != 2)
return false;

// Both alpha and beta must be 1.
if ((adaptor.getAlpha().convertToFloat() != 1.0F) ||
(adaptor.getBeta().convertToFloat() != 1.0F))
return false;

// Only Transpose B must be enabled.
if (adaptor.getTransA() != 0 || adaptor.getTransB() != 1)
return false;

// If all check passed, we replace the GEMM by a FC operator
Type resultType = getTypeConverter()->convertType(op.getResult().getType());

// Because the bias is not broadcastable for TOSA while it is for ONNX,
// we create an empty bias and use an add (broadcastable for tosa)
// afterwards.
// Base dummy C shape on B[0] shape.
bool needsBroadcasting = !hasCCorrectShape(AType, BType, C);
Value dummyC = C;
if (!isCPresent || needsBroadcasting) {
ArrayRef<int64_t> cformat(
mlir::cast<TensorType>(resultType).getShape()[1]);
std::vector<float> elements = {};
for (int i = 0; i < cformat[0]; ++i)
elements.push_back(0.0F);
dummyC = tosaBuilder.getConst(elements, cformat);
}

Value fcRes = tosa::CreateOpAndInfer<mlir::tosa::FullyConnectedOp>(
rewriter, op->getLoc(), resultType, A, B, dummyC)
.getResult();
// If C was present in the original GEMM, we create an add to take the bias
// into account.
if (isCPresent && needsBroadcasting)
fcRes = tosaBuilder.binaryOp<mlir::tosa::AddOp>(fcRes, C);

rewriter.replaceOp(op, fcRes);

return true;
}
};

} // namespace
Expand Down
Loading
Loading