Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/lospn2std opt select #61

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/Conversion/LoSPNtoCPU/NodePatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ namespace mlir {
ConversionPatternRewriter& rewriter) const override;
};

struct SelectLowering : public OpConversionPattern<low::SPNSelectLeaf> {

using OpConversionPattern<low::SPNSelectLeaf>::OpConversionPattern;

LogicalResult matchAndRewrite(low::SPNSelectLeaf op,
ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const override;
};

struct ResolveConvertToVector : public OpConversionPattern<low::SPNConvertToVector> {

using OpConversionPattern<low::SPNConvertToVector>::OpConversionPattern;
Expand Down Expand Up @@ -181,6 +190,7 @@ namespace mlir {
patterns.insert<MulLowering, AddLowering>(typeConverter, context);
patterns.insert<MulLogLowering, AddLogLowering>(typeConverter, context);
patterns.insert<CategoricalLowering, HistogramLowering>(typeConverter, context);
patterns.insert<SelectLowering, CategoricalLowering, HistogramLowering>(typeConverter, context);
patterns.insert<GaussianLowering, GaussianLogLowering>(typeConverter, context);
patterns.insert<ResolveConvertToVector, ResolveStripLog, ResolveConvertLog>(typeConverter, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ namespace mlir {

LogicalResult matchAndRewrite(low::SPNConvertLog op,
ArrayRef <Value> operands,
ConversionPatternRewriter& rewriter) const override;
};

struct VectorizeSelectLeaf : public OpConversionPattern<low::SPNSelectLeaf> {

using OpConversionPattern<low::SPNSelectLeaf>::OpConversionPattern;

LogicalResult matchAndRewrite(low::SPNSelectLeaf op,
ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const override;
};

Expand All @@ -177,6 +186,7 @@ namespace mlir {
patterns.insert<VectorizeLogAdd, VectorizeLogMul>(typeConverter, context, 2);
patterns.insert<VectorizeConstant>(typeConverter, context, 2);
patterns.insert<ResolveVectorizedStripLog, ResolveVectorizedConvertLog>(typeConverter, context, 2);
patterns.insert<VectorizeSelectLeaf>(typeConverter, context, 2);
}

}
Expand Down
1 change: 1 addition & 0 deletions mlir/include/Dialect/LoSPN/LoSPNOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/FunctionSupport.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
sommerlukas marked this conversation as resolved.
Show resolved Hide resolved

#define GET_OP_CLASSES
#include "LoSPN/LoSPNOps.h.inc"
Expand Down
26 changes: 25 additions & 1 deletion mlir/include/Dialect/LoSPN/LoSPNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,9 @@ def SPNHistogramLeaf : LoSPNBodyOp<"histogram", [NoSideEffect,
let arguments = (ins LoSPNInputType:$index, BucketListAttr:$buckets,
UI32Attr:$bucketCount, BoolAttr:$supportMarginal);

let results = (outs LoSPNComputeType);
let hasCanonicalizeMethod = 1;

let results = (outs LoSPNComputeType);
}

///
Expand All @@ -362,6 +363,8 @@ def SPNCategoricalLeaf : LoSPNBodyOp<"categorical", [NoSideEffect,

let arguments = (ins LoSPNInputType:$index, F64ArrayAttr:$probabilities, BoolAttr:$supportMarginal);

let hasCanonicalizeMethod = 1;

let results = (outs LoSPNComputeType);
}

Expand All @@ -383,4 +386,25 @@ def SPNGaussianLeaf : LoSPNBodyOp<"gaussian", [NoSideEffect,
let results = (outs LoSPNComputeType);
}

///
/// Select of an SPN leaf node value.
/// Corresponds to: ($input < $input_true_threshold) ? $val_true : $val_false;
///
def SPNSelectLeaf : LoSPNBodyOp<"select", [NoSideEffect,
DeclareOpInterfaceMethods<LoSPNVectorizable>,
DeclareOpInterfaceMethods<LeafNodeInterface>]> {

let summary = "Leaf node value select";

let description = [{
Single value select of a Categorical or Histogram leaf.
}];

let arguments = (ins LoSPNInputType:$input, F64Attr:$input_true_threshold,
F64Attr:$val_true, F64Attr:$val_false, BoolAttr:$supportMarginal);

let results = (outs LoSPNComputeType);

}

#endif // LoSPN_Ops
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ void mlir::spn::LoSPNNodeVectorizationPass::runOnOperation() {
target.addLegalOp<mlir::spn::low::SPNConvertToVector>();

OwningRewritePatternList patterns(&getContext());
patterns.insert<SelectLowering>(typeConverter, &getContext());
mhalk marked this conversation as resolved.
Show resolved Hide resolved
mlir::spn::populateLoSPNCPUVectorizationNodePatterns(patterns, &getContext(), typeConverter);

auto op = getOperation();
Expand Down
59 changes: 59 additions & 0 deletions mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ mlir::LogicalResult mlir::spn::GaussianLogLowering::matchAndRewrite(mlir::spn::l
Value gaussian = rewriter.create<mlir::AddFOp>(op->getLoc(), coefficientConst, fraction);
if (op.supportMarginal()) {
auto isNan = rewriter.create<mlir::CmpFOp>(op->getLoc(), CmpFPredicate::UNO, index, index);
// FixMe / Question: Could this be a bug? (Either rename to 'constZero' -OR- set to 1.0 (instead of 0.0))
mhalk marked this conversation as resolved.
Show resolved Hide resolved
auto constOne = rewriter.create<mlir::ConstantOp>(op.getLoc(), rewriter.getFloatAttr(resultType, 0.0));
gaussian = rewriter.create<mlir::SelectOp>(op.getLoc(), isNan, constOne, gaussian);
}
Expand Down Expand Up @@ -485,6 +486,64 @@ mlir::LogicalResult mlir::spn::CategoricalLowering::matchAndRewrite(mlir::spn::l
values, resultType, "categorical_", computesLog);
}

mlir::LogicalResult mlir::spn::SelectLowering::matchAndRewrite(mlir::spn::low::SPNSelectLeaf op,
llvm::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter& rewriter) const {
if (op.checkVectorized()) {
return rewriter.notifyMatchFailure(op, "Pattern only matches non-vectorized SelectLeaf");
}
// If the input type is not an integer, but also not a float, we cannot convert it and this pattern fails.
mlir::Value cond;
auto inputTy = op.input().getType();
if (inputTy.isa<mlir::FloatType>()) {
auto thresholdAttr = FloatAttr::get(inputTy, op.input_true_thresholdAttr().getValueAsDouble());
auto input_true_threshold = rewriter.create<mlir::ConstantOp>(op->getLoc(), inputTy, thresholdAttr);
cond = rewriter.create<mlir::CmpFOp>(op->getLoc(), IntegerType::get(op.getContext(), 1),
mlir::CmpFPredicate::ULT, op.input(), input_true_threshold);
} else if (inputTy.isa<mlir::IntegerType>()) {
auto thresholdAttr = IntegerAttr::get(inputTy, op.input_true_thresholdAttr().getValueAsDouble());
auto input_true_threshold = rewriter.create<mlir::ConstantOp>(op->getLoc(), inputTy, thresholdAttr);
cond = rewriter.create<mlir::CmpIOp>(op->getLoc(), IntegerType::get(op.getContext(), 1),
mlir::CmpIPredicate::ult, op.input(), input_true_threshold);
} else {
return rewriter.notifyMatchFailure(op, "Expected condition-value to be either Float- or IntegerType");
}

Type resultType = op.getResult().getType();
bool computesLog = false;
if (auto logType = resultType.dyn_cast<low::LogType>()) {
resultType = logType.getBaseType();
computesLog = true;
}

ConstantOp val_true, val_false;
if (computesLog) {
val_true = rewriter.create<mlir::ConstantOp>(op->getLoc(),
resultType,
FloatAttr::get(resultType, log(op.val_trueAttr().getValueAsDouble())));
val_false = rewriter.create<mlir::ConstantOp>(op->getLoc(),
resultType,
FloatAttr::get(resultType,
log(op.val_falseAttr().getValueAsDouble())));
} else {
val_true = rewriter.create<mlir::ConstantOp>(op->getLoc(), op.val_trueAttr().getType(), op.val_trueAttr());
val_false = rewriter.create<mlir::ConstantOp>(op->getLoc(), op.val_falseAttr().getType(), op.val_falseAttr());
}

mlir::Value leaf = rewriter.template create<SelectOp>(op.getLoc(), cond, val_true, val_false);
mhalk marked this conversation as resolved.
Show resolved Hide resolved
if (op.supportMarginal()) {
assert(inputTy.template isa<mlir::FloatType>());
auto isNan = rewriter.create<mlir::CmpFOp>(op->getLoc(), mlir::CmpFPredicate::UNO, op.input(), op.input());
auto marginalValue = (computesLog) ? 0.0 : 1.0;
auto constOne = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getFloatAttr(resultType, marginalValue));
leaf = rewriter.create<mlir::SelectOp>(op.getLoc(), isNan, constOne, leaf);
}
rewriter.replaceOp(op, leaf);

return success();
}

mlir::LogicalResult mlir::spn::ResolveConvertToVector::matchAndRewrite(mlir::spn::low::SPNConvertToVector op,
llvm::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter& rewriter) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -642,5 +642,72 @@ mlir::LogicalResult mlir::spn::ResolveVectorizedConvertLog::matchAndRewrite(mlir
return rewriter.notifyMatchFailure(op, "Could not resolve ConvertLog trivially");
}
rewriter.replaceOp(op, operands[0]);
return success();
}

mlir::LogicalResult mlir::spn::VectorizeSelectLeaf::matchAndRewrite(mlir::spn::low::SPNSelectLeaf op,
llvm::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter& rewriter) const {
if (!op.checkVectorized()) {
return rewriter.notifyMatchFailure(op, "Pattern only matches vectorized Select");
}

auto inputVec = operands.front();
auto inputVecTy = inputVec.getType();

// Input should be a vector
if (!inputVecTy.isa<VectorType>()) {
return rewriter.notifyMatchFailure(op, "Vectorization pattern did not match, input was not a vector");
}

auto inputTy = op.input().getType();
VectorType vectorType = inputVecTy.dyn_cast<VectorType>();

// If the input type is not an integer, but also not a float, we cannot convert it and this pattern fails.
mlir::Value cond;
auto booleanVectorTy = VectorType::get(vectorType.getShape(), IntegerType::get(op.getContext(), 1));
Value thresholdVec =
broadcastVectorConstant(vectorType, op.input_true_thresholdAttr().getValueAsDouble(), rewriter, op.getLoc());
if (inputTy.isa<mlir::FloatType>()) {
cond =
rewriter.create<mlir::CmpFOp>(op->getLoc(), booleanVectorTy, mlir::CmpFPredicate::ULT, inputVec, thresholdVec);
} else if (inputTy.isa<mlir::IntegerType>()) {
// Convert from floating-point input to integer value if necessary.
// This conversion is also possible in vectorized mode.
auto intVectorTy = VectorType::get(vectorType.getShape(), inputTy);
thresholdVec = rewriter.create<FPToSIOp>(op->getLoc(), thresholdVec, intVectorTy);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why convert a constant value? thresholdVec is vector of constant values, so it should not be created as vector of floats and then converted, but rather should be created as a vector integers in the first place.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely right, it'll be fixed.

cond =
rewriter.create<mlir::CmpIOp>(op->getLoc(), booleanVectorTy, mlir::CmpIPredicate::ult, inputVec, thresholdVec);
} else {
return rewriter.notifyMatchFailure(op, "Expected condition-value to be either Float- or IntegerType");
}

Type resultType = op.getResult().getType();
bool computesLog = false;
if (auto logType = resultType.dyn_cast<low::LogType>()) {
resultType = logType.getBaseType();
computesLog = true;
}

ConstantOp val_true, val_false;
if (computesLog) {
auto logVecTy = VectorType::get(vectorType.getShape(), resultType);
val_true = broadcastVectorConstant(logVecTy, log(op.val_trueAttr().getValueAsDouble()), rewriter, op.getLoc());
val_false = broadcastVectorConstant(logVecTy, log(op.val_falseAttr().getValueAsDouble()), rewriter, op.getLoc());
} else {
val_true = broadcastVectorConstant(vectorType, op.val_trueAttr().getValueAsDouble(), rewriter, op.getLoc());
val_false = broadcastVectorConstant(vectorType, op.val_falseAttr().getValueAsDouble(), rewriter, op.getLoc());
}

mlir::Value leaf = rewriter.template create<SelectOp>(op.getLoc(), cond, val_true, val_false);
mhalk marked this conversation as resolved.
Show resolved Hide resolved
if (op.supportMarginal()) {
assert(inputTy.template isa<mlir::FloatType>());
auto isNan = rewriter.create<mlir::CmpFOp>(op->getLoc(), CmpFPredicate::UNO, inputVec, inputVec);
auto marginalValue = (computesLog) ? 0.0 : 1.0;
auto constOne = broadcastVectorConstant(vectorType, marginalValue, rewriter, op.getLoc());
leaf = rewriter.create<mlir::SelectOp>(op.getLoc(), isNan, constOne, leaf);
}
rewriter.replaceOp(op, leaf);

return success();
}
61 changes: 60 additions & 1 deletion mlir/lib/Dialect/LoSPN/LoSPNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir {
namespace spn {
Expand Down Expand Up @@ -367,5 +366,65 @@ ::mlir::OpFoldResult mlir::spn::low::SPNAdd::fold(::llvm::ArrayRef<::mlir::Attri
return nullptr;
}

//===----------------------------------------------------------------------===//
// SPNCategoricalLeaf
//===----------------------------------------------------------------------===//

::mlir::LogicalResult mlir::spn::low::SPNCategoricalLeaf::canonicalize(SPNCategoricalLeaf op,
PatternRewriter& rewriter) {
// Rewrite Categoricals which contain exactly two probabilities into a LoSPN Select.
auto probabilities = op.probabilities().getValue();
if (probabilities.size() == 2) {
auto pTrue = probabilities[0].dyn_cast<FloatAttr>();
auto pFalse = probabilities[1].dyn_cast<FloatAttr>();
auto threshold_max_true = FloatAttr::get(op.index().getType(), 1.0);

rewriter.replaceOpWithNewOp<SPNSelectLeaf>(op,
op.getResult().getType(),
op.index(),
threshold_max_true,
pTrue,
pFalse,
op.supportMarginalAttr());
return success();
}
return failure();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, provide a message on failure by returning rewriter.notifyMatchFailure(...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ACK, an explanation is now provided, regarding the rewrite conditions.

}

//===----------------------------------------------------------------------===//
// SPNHistogramLeaf
//===----------------------------------------------------------------------===//

::mlir::LogicalResult mlir::spn::low::SPNHistogramLeaf::canonicalize(SPNHistogramLeaf op, PatternRewriter& rewriter) {
// Rewrite certain Histograms which contain exactly two buckets into a LoSPN Select.
// Buckets' index range must be 1 and buckets have to be consecutive / contiguous.
// i.e.: (UB_0-LB_0 == 1) && (UB_1-LB_1 == 1) && (UB_0 == LB_1)
auto buckets = op.buckets();
if (buckets.size() == 2) {
auto b0 = buckets[0].cast<mlir::spn::low::Bucket>();
auto b1 = buckets[1].cast<mlir::spn::low::Bucket>();

bool isQualifiedIndexRange = ((b0.ub().getInt() - b0.lb().getInt()) == 1) &&
((b1.ub().getInt() - b1.lb().getInt()) == 1);
bool isContiguous = (b0.ub().getInt() == b1.lb().getInt());

if (isQualifiedIndexRange && isContiguous) {
auto pTrue = b0.val();
auto pFalse = b1.val();
auto threshold_max_true = FloatAttr::get(Float64Type::get(op.getContext()), b0.ub().getInt());

rewriter.replaceOpWithNewOp<SPNSelectLeaf>(op,
op.getResult().getType(),
op.index(),
threshold_max_true,
pTrue,
pFalse,
op.supportMarginalAttr());
return success();
}
}
return failure();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, provide a message on failure by returning rewriter.notifyMatchFailure(...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ACK, an explanation is now provided, regarding the rewrite conditions.

}

#define GET_OP_CLASSES
#include "LoSPN/LoSPNOps.cpp.inc"
Loading