diff --git a/mlir/include/Conversion/LoSPNtoCPU/NodePatterns.h b/mlir/include/Conversion/LoSPNtoCPU/NodePatterns.h index 1963f0b5..1c7ac3e9 100644 --- a/mlir/include/Conversion/LoSPNtoCPU/NodePatterns.h +++ b/mlir/include/Conversion/LoSPNtoCPU/NodePatterns.h @@ -146,6 +146,15 @@ namespace mlir { ConversionPatternRewriter& rewriter) const override; }; + struct SelectLowering : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(low::SPNSelectLeaf op, + ArrayRef operands, + ConversionPatternRewriter& rewriter) const override; + }; + struct ResolveConvertToVector : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -181,6 +190,7 @@ namespace mlir { patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); + patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); } diff --git a/mlir/include/Conversion/LoSPNtoCPU/Vectorization/VectorizationPatterns.h b/mlir/include/Conversion/LoSPNtoCPU/Vectorization/VectorizationPatterns.h index 35f7bcf6..ad27dabc 100644 --- a/mlir/include/Conversion/LoSPNtoCPU/Vectorization/VectorizationPatterns.h +++ b/mlir/include/Conversion/LoSPNtoCPU/Vectorization/VectorizationPatterns.h @@ -164,6 +164,15 @@ namespace mlir { LogicalResult matchAndRewrite(low::SPNConvertLog op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const override; + }; + + struct VectorizeSelectLeaf : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(low::SPNSelectLeaf op, + ArrayRef operands, ConversionPatternRewriter& rewriter) const override; }; @@ -177,6 +186,7 @@ namespace mlir { patterns.insert(typeConverter, context, 2); patterns.insert(typeConverter, context, 2); patterns.insert(typeConverter, context, 2); + patterns.insert(typeConverter, context, 2); } } diff --git a/mlir/include/Dialect/LoSPN/LoSPNOps.h b/mlir/include/Dialect/LoSPN/LoSPNOps.h index e3038f23..d6f15ee8 100644 --- a/mlir/include/Dialect/LoSPN/LoSPNOps.h +++ b/mlir/include/Dialect/LoSPN/LoSPNOps.h @@ -19,6 +19,7 @@ #include "mlir/IR/FunctionSupport.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" #define GET_OP_CLASSES #include "LoSPN/LoSPNOps.h.inc" diff --git a/mlir/include/Dialect/LoSPN/LoSPNOps.td b/mlir/include/Dialect/LoSPN/LoSPNOps.td index d5f4b878..e7b1faad 100644 --- a/mlir/include/Dialect/LoSPN/LoSPNOps.td +++ b/mlir/include/Dialect/LoSPN/LoSPNOps.td @@ -440,8 +440,9 @@ def SPNHistogramLeaf : LoSPNBodyOp<"histogram", [NoSideEffect, let arguments = (ins LoSPNInputType:$index, BucketListAttr:$buckets, UI32Attr:$bucketCount, BoolAttr:$supportMarginal); - let results = (outs LoSPNComputeType); + let hasCanonicalizeMethod = 1; + let results = (outs LoSPNComputeType); } /// @@ -459,6 +460,8 @@ def SPNCategoricalLeaf : LoSPNBodyOp<"categorical", [NoSideEffect, let arguments = (ins LoSPNInputType:$index, F64ArrayAttr:$probabilities, BoolAttr:$supportMarginal); + let hasCanonicalizeMethod = 1; + let results = (outs LoSPNComputeType); } @@ -480,4 +483,25 @@ def SPNGaussianLeaf : LoSPNBodyOp<"gaussian", [NoSideEffect, let results = (outs LoSPNComputeType); } +/// +/// Select of an SPN leaf node value. +/// Corresponds to: ($input < $input_true_threshold) ? $val_true : $val_false; +/// +def SPNSelectLeaf : LoSPNBodyOp<"select", [NoSideEffect, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + + let summary = "Leaf node value select"; + + let description = [{ + Single value select of a Categorical or Histogram leaf. + }]; + + let arguments = (ins LoSPNInputType:$input, F64Attr:$input_true_threshold, + F64Attr:$val_true, F64Attr:$val_false, BoolAttr:$supportMarginal); + + let results = (outs LoSPNComputeType); + +} + #endif // LoSPN_Ops \ No newline at end of file diff --git a/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp b/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp index 24ef7906..eac14297 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp @@ -379,11 +379,28 @@ namespace { mlir::LogicalResult replaceOpWithGlobalMemref(SourceOp op, mlir::ConversionPatternRewriter& rewriter, mlir::Value indexOperand, llvm::ArrayRef arrayValues, mlir::Type resultType, const std::string& tablePrefix, - bool computesLog) { + bool computesLog, int lowerBound, int upperBound) { static int tableCount = 0; if (!resultType.isIntOrFloat()) { // Currently only handling Int and Float result types. - return mlir::failure(); + return rewriter.notifyMatchFailure(op, "Match failed because result is neither float nor integer"); + } + + // Convert input value from float to integer if necessary. + mlir::Value index = indexOperand; + if (!index.getType().isIntOrIndex()) { + // If the input type is not an integer, but also not a float, we cannot convert it and this pattern fails. + if (!index.getType().isIntOrFloat()) { + return rewriter.notifyMatchFailure(op, "Match failed because input is neither float nor integer/index"); + } + index = rewriter.template create(op.getLoc(), index, rewriter.getI64Type()); + } + auto integerIndex = index; + auto idxTy = index.getType(); + + // Cast input value to index if necessary. + if (!index.getType().isIndex()) { + index = rewriter.template create(op.getLoc(), rewriter.getIndexType(), index); } // Construct a DenseElementsAttr to hold the array values. @@ -404,34 +421,66 @@ namespace { // Restore insertion point rewriter.restoreInsertionPoint(restore); + auto idxMin = rewriter.create(op->getLoc(), + idxTy, + rewriter.getIntegerAttr(idxTy, lowerBound)); + auto idxMax = rewriter.create(op->getLoc(), + idxTy, + rewriter.getIntegerAttr(idxTy, upperBound)); + + auto boolTy = mlir::IntegerType::get(op.getContext(), 1); + auto inBoundsLB = + rewriter.create(op->getLoc(), boolTy, mlir::CmpIPredicate::sge, integerIndex, idxMin); + auto inBoundsUB = + rewriter.create(op->getLoc(), boolTy, mlir::CmpIPredicate::slt, integerIndex, idxMax); + auto inBounds = rewriter.create(op->getLoc(), inBoundsLB, inBoundsUB); + + // Perform range-check: if (inBoundsUB) { return memRefValue; } else { return defaultValue; } + auto boundaryIf = rewriter.create(op->getLoc(), resultType, inBounds, true); + + // If-Then branch: Legal range access -> memref load + rewriter.setInsertionPointToStart(&boundaryIf.thenRegion().front()); + // Use GetGlobalMemref operation to access the global created above. auto addressOf = rewriter.template create(op.getLoc(), memrefType, symbolName); - // Convert input value from float to integer if necessary. - mlir::Value index = indexOperand; - if (!index.getType().isIntOrIndex()) { - // If the input type is not an integer, but also not a float, we cannot convert it and this pattern fails. - if (!index.getType().isIntOrFloat()) { - return mlir::failure(); - } - index = rewriter.template create(op.getLoc(), index, rewriter.getI64Type()); - } - // Cast input value to index if necessary. - if (!index.getType().isIndex()) { - index = rewriter.template create(op.getLoc(), rewriter.getIndexType(), index); - } + // Replace the source operation with a load from the global memref, // using the source operation's input value as index. - mlir::Value leaf = rewriter.template create(op.getLoc(), addressOf, mlir::ValueRange{index}); + mlir::Value leafThen = rewriter.template create(op.getLoc(), addressOf, mlir::ValueRange{index}); + + // If-Else branch: Illegal range access -> default return + rewriter.setInsertionPointToStart(&boundaryIf.elseRegion().front()); + double defaultValue = (computesLog) ? static_cast(-INFINITY) : 0.0; + mlir::Value leafElse = rewriter.create(op.getLoc(), resultType, rewriter.getFloatAttr(resultType, defaultValue)); + if (op.supportMarginal()) { - assert(indexOperand.getType().template isa()); + // Set the insertion point right before the if-op + rewriter.setInsertionPoint(boundaryIf); + // Define the values needed for marginal-support (for both [then/else] branches) auto isNan = rewriter.create(op->getLoc(), mlir::CmpFPredicate::UNO, indexOperand, indexOperand); auto marginalValue = (computesLog) ? 0.0 : 1.0; auto constOne = rewriter.create(op.getLoc(), rewriter.getFloatAttr(resultType, marginalValue)); - leaf = rewriter.create(op.getLoc(), isNan, constOne, leaf); + + // Place select and yield ops in the corresponding branches + rewriter.setInsertionPointToEnd(&boundaryIf.thenRegion().back()); + mlir::Value leaf = rewriter.create(op.getLoc(), isNan, constOne, leafThen); + rewriter.create(op->getLoc(), leaf); + + rewriter.setInsertionPointToEnd(&boundaryIf.elseRegion().back()); + leaf = rewriter.create(op.getLoc(), isNan, constOne, leafElse); + rewriter.create(op->getLoc(), leaf); + } else { + // Place yield ops in the corresponding branches + rewriter.setInsertionPointToEnd(&boundaryIf.thenRegion().back()); + rewriter.create(op->getLoc(), leafThen); + rewriter.setInsertionPointToEnd(&boundaryIf.elseRegion().back()); + rewriter.create(op->getLoc(), leafElse); } - rewriter.replaceOp(op, leaf); + + // Replace access by result(s) of created if-op + rewriter.replaceOp(op, boundaryIf.results()); return mlir::success(); } @@ -451,7 +500,7 @@ mlir::LogicalResult mlir::spn::HistogramLowering::matchAndRewrite(mlir::spn::low llvm::DenseMap values; int minLB = std::numeric_limits::max(); int maxUB = std::numeric_limits::min(); - for (auto& b : op.bucketsAttr()) { + for (auto& b: op.bucketsAttr()) { auto bucket = b.cast(); auto lb = bucket.lb().getInt(); auto ub = bucket.ub().getInt(); @@ -486,7 +535,7 @@ mlir::LogicalResult mlir::spn::HistogramLowering::matchAndRewrite(mlir::spn::low if (values.count(i)) { indexVal = (computesLog) ? log(values[i]) : values[i]; } else { - // Fill up with 0 if no value was defined by the histogram. + // Fill up with corresponding zero if no value was defined by the histogram. indexVal = (computesLog) ? static_cast(-INFINITY) : 0; } // Construct attribute with constant value. Need to distinguish cases here due to different builder methods. @@ -498,7 +547,7 @@ mlir::LogicalResult mlir::spn::HistogramLowering::matchAndRewrite(mlir::spn::low } return replaceOpWithGlobalMemref(op, rewriter, operands[0], valArray, - resultType, "histogram_", computesLog); + resultType, "histogram_", computesLog, minLB, maxUB); } mlir::LogicalResult mlir::spn::CategoricalLowering::matchAndRewrite(mlir::spn::low::SPNCategoricalLeaf op, llvm::ArrayRef operands, @@ -514,8 +563,9 @@ mlir::LogicalResult mlir::spn::CategoricalLowering::matchAndRewrite(mlir::spn::l resultType = logType.getBaseType(); computesLog = true; } + SmallVector values; - for (auto val : op.probabilities().getValue()) { + for (auto val: op.probabilities().getValue()) { if (computesLog) { auto floatVal = val.dyn_cast(); assert(floatVal); @@ -524,8 +574,115 @@ mlir::LogicalResult mlir::spn::CategoricalLowering::matchAndRewrite(mlir::spn::l values.push_back(val); } } - return replaceOpWithGlobalMemref(op, rewriter, operands[0], - values, resultType, "categorical_", computesLog); + return replaceOpWithGlobalMemref(op, + rewriter, + operands[0], + values, + resultType, + "categorical_", + computesLog, + 0, + op.probabilities().size()); +} + +mlir::LogicalResult mlir::spn::SelectLowering::matchAndRewrite(mlir::spn::low::SPNSelectLeaf op, + llvm::ArrayRef operands, + mlir::ConversionPatternRewriter& rewriter) const { + if (op.checkVectorized()) { + return rewriter.notifyMatchFailure(op, "Pattern only matches non-vectorized SelectLeaf"); + } + + auto inputTy = op.input().getType(); + if (!inputTy.isIntOrFloat()) { + return rewriter.notifyMatchFailure(op, "Input type should be either int or float."); + } + + mlir::Value cond; + mlir::Value idxMin, idxMax; + mlir::Value inBoundsLB, inBoundsUB; + auto boolTy = IntegerType::get(op.getContext(), 1); + // Regarding the actual LB/UB values we will use the knowledge that: + // (UB - LB = 2) && input_true_thresholdAttr == LB + 1 == UB - 1 + if (inputTy.isa()) { + idxMin = rewriter.create(op->getLoc(), + rewriter.getF64Type(), + rewriter.getFloatAttr(rewriter.getF64Type(), + op.input_true_thresholdAttr().getValueAsDouble() + - 1.0)); + idxMax = rewriter.create(op->getLoc(), + rewriter.getF64Type(), + rewriter.getFloatAttr(rewriter.getF64Type(), + op.input_true_thresholdAttr().getValueAsDouble() + + 1.0)); + inBoundsLB = + rewriter.create(op->getLoc(), boolTy, mlir::CmpFPredicate::UGE, op.input(), idxMin); + inBoundsUB = + rewriter.create(op->getLoc(), boolTy, mlir::CmpFPredicate::ULT, op.input(), idxMax); + auto thresholdAttr = FloatAttr::get(inputTy, op.input_true_thresholdAttr().getValueAsDouble()); + auto input_true_threshold = rewriter.create(op->getLoc(), inputTy, thresholdAttr); + cond = + rewriter.create(op->getLoc(), boolTy, mlir::CmpFPredicate::ULT, op.input(), input_true_threshold); + } else if (inputTy.isa()) { + idxMin = rewriter.create(op->getLoc(), + op.input().getType(), + rewriter.getIntegerAttr(op.input().getType(), + op.input_true_thresholdAttr().getValueAsDouble() + - 1.0)); + idxMax = rewriter.create(op->getLoc(), + op.input().getType(), + rewriter.getIntegerAttr(op.input().getType(), + op.input_true_thresholdAttr().getValueAsDouble() + + 1.0)); + inBoundsLB = + rewriter.create(op->getLoc(), boolTy, mlir::CmpIPredicate::sge, op.input(), idxMin); + inBoundsUB = + rewriter.create(op->getLoc(), boolTy, mlir::CmpIPredicate::slt, op.input(), idxMax); + auto thresholdAttr = IntegerAttr::get(inputTy, op.input_true_thresholdAttr().getValueAsDouble()); + auto input_true_threshold = rewriter.create(op->getLoc(), inputTy, thresholdAttr); + cond = + rewriter.create(op->getLoc(), boolTy, mlir::CmpIPredicate::ult, op.input(), input_true_threshold); + } + + Type resultType = op.getResult().getType(); + bool computesLog = false; + if (auto logType = resultType.dyn_cast()) { + resultType = logType.getBaseType(); + computesLog = true; + } + + ConstantOp val_true, val_false; + if (computesLog) { + val_true = rewriter.create(op->getLoc(), + resultType, + FloatAttr::get(resultType, log(op.val_trueAttr().getValueAsDouble()))); + val_false = rewriter.create(op->getLoc(), + resultType, + FloatAttr::get(resultType, + log(op.val_falseAttr().getValueAsDouble()))); + } else { + val_true = rewriter.create(op->getLoc(), op.val_trueAttr().getType(), op.val_trueAttr()); + val_false = rewriter.create(op->getLoc(), op.val_falseAttr().getType(), op.val_falseAttr()); + } + + mlir::Value leaf = rewriter.create(op.getLoc(), cond, val_true, val_false); + if (op.supportMarginal()) { + assert(inputTy.isa()); + auto isNan = rewriter.create(op->getLoc(), mlir::CmpFPredicate::UNO, op.input(), op.input()); + auto marginalValue = (computesLog) ? 0.0 : 1.0; + auto constOne = rewriter.create(op.getLoc(), + rewriter.getFloatAttr(resultType, marginalValue)); + leaf = rewriter.create(op.getLoc(), isNan, constOne, leaf); + } + + auto inBounds = rewriter.create(op->getLoc(), inBoundsLB, inBoundsUB); + double defaultValue = (computesLog) ? static_cast(-INFINITY) : 0.0; + mlir::Value defaultReturn = + rewriter.create(op.getLoc(), resultType, rewriter.getFloatAttr(resultType, defaultValue)); + + leaf = rewriter.create(op.getLoc(), inBounds, leaf, defaultReturn); + rewriter.replaceOp(op, leaf); + + return success(); } mlir::LogicalResult mlir::spn::ResolveConvertToVector::matchAndRewrite(mlir::spn::low::SPNConvertToVector op, diff --git a/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp b/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp index 2c43dce1..d9f997f7 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp @@ -462,6 +462,8 @@ namespace { mlir::LogicalResult replaceOpWithGatherFromGlobalMemref(SourceOp op, mlir::ConversionPatternRewriter& rewriter, mlir::Value indexOperand, + int lowerBound, + int upperBound, llvm::ArrayRef arrayValues, mlir::Type resultType, const std::string& tablePrefix, bool computesLog) { @@ -469,7 +471,7 @@ namespace { auto inputType = indexOperand.getType(); if (!inputType.template isa()) { // This pattern only handles vectorized implementations and fails if the input is not a vector. - return mlir::failure(); + return rewriter.notifyMatchFailure(op, "Match failed because input is not a VectorType"); } // Construct a DenseElementsAttr to hold the array values. @@ -499,20 +501,53 @@ namespace { auto indexType = inputType.template dyn_cast().getElementType(); if (!indexType.isIntOrIndex()) { if (indexType.template isa()) { + indexType = rewriter.getI64Type(); index = rewriter.template create(op.getLoc(), index, - mlir::VectorType::get(vectorShape, rewriter.getI64Type())); + mlir::VectorType::get(vectorShape, indexType)); } else { // The input type is neither int/index nor float, conversion unknown, fail this pattern. - return mlir::failure(); + return rewriter.notifyMatchFailure(op, "Match failed because input is neither float nor integer/index"); } } // Construct the constant pass-thru value (values used if the mask is false for an element of the vector). auto vectorType = mlir::VectorType::get(vectorShape, resultType); - auto passThru = broadcastVectorConstant(vectorType, 0.0, + double defaultValue = (computesLog) ? static_cast(-INFINITY) : 0.0; + auto passThru = broadcastVectorConstant(vectorType, defaultValue, rewriter, op->getLoc()); - // Construct the constant mask. - auto mask = broadcastVectorConstant(mlir::VectorType::get(vectorShape, rewriter.getI1Type()), true, - rewriter, op->getLoc()); + + mlir::ConstantOp idxMinVec; + mlir::ConstantOp idxMaxVec; + if (indexType.isInteger(32)) { + idxMinVec = + broadcastVectorConstant(mlir::VectorType::get(vectorShape, indexType), lowerBound, rewriter, op->getLoc()); + idxMaxVec = + broadcastVectorConstant(mlir::VectorType::get(vectorShape, indexType), upperBound, rewriter, op->getLoc()); + } else { + idxMinVec = broadcastVectorConstant(mlir::VectorType::get(vectorShape, indexType), + static_cast(lowerBound), + rewriter, + op->getLoc()); + idxMaxVec = broadcastVectorConstant(mlir::VectorType::get(vectorShape, indexType), + static_cast(upperBound), + rewriter, + op->getLoc()); + } + + auto boolVecTy = mlir::VectorType::get(vectorShape, rewriter.getI1Type()); + auto inBoundsLB = + rewriter.create(op->getLoc(), boolVecTy, mlir::CmpIPredicate::sge, index, idxMinVec); + auto inBoundsUB = + rewriter.create(op->getLoc(), boolVecTy, mlir::CmpIPredicate::slt, index, idxMaxVec); + auto inBounds = rewriter.create(op->getLoc(), inBoundsLB, inBoundsUB); + auto trueVec = + broadcastVectorConstant(mlir::VectorType::get(vectorShape, rewriter.getI1Type()), true, rewriter, op->getLoc()); + auto falseVec = broadcastVectorConstant(mlir::VectorType::get(vectorShape, rewriter.getI1Type()), + false, + rewriter, + op->getLoc()); + // Create mask, dependant on an "inBounds" select -> Mask all out-of-bounds accesses. + auto mask = rewriter.create(op.getLoc(), inBounds, trueVec, falseVec); + // Replace the source operation with a gather load from the global memref. mlir::Value constIndex = rewriter.template create(op.getLoc(), rewriter.getIndexAttr(0)); mlir::Value leaf = rewriter.template create(op.getLoc(), vectorType, addressOf, @@ -555,7 +590,8 @@ mlir::LogicalResult mlir::spn::VectorizeCategorical::matchAndRewrite(mlir::spn:: values.push_back(val); } } - return replaceOpWithGatherFromGlobalMemref(op, rewriter, operands[0], + int idxMax = op.probabilities().size(); + return replaceOpWithGatherFromGlobalMemref(op, rewriter, operands[0], 0, idxMax, values, resultType, "categorical_vec_", computesLog); } @@ -574,7 +610,7 @@ mlir::LogicalResult mlir::spn::VectorizeHistogram::matchAndRewrite(mlir::spn::lo llvm::DenseMap values; int minLB = std::numeric_limits::max(); int maxUB = std::numeric_limits::min(); - for (auto& b : op.bucketsAttr()) { + for (auto& b: op.bucketsAttr()) { auto bucket = b.cast(); auto lb = bucket.lb().getInt(); auto ub = bucket.ub().getInt(); @@ -620,7 +656,7 @@ mlir::LogicalResult mlir::spn::VectorizeHistogram::matchAndRewrite(mlir::spn::lo } } - return replaceOpWithGatherFromGlobalMemref(op, rewriter, operands[0], valArray, + return replaceOpWithGatherFromGlobalMemref(op, rewriter, operands[0], minLB, maxUB, valArray, resultType, "histogram_vec_", computesLog); } @@ -658,5 +694,70 @@ mlir::LogicalResult mlir::spn::ResolveVectorizedConvertLog::matchAndRewrite(mlir return rewriter.notifyMatchFailure(op, "Could not resolve ConvertLog trivially"); } rewriter.replaceOp(op, operands[0]); + return success(); +} + +mlir::LogicalResult mlir::spn::VectorizeSelectLeaf::matchAndRewrite(mlir::spn::low::SPNSelectLeaf op, + llvm::ArrayRef operands, + mlir::ConversionPatternRewriter& rewriter) const { + if (!op.checkVectorized()) { + return rewriter.notifyMatchFailure(op, "Pattern only matches vectorized Select"); + } + + auto inputVec = operands.front(); + auto inputVecTy = inputVec.getType(); + + // Input should be a vector + if (!inputVecTy.isa()) { + return rewriter.notifyMatchFailure(op, "Vectorization pattern did not match, input was not a vector"); + } + + auto inputTy = op.input().getType(); + VectorType vectorType = inputVecTy.dyn_cast(); + + // If the input type is not an integer, but also not a float, we cannot convert it and this pattern fails. + mlir::Value cond; + auto booleanVectorTy = VectorType::get(vectorType.getShape(), IntegerType::get(op.getContext(), 1)); + if (inputTy.isa()) { + auto thresholdVec = + broadcastVectorConstant(vectorType, op.input_true_thresholdAttr().getValueAsDouble(), rewriter, op.getLoc()); + cond = + rewriter.create(op->getLoc(), booleanVectorTy, mlir::CmpFPredicate::ULT, inputVec, thresholdVec); + } else if (inputTy.isa()) { + auto thresholdVec = + broadcastVectorConstant(vectorType, op.input_true_thresholdAttr().getValueAsDouble(), rewriter, op.getLoc()); + cond = + rewriter.create(op->getLoc(), booleanVectorTy, mlir::CmpIPredicate::ult, inputVec, thresholdVec); + } else { + return rewriter.notifyMatchFailure(op, "Expected condition-value to be either Float- or IntegerType"); + } + + Type resultType = op.getResult().getType(); + bool computesLog = false; + if (auto logType = resultType.dyn_cast()) { + resultType = logType.getBaseType(); + computesLog = true; + } + + ConstantOp val_true, val_false; + if (computesLog) { + auto logVecTy = VectorType::get(vectorType.getShape(), resultType); + val_true = broadcastVectorConstant(logVecTy, log(op.val_trueAttr().getValueAsDouble()), rewriter, op.getLoc()); + val_false = broadcastVectorConstant(logVecTy, log(op.val_falseAttr().getValueAsDouble()), rewriter, op.getLoc()); + } else { + val_true = broadcastVectorConstant(vectorType, op.val_trueAttr().getValueAsDouble(), rewriter, op.getLoc()); + val_false = broadcastVectorConstant(vectorType, op.val_falseAttr().getValueAsDouble(), rewriter, op.getLoc()); + } + + mlir::Value leaf = rewriter.create(op.getLoc(), cond, val_true, val_false); + if (op.supportMarginal()) { + assert(inputTy.isa()); + auto isNan = rewriter.create(op->getLoc(), CmpFPredicate::UNO, inputVec, inputVec); + auto marginalValue = (computesLog) ? 0.0 : 1.0; + auto constOne = broadcastVectorConstant(vectorType, marginalValue, rewriter, op.getLoc()); + leaf = rewriter.create(op.getLoc(), isNan, constOne, leaf); + } + rewriter.replaceOp(op, leaf); + return success(); } \ No newline at end of file diff --git a/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp b/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp index da91dea9..b14e874a 100644 --- a/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp +++ b/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp @@ -10,10 +10,10 @@ #include "LoSPN/LoSPNDialect.h" #include "LoSPN/LoSPNAttributes.h" #include "LoSPN/LoSPNInterfaces.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" namespace mlir { namespace spn { @@ -419,5 +419,65 @@ ::mlir::OpFoldResult mlir::spn::low::SPNAdd::fold(::llvm::ArrayRef<::mlir::Attri return nullptr; } +//===----------------------------------------------------------------------===// +// SPNCategoricalLeaf +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::spn::low::SPNCategoricalLeaf::canonicalize(SPNCategoricalLeaf op, + PatternRewriter& rewriter) { + // Rewrite Categoricals which contain exactly two probabilities into a LoSPN Select. + auto probabilities = op.probabilities().getValue(); + if (probabilities.size() == 2) { + auto pTrue = probabilities[0].dyn_cast(); + auto pFalse = probabilities[1].dyn_cast(); + auto threshold_max_true = FloatAttr::get(op.index().getType(), 1.0); + + rewriter.replaceOpWithNewOp(op, + op.getResult().getType(), + op.index(), + threshold_max_true, + pTrue, + pFalse, + op.supportMarginalAttr()); + return success(); + } + return rewriter.notifyMatchFailure(op, "Categorical held != 2 probabilities (no reduction to select possible)"); +} + +//===----------------------------------------------------------------------===// +// SPNHistogramLeaf +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::spn::low::SPNHistogramLeaf::canonicalize(SPNHistogramLeaf op, PatternRewriter& rewriter) { + // Rewrite certain Histograms which contain exactly two buckets into a LoSPN Select. + // Buckets' index range must be 1 and buckets have to be consecutive / contiguous. + // i.e.: (UB_0-LB_0 == 1) && (UB_1-LB_1 == 1) && (UB_0 == LB_1) + auto buckets = op.buckets(); + if (buckets.size() == 2) { + auto b0 = buckets[0].cast(); + auto b1 = buckets[1].cast(); + + bool isQualifiedIndexRange = ((b0.ub().getInt() - b0.lb().getInt()) == 1) && + ((b1.ub().getInt() - b1.lb().getInt()) == 1); + bool isContiguous = (b0.ub().getInt() == b1.lb().getInt()); + + if (isQualifiedIndexRange && isContiguous) { + auto pTrue = b0.val(); + auto pFalse = b1.val(); + auto threshold_max_true = FloatAttr::get(Float64Type::get(op.getContext()), b0.ub().getInt()); + + rewriter.replaceOpWithNewOp(op, + op.getResult().getType(), + op.index(), + threshold_max_true, + pTrue, + pFalse, + op.supportMarginalAttr()); + return success(); + } + } + return rewriter.notifyMatchFailure(op, "Histogram was not eligible for reduction to select"); +} + #define GET_OP_CLASSES #include "LoSPN/LoSPNOps.cpp.inc" \ No newline at end of file diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-boundary-check.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-boundary-check.mlir new file mode 100644 index 00000000..5d85d270 --- /dev/null +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-boundary-check.mlir @@ -0,0 +1,76 @@ +// RUN: %optcall --convert-lospn-nodes-to-cpu %s | FileCheck %s + +module { + func @task_0(%arg0: memref, %arg1: memref<1x?xf64>) { + %c0 = constant 0 : index + %c0_0 = constant 0 : index + %0 = memref.dim %arg0, %c0_0 : memref + %c1 = constant 1 : index + scf.for %arg2 = %c0 to %0 step %c1 { + %1 = "lo_spn.batch_read"(%arg0, %arg2) {staticIndex = 0 : ui32} : (memref, index) -> i32 + %2 = "lo_spn.select"(%1) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 7.500000e-01 : f64, val_true = 2.500000e-01 : f64} : (i32) -> f64 + %3 = "lo_spn.log"(%2) : (f64) -> f64 + "lo_spn.batch_write"(%arg1, %arg2, %3) {transposed = true} : (memref<1x?xf64>, index, f64) -> () + } + return + } + func @spn_kernel(%arg0: memref, %arg1: memref<1x?xf64>) { + %c0 = constant 0 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.alloc(%0) : memref<1x?xf64> + call @task_0(%arg0, %1) : (memref, memref<1x?xf64>) -> () + "lo_spn.copy"(%1, %arg1) : (memref<1x?xf64>, memref<1x?xf64>) -> () + "lo_spn.return"() : () -> () + } +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + + +// CHECK-LABEL: func @task_0( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x?xf64>) { +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_5]] { +// CHECK: %[[VAL_7:.*]] = constant 0 : index +// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_7]]] : memref +// CHECK: %[[VAL_9:.*]] = constant 0 : i32 +// CHECK: %[[VAL_10:.*]] = constant 2 : i32 +// CHECK: %[[VAL_11:.*]] = cmpi sge, %[[VAL_8]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_12:.*]] = cmpi slt, %[[VAL_8]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_13:.*]] = constant 1 : i32 +// CHECK: %[[VAL_14:.*]] = cmpi ult, %[[VAL_8]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_15:.*]] = constant 2.500000e-01 : f64 +// CHECK: %[[VAL_16:.*]] = constant 7.500000e-01 : f64 +// CHECK: %[[VAL_17:.*]] = select %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] : f64 +// CHECK: %[[VAL_18:.*]] = and %[[VAL_11]], %[[VAL_12]] : i1 +// CHECK: %[[VAL_19:.*]] = constant 0.000000e+00 : f64 +// CHECK: %[[VAL_20:.*]] = select %[[VAL_18]], %[[VAL_17]], %[[VAL_19]] : f64 +// CHECK: %[[VAL_21:.*]] = math.log %[[VAL_20]] : f64 +// CHECK: %[[VAL_22:.*]] = constant 0 : index +// CHECK: memref.store %[[VAL_21]], %[[VAL_1]]{{\[}}%[[VAL_22]], %[[VAL_6]]] : memref<1x?xf64> +// CHECK: } +// CHECK: return +// CHECK: } + +// CHECK-LABEL: func @spn_kernel( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x?xf64>) { +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = memref.dim %[[VAL_0]], %[[VAL_2]] : memref +// CHECK: %[[VAL_4:.*]] = memref.alloc(%[[VAL_3]]) : memref<1x?xf64> +// CHECK: call @task_0(%[[VAL_0]], %[[VAL_4]]) : (memref, memref<1x?xf64>) -> () +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_4]], %[[VAL_5]] : memref<1x?xf64> +// CHECK: %[[VAL_7:.*]] = constant 0 : index +// CHECK: %[[VAL_8:.*]] = constant 1 : index +// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] { +// CHECK: %[[VAL_10:.*]] = constant 0 : index +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_10]], %[[VAL_9]]] : memref<1x?xf64> +// CHECK: memref.store %[[VAL_11]], %[[VAL_1]]{{\[}}%[[VAL_10]], %[[VAL_9]]] : memref<1x?xf64> +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar-log.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar-log.mlir index 1cfed71b..095e354a 100644 --- a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar-log.mlir +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar-log.mlir @@ -383,53 +383,97 @@ module { // CHECK: %[[VAL_165:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_155]], %[[VAL_164]]] : memref // CHECK: %[[VAL_166:.*]] = constant 5 : index // CHECK: %[[VAL_167:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_155]], %[[VAL_166]]] : memref -// CHECK: %[[VAL_168:.*]] = memref.get_global @categorical_0 : memref<3xf32> -// CHECK: %[[VAL_169:.*]] = fptoui %[[VAL_157]] : f32 to i64 -// CHECK: %[[VAL_170:.*]] = index_cast %[[VAL_169]] : i64 to index -// CHECK: %[[VAL_171:.*]] = memref.load %[[VAL_168]]{{\[}}%[[VAL_170]]] : memref<3xf32> -// CHECK: %[[VAL_172:.*]] = memref.get_global @categorical_1 : memref<3xf32> -// CHECK: %[[VAL_173:.*]] = fptoui %[[VAL_159]] : f32 to i64 -// CHECK: %[[VAL_174:.*]] = index_cast %[[VAL_173]] : i64 to index -// CHECK: %[[VAL_175:.*]] = memref.load %[[VAL_172]]{{\[}}%[[VAL_174]]] : memref<3xf32> -// CHECK: %[[VAL_176:.*]] = memref.get_global @histogram_0 : memref<2xf32> -// CHECK: %[[VAL_177:.*]] = fptoui %[[VAL_161]] : f32 to i64 -// CHECK: %[[VAL_178:.*]] = index_cast %[[VAL_177]] : i64 to index -// CHECK: %[[VAL_179:.*]] = memref.load %[[VAL_176]]{{\[}}%[[VAL_178]]] : memref<2xf32> -// CHECK: %[[VAL_180:.*]] = memref.get_global @histogram_1 : memref<2xf32> -// CHECK: %[[VAL_181:.*]] = fptoui %[[VAL_163]] : f32 to i64 -// CHECK: %[[VAL_182:.*]] = index_cast %[[VAL_181]] : i64 to index -// CHECK: %[[VAL_183:.*]] = memref.load %[[VAL_180]]{{\[}}%[[VAL_182]]] : memref<2xf32> -// CHECK: %[[VAL_184:.*]] = constant -5.000000e-01 : f32 -// CHECK: %[[VAL_185:.*]] = constant -0.918938517 : f32 -// CHECK: %[[VAL_186:.*]] = constant 5.000000e-01 : f32 -// CHECK: %[[VAL_187:.*]] = subf %[[VAL_165]], %[[VAL_186]] : f32 -// CHECK: %[[VAL_188:.*]] = mulf %[[VAL_187]], %[[VAL_187]] : f32 -// CHECK: %[[VAL_189:.*]] = mulf %[[VAL_188]], %[[VAL_184]] : f32 -// CHECK: %[[VAL_190:.*]] = addf %[[VAL_185]], %[[VAL_189]] : f32 -// CHECK: %[[VAL_191:.*]] = constant -5.000000e+01 : f32 -// CHECK: %[[VAL_192:.*]] = constant 1.38364661 : f32 -// CHECK: %[[VAL_193:.*]] = constant 2.500000e-01 : f32 -// CHECK: %[[VAL_194:.*]] = subf %[[VAL_167]], %[[VAL_193]] : f32 -// CHECK: %[[VAL_195:.*]] = mulf %[[VAL_194]], %[[VAL_194]] : f32 -// CHECK: %[[VAL_196:.*]] = mulf %[[VAL_195]], %[[VAL_191]] : f32 -// CHECK: %[[VAL_197:.*]] = addf %[[VAL_192]], %[[VAL_196]] : f32 -// CHECK: %[[VAL_198:.*]] = addf %[[VAL_171]], %[[VAL_175]] : f32 -// CHECK: %[[VAL_199:.*]] = addf %[[VAL_198]], %[[VAL_179]] : f32 -// CHECK: %[[VAL_200:.*]] = constant 1.000000e-01 : f32 -// CHECK: %[[VAL_201:.*]] = addf %[[VAL_199]], %[[VAL_200]] : f32 -// CHECK: %[[VAL_202:.*]] = addf %[[VAL_183]], %[[VAL_190]] : f32 -// CHECK: %[[VAL_203:.*]] = addf %[[VAL_202]], %[[VAL_197]] : f32 -// CHECK: %[[VAL_204:.*]] = constant 1.000000e-01 : f32 -// CHECK: %[[VAL_205:.*]] = addf %[[VAL_203]], %[[VAL_204]] : f32 -// CHECK: %[[VAL_206:.*]] = cmpf ogt, %[[VAL_201]], %[[VAL_205]] : f32 -// CHECK: %[[VAL_207:.*]] = select %[[VAL_206]], %[[VAL_201]], %[[VAL_205]] : f32 -// CHECK: %[[VAL_208:.*]] = select %[[VAL_206]], %[[VAL_205]], %[[VAL_201]] : f32 -// CHECK: %[[VAL_209:.*]] = subf %[[VAL_208]], %[[VAL_207]] : f32 -// CHECK: %[[VAL_210:.*]] = math.exp %[[VAL_209]] : f32 -// CHECK: %[[VAL_211:.*]] = math.log1p %[[VAL_210]] : f32 -// CHECK: %[[VAL_212:.*]] = addf %[[VAL_207]], %[[VAL_211]] : f32 -// CHECK: %[[VAL_213:.*]] = constant 0 : index -// CHECK: memref.store %[[VAL_212]], %[[VAL_1]]{{\[}}%[[VAL_213]], %[[VAL_155]]] : memref<1x?xf32> +// CHECK: %[[VAL_168:.*]] = fptoui %[[VAL_157]] : f32 to i64 +// CHECK: %[[VAL_169:.*]] = index_cast %[[VAL_168]] : i64 to index +// CHECK: %[[VAL_170:.*]] = constant 0 : i64 +// CHECK: %[[VAL_171:.*]] = constant 3 : i64 +// CHECK: %[[VAL_172:.*]] = cmpi sge, %[[VAL_168]], %[[VAL_170]] : i64 +// CHECK: %[[VAL_173:.*]] = cmpi slt, %[[VAL_168]], %[[VAL_171]] : i64 +// CHECK: %[[VAL_174:.*]] = and %[[VAL_172]], %[[VAL_173]] : i1 +// CHECK: %[[VAL_175:.*]] = scf.if %[[VAL_174]] -> (f32) { +// CHECK: %[[VAL_176:.*]] = memref.get_global @categorical_0 : memref<3xf32> +// CHECK: %[[VAL_177:.*]] = memref.load %[[VAL_176]]{{\[}}%[[VAL_169]]] : memref<3xf32> +// CHECK: scf.yield %[[VAL_177]] : f32 +// CHECK: } else { +// CHECK: %[[VAL_178:.*]] = constant 0xFF800000 : f32 +// CHECK: scf.yield %[[VAL_178]] : f32 +// CHECK: } +// CHECK: %[[VAL_179:.*]] = fptoui %[[VAL_159]] : f32 to i64 +// CHECK: %[[VAL_180:.*]] = index_cast %[[VAL_179]] : i64 to index +// CHECK: %[[VAL_181:.*]] = constant 0 : i64 +// CHECK: %[[VAL_182:.*]] = constant 3 : i64 +// CHECK: %[[VAL_183:.*]] = cmpi sge, %[[VAL_179]], %[[VAL_181]] : i64 +// CHECK: %[[VAL_184:.*]] = cmpi slt, %[[VAL_179]], %[[VAL_182]] : i64 +// CHECK: %[[VAL_185:.*]] = and %[[VAL_183]], %[[VAL_184]] : i1 +// CHECK: %[[VAL_186:.*]] = scf.if %[[VAL_185]] -> (f32) { +// CHECK: %[[VAL_187:.*]] = memref.get_global @categorical_1 : memref<3xf32> +// CHECK: %[[VAL_188:.*]] = memref.load %[[VAL_187]]{{\[}}%[[VAL_180]]] : memref<3xf32> +// CHECK: scf.yield %[[VAL_188]] : f32 +// CHECK: } else { +// CHECK: %[[VAL_189:.*]] = constant 0xFF800000 : f32 +// CHECK: scf.yield %[[VAL_189]] : f32 +// CHECK: } +// CHECK: %[[VAL_190:.*]] = fptoui %[[VAL_161]] : f32 to i64 +// CHECK: %[[VAL_191:.*]] = index_cast %[[VAL_190]] : i64 to index +// CHECK: %[[VAL_192:.*]] = constant 0 : i64 +// CHECK: %[[VAL_193:.*]] = constant 2 : i64 +// CHECK: %[[VAL_194:.*]] = cmpi sge, %[[VAL_190]], %[[VAL_192]] : i64 +// CHECK: %[[VAL_195:.*]] = cmpi slt, %[[VAL_190]], %[[VAL_193]] : i64 +// CHECK: %[[VAL_196:.*]] = and %[[VAL_194]], %[[VAL_195]] : i1 +// CHECK: %[[VAL_197:.*]] = scf.if %[[VAL_196]] -> (f32) { +// CHECK: %[[VAL_198:.*]] = memref.get_global @histogram_0 : memref<2xf32> +// CHECK: %[[VAL_199:.*]] = memref.load %[[VAL_198]]{{\[}}%[[VAL_191]]] : memref<2xf32> +// CHECK: scf.yield %[[VAL_199]] : f32 +// CHECK: } else { +// CHECK: %[[VAL_200:.*]] = constant 0xFF800000 : f32 +// CHECK: scf.yield %[[VAL_200]] : f32 +// CHECK: } +// CHECK: %[[VAL_201:.*]] = fptoui %[[VAL_163]] : f32 to i64 +// CHECK: %[[VAL_202:.*]] = index_cast %[[VAL_201]] : i64 to index +// CHECK: %[[VAL_203:.*]] = constant 0 : i64 +// CHECK: %[[VAL_204:.*]] = constant 2 : i64 +// CHECK: %[[VAL_205:.*]] = cmpi sge, %[[VAL_201]], %[[VAL_203]] : i64 +// CHECK: %[[VAL_206:.*]] = cmpi slt, %[[VAL_201]], %[[VAL_204]] : i64 +// CHECK: %[[VAL_207:.*]] = and %[[VAL_205]], %[[VAL_206]] : i1 +// CHECK: %[[VAL_208:.*]] = scf.if %[[VAL_207]] -> (f32) { +// CHECK: %[[VAL_209:.*]] = memref.get_global @histogram_1 : memref<2xf32> +// CHECK: %[[VAL_210:.*]] = memref.load %[[VAL_209]]{{\[}}%[[VAL_202]]] : memref<2xf32> +// CHECK: scf.yield %[[VAL_210]] : f32 +// CHECK: } else { +// CHECK: %[[VAL_211:.*]] = constant 0xFF800000 : f32 +// CHECK: scf.yield %[[VAL_211]] : f32 +// CHECK: } +// CHECK: %[[VAL_212:.*]] = constant -5.000000e-01 : f32 +// CHECK: %[[VAL_213:.*]] = constant -0.918938517 : f32 +// CHECK: %[[VAL_214:.*]] = constant 5.000000e-01 : f32 +// CHECK: %[[VAL_215:.*]] = subf %[[VAL_165]], %[[VAL_214]] : f32 +// CHECK: %[[VAL_216:.*]] = mulf %[[VAL_215]], %[[VAL_215]] : f32 +// CHECK: %[[VAL_217:.*]] = mulf %[[VAL_216]], %[[VAL_212]] : f32 +// CHECK: %[[VAL_218:.*]] = addf %[[VAL_213]], %[[VAL_217]] : f32 +// CHECK: %[[VAL_219:.*]] = constant -5.000000e+01 : f32 +// CHECK: %[[VAL_220:.*]] = constant 1.38364661 : f32 +// CHECK: %[[VAL_221:.*]] = constant 2.500000e-01 : f32 +// CHECK: %[[VAL_222:.*]] = subf %[[VAL_167]], %[[VAL_221]] : f32 +// CHECK: %[[VAL_223:.*]] = mulf %[[VAL_222]], %[[VAL_222]] : f32 +// CHECK: %[[VAL_224:.*]] = mulf %[[VAL_223]], %[[VAL_219]] : f32 +// CHECK: %[[VAL_225:.*]] = addf %[[VAL_220]], %[[VAL_224]] : f32 +// CHECK: %[[VAL_226:.*]] = addf %[[VAL_227:.*]], %[[VAL_228:.*]] : f32 +// CHECK: %[[VAL_229:.*]] = addf %[[VAL_226]], %[[VAL_230:.*]] : f32 +// CHECK: %[[VAL_231:.*]] = constant 1.000000e-01 : f32 +// CHECK: %[[VAL_232:.*]] = addf %[[VAL_229]], %[[VAL_231]] : f32 +// CHECK: %[[VAL_233:.*]] = addf %[[VAL_234:.*]], %[[VAL_218]] : f32 +// CHECK: %[[VAL_235:.*]] = addf %[[VAL_233]], %[[VAL_225]] : f32 +// CHECK: %[[VAL_236:.*]] = constant 1.000000e-01 : f32 +// CHECK: %[[VAL_237:.*]] = addf %[[VAL_235]], %[[VAL_236]] : f32 +// CHECK: %[[VAL_238:.*]] = cmpf ogt, %[[VAL_232]], %[[VAL_237]] : f32 +// CHECK: %[[VAL_239:.*]] = select %[[VAL_238]], %[[VAL_232]], %[[VAL_237]] : f32 +// CHECK: %[[VAL_240:.*]] = select %[[VAL_238]], %[[VAL_237]], %[[VAL_232]] : f32 +// CHECK: %[[VAL_241:.*]] = subf %[[VAL_240]], %[[VAL_239]] : f32 +// CHECK: %[[VAL_242:.*]] = math.exp %[[VAL_241]] : f32 +// CHECK: %[[VAL_243:.*]] = math.log1p %[[VAL_242]] : f32 +// CHECK: %[[VAL_244:.*]] = addf %[[VAL_239]], %[[VAL_243]] : f32 +// CHECK: %[[VAL_245:.*]] = constant 0 : index +// CHECK: memref.store %[[VAL_244]], %[[VAL_1]]{{\[}}%[[VAL_245]], %[[VAL_155]]] : memref<1x?xf32> // CHECK: } // CHECK: return // CHECK: } diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar.mlir index 2da0b864..f415f2fa 100644 --- a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar.mlir +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar.mlir @@ -377,50 +377,94 @@ module { // CHECK: %[[VAL_162:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_152]], %[[VAL_161]]] : memref // CHECK: %[[VAL_163:.*]] = constant 5 : index // CHECK: %[[VAL_164:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_152]], %[[VAL_163]]] : memref -// CHECK: %[[VAL_165:.*]] = memref.get_global @categorical_0 : memref<3xf64> -// CHECK: %[[VAL_166:.*]] = fptoui %[[VAL_154]] : f64 to i64 -// CHECK: %[[VAL_167:.*]] = index_cast %[[VAL_166]] : i64 to index -// CHECK: %[[VAL_168:.*]] = memref.load %[[VAL_165]]{{\[}}%[[VAL_167]]] : memref<3xf64> -// CHECK: %[[VAL_169:.*]] = memref.get_global @categorical_1 : memref<3xf64> -// CHECK: %[[VAL_170:.*]] = fptoui %[[VAL_156]] : f64 to i64 -// CHECK: %[[VAL_171:.*]] = index_cast %[[VAL_170]] : i64 to index -// CHECK: %[[VAL_172:.*]] = memref.load %[[VAL_169]]{{\[}}%[[VAL_171]]] : memref<3xf64> -// CHECK: %[[VAL_173:.*]] = memref.get_global @histogram_0 : memref<2xf64> -// CHECK: %[[VAL_174:.*]] = fptoui %[[VAL_158]] : f64 to i64 -// CHECK: %[[VAL_175:.*]] = index_cast %[[VAL_174]] : i64 to index -// CHECK: %[[VAL_176:.*]] = memref.load %[[VAL_173]]{{\[}}%[[VAL_175]]] : memref<2xf64> -// CHECK: %[[VAL_177:.*]] = memref.get_global @histogram_1 : memref<2xf64> -// CHECK: %[[VAL_178:.*]] = fptoui %[[VAL_160]] : f64 to i64 -// CHECK: %[[VAL_179:.*]] = index_cast %[[VAL_178]] : i64 to index -// CHECK: %[[VAL_180:.*]] = memref.load %[[VAL_177]]{{\[}}%[[VAL_179]]] : memref<2xf64> -// CHECK: %[[VAL_181:.*]] = constant 0.3989422804014327 : f64 -// CHECK: %[[VAL_182:.*]] = constant -5.000000e-01 : f64 -// CHECK: %[[VAL_183:.*]] = constant 5.000000e-01 : f64 -// CHECK: %[[VAL_184:.*]] = subf %[[VAL_162]], %[[VAL_183]] : f64 -// CHECK: %[[VAL_185:.*]] = mulf %[[VAL_184]], %[[VAL_184]] : f64 -// CHECK: %[[VAL_186:.*]] = mulf %[[VAL_185]], %[[VAL_182]] : f64 -// CHECK: %[[VAL_187:.*]] = math.exp %[[VAL_186]] : f64 -// CHECK: %[[VAL_188:.*]] = mulf %[[VAL_181]], %[[VAL_187]] : f64 -// CHECK: %[[VAL_189:.*]] = constant 3.9894228040143269 : f64 -// CHECK: %[[VAL_190:.*]] = constant -49.999999999999993 : f64 -// CHECK: %[[VAL_191:.*]] = constant 2.500000e-01 : f64 -// CHECK: %[[VAL_192:.*]] = subf %[[VAL_164]], %[[VAL_191]] : f64 -// CHECK: %[[VAL_193:.*]] = mulf %[[VAL_192]], %[[VAL_192]] : f64 -// CHECK: %[[VAL_194:.*]] = mulf %[[VAL_193]], %[[VAL_190]] : f64 -// CHECK: %[[VAL_195:.*]] = math.exp %[[VAL_194]] : f64 -// CHECK: %[[VAL_196:.*]] = mulf %[[VAL_189]], %[[VAL_195]] : f64 -// CHECK: %[[VAL_197:.*]] = mulf %[[VAL_168]], %[[VAL_172]] : f64 -// CHECK: %[[VAL_198:.*]] = mulf %[[VAL_197]], %[[VAL_176]] : f64 -// CHECK: %[[VAL_199:.*]] = constant 1.000000e-01 : f64 -// CHECK: %[[VAL_200:.*]] = mulf %[[VAL_198]], %[[VAL_199]] : f64 -// CHECK: %[[VAL_201:.*]] = mulf %[[VAL_180]], %[[VAL_188]] : f64 -// CHECK: %[[VAL_202:.*]] = mulf %[[VAL_201]], %[[VAL_196]] : f64 -// CHECK: %[[VAL_203:.*]] = constant 1.000000e-01 : f64 -// CHECK: %[[VAL_204:.*]] = mulf %[[VAL_202]], %[[VAL_203]] : f64 -// CHECK: %[[VAL_205:.*]] = addf %[[VAL_200]], %[[VAL_204]] : f64 -// CHECK: %[[VAL_206:.*]] = math.log %[[VAL_205]] : f64 -// CHECK: %[[VAL_207:.*]] = constant 0 : index -// CHECK: memref.store %[[VAL_206]], %[[VAL_1]]{{\[}}%[[VAL_207]], %[[VAL_152]]] : memref<1x?xf64> +// CHECK: %[[VAL_165:.*]] = fptoui %[[VAL_154]] : f64 to i64 +// CHECK: %[[VAL_166:.*]] = index_cast %[[VAL_165]] : i64 to index +// CHECK: %[[VAL_167:.*]] = constant 0 : i64 +// CHECK: %[[VAL_168:.*]] = constant 3 : i64 +// CHECK: %[[VAL_169:.*]] = cmpi sge, %[[VAL_165]], %[[VAL_167]] : i64 +// CHECK: %[[VAL_170:.*]] = cmpi slt, %[[VAL_165]], %[[VAL_168]] : i64 +// CHECK: %[[VAL_171:.*]] = and %[[VAL_169]], %[[VAL_170]] : i1 +// CHECK: %[[VAL_172:.*]] = scf.if %[[VAL_171]] -> (f64) { +// CHECK: %[[VAL_173:.*]] = memref.get_global @categorical_0 : memref<3xf64> +// CHECK: %[[VAL_174:.*]] = memref.load %[[VAL_173]]{{\[}}%[[VAL_166]]] : memref<3xf64> +// CHECK: scf.yield %[[VAL_174]] : f64 +// CHECK: } else { +// CHECK: %[[VAL_175:.*]] = constant 0.000000e+00 : f64 +// CHECK: scf.yield %[[VAL_175]] : f64 +// CHECK: } +// CHECK: %[[VAL_176:.*]] = fptoui %[[VAL_156]] : f64 to i64 +// CHECK: %[[VAL_177:.*]] = index_cast %[[VAL_176]] : i64 to index +// CHECK: %[[VAL_178:.*]] = constant 0 : i64 +// CHECK: %[[VAL_179:.*]] = constant 3 : i64 +// CHECK: %[[VAL_180:.*]] = cmpi sge, %[[VAL_176]], %[[VAL_178]] : i64 +// CHECK: %[[VAL_181:.*]] = cmpi slt, %[[VAL_176]], %[[VAL_179]] : i64 +// CHECK: %[[VAL_182:.*]] = and %[[VAL_180]], %[[VAL_181]] : i1 +// CHECK: %[[VAL_183:.*]] = scf.if %[[VAL_182]] -> (f64) { +// CHECK: %[[VAL_184:.*]] = memref.get_global @categorical_1 : memref<3xf64> +// CHECK: %[[VAL_185:.*]] = memref.load %[[VAL_184]]{{\[}}%[[VAL_177]]] : memref<3xf64> +// CHECK: scf.yield %[[VAL_185]] : f64 +// CHECK: } else { +// CHECK: %[[VAL_186:.*]] = constant 0.000000e+00 : f64 +// CHECK: scf.yield %[[VAL_186]] : f64 +// CHECK: } +// CHECK: %[[VAL_187:.*]] = fptoui %[[VAL_158]] : f64 to i64 +// CHECK: %[[VAL_188:.*]] = index_cast %[[VAL_187]] : i64 to index +// CHECK: %[[VAL_189:.*]] = constant 0 : i64 +// CHECK: %[[VAL_190:.*]] = constant 2 : i64 +// CHECK: %[[VAL_191:.*]] = cmpi sge, %[[VAL_187]], %[[VAL_189]] : i64 +// CHECK: %[[VAL_192:.*]] = cmpi slt, %[[VAL_187]], %[[VAL_190]] : i64 +// CHECK: %[[VAL_193:.*]] = and %[[VAL_191]], %[[VAL_192]] : i1 +// CHECK: %[[VAL_194:.*]] = scf.if %[[VAL_193]] -> (f64) { +// CHECK: %[[VAL_195:.*]] = memref.get_global @histogram_0 : memref<2xf64> +// CHECK: %[[VAL_196:.*]] = memref.load %[[VAL_195]]{{\[}}%[[VAL_188]]] : memref<2xf64> +// CHECK: scf.yield %[[VAL_196]] : f64 +// CHECK: } else { +// CHECK: %[[VAL_197:.*]] = constant 0.000000e+00 : f64 +// CHECK: scf.yield %[[VAL_197]] : f64 +// CHECK: } +// CHECK: %[[VAL_198:.*]] = fptoui %[[VAL_160]] : f64 to i64 +// CHECK: %[[VAL_199:.*]] = index_cast %[[VAL_198]] : i64 to index +// CHECK: %[[VAL_200:.*]] = constant 0 : i64 +// CHECK: %[[VAL_201:.*]] = constant 2 : i64 +// CHECK: %[[VAL_202:.*]] = cmpi sge, %[[VAL_198]], %[[VAL_200]] : i64 +// CHECK: %[[VAL_203:.*]] = cmpi slt, %[[VAL_198]], %[[VAL_201]] : i64 +// CHECK: %[[VAL_204:.*]] = and %[[VAL_202]], %[[VAL_203]] : i1 +// CHECK: %[[VAL_205:.*]] = scf.if %[[VAL_204]] -> (f64) { +// CHECK: %[[VAL_206:.*]] = memref.get_global @histogram_1 : memref<2xf64> +// CHECK: %[[VAL_207:.*]] = memref.load %[[VAL_206]]{{\[}}%[[VAL_199]]] : memref<2xf64> +// CHECK: scf.yield %[[VAL_207]] : f64 +// CHECK: } else { +// CHECK: %[[VAL_208:.*]] = constant 0.000000e+00 : f64 +// CHECK: scf.yield %[[VAL_208]] : f64 +// CHECK: } +// CHECK: %[[VAL_209:.*]] = constant 0.3989422804014327 : f64 +// CHECK: %[[VAL_210:.*]] = constant -5.000000e-01 : f64 +// CHECK: %[[VAL_211:.*]] = constant 5.000000e-01 : f64 +// CHECK: %[[VAL_212:.*]] = subf %[[VAL_162]], %[[VAL_211]] : f64 +// CHECK: %[[VAL_213:.*]] = mulf %[[VAL_212]], %[[VAL_212]] : f64 +// CHECK: %[[VAL_214:.*]] = mulf %[[VAL_213]], %[[VAL_210]] : f64 +// CHECK: %[[VAL_215:.*]] = math.exp %[[VAL_214]] : f64 +// CHECK: %[[VAL_216:.*]] = mulf %[[VAL_209]], %[[VAL_215]] : f64 +// CHECK: %[[VAL_217:.*]] = constant 3.9894228040143269 : f64 +// CHECK: %[[VAL_218:.*]] = constant -49.999999999999993 : f64 +// CHECK: %[[VAL_219:.*]] = constant 2.500000e-01 : f64 +// CHECK: %[[VAL_220:.*]] = subf %[[VAL_164]], %[[VAL_219]] : f64 +// CHECK: %[[VAL_221:.*]] = mulf %[[VAL_220]], %[[VAL_220]] : f64 +// CHECK: %[[VAL_222:.*]] = mulf %[[VAL_221]], %[[VAL_218]] : f64 +// CHECK: %[[VAL_223:.*]] = math.exp %[[VAL_222]] : f64 +// CHECK: %[[VAL_224:.*]] = mulf %[[VAL_217]], %[[VAL_223]] : f64 +// CHECK: %[[VAL_225:.*]] = mulf %[[VAL_226:.*]], %[[VAL_227:.*]] : f64 +// CHECK: %[[VAL_228:.*]] = mulf %[[VAL_225]], %[[VAL_229:.*]] : f64 +// CHECK: %[[VAL_230:.*]] = constant 1.000000e-01 : f64 +// CHECK: %[[VAL_231:.*]] = mulf %[[VAL_228]], %[[VAL_230]] : f64 +// CHECK: %[[VAL_232:.*]] = mulf %[[VAL_233:.*]], %[[VAL_216]] : f64 +// CHECK: %[[VAL_234:.*]] = mulf %[[VAL_232]], %[[VAL_224]] : f64 +// CHECK: %[[VAL_235:.*]] = constant 1.000000e-01 : f64 +// CHECK: %[[VAL_236:.*]] = mulf %[[VAL_234]], %[[VAL_235]] : f64 +// CHECK: %[[VAL_237:.*]] = addf %[[VAL_231]], %[[VAL_236]] : f64 +// CHECK: %[[VAL_238:.*]] = math.log %[[VAL_237]] : f64 +// CHECK: %[[VAL_239:.*]] = constant 0 : index +// CHECK: memref.store %[[VAL_238]], %[[VAL_1]]{{\[}}%[[VAL_239]], %[[VAL_152]]] : memref<1x?xf64> // CHECK: } // CHECK: return // CHECK: } diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select-vectorize.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select-vectorize.mlir new file mode 100644 index 00000000..a63e7bd0 --- /dev/null +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select-vectorize.mlir @@ -0,0 +1,50 @@ +// RUN: %optcall --vectorize-lospn-nodes %s | FileCheck %s + +module { + func @vec_task_0(%arg0: memref, %arg1: memref<1x?xf64>) { + %c0 = constant 0 : index + %c4 = constant 4 : index + scf.for %arg2 = %c0 to %c4 step %c4 { + %0 = "lo_spn.batch_read"(%arg0, %arg2) {staticIndex = 0 : ui32, vector_width = 4 : i32} : (memref, index) -> f64 + %1 = "lo_spn.select"(%0) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 3.500000e-01 : f64, vector_width = 4 : i32} : (f64) -> f64 + "lo_spn.batch_write"(%arg1, %arg2, %1) {vector_width = 4 : i32, transposed = true} : (memref<1x?xf64>, index, f64) -> () + } + return + } +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + + +// CHECK-LABEL: func @vec_task_0( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x?xf64>) { +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = constant 4 : index +// CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_2]] to %[[VAL_3]] step %[[VAL_3]] { +// CHECK: %[[VAL_5:.*]] = index_cast %[[VAL_4]] : index to i64 +// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : i64 to vector<4xi64> +// CHECK: %[[VAL_7:.*]] = constant dense<[0, 1, 2, 3]> : vector<4xi64> +// CHECK: %[[VAL_8:.*]] = constant dense<1> : vector<4xi64> +// CHECK: %[[VAL_9:.*]] = muli %[[VAL_6]], %[[VAL_8]] : vector<4xi64> +// CHECK: %[[VAL_10:.*]] = addi %[[VAL_9]], %[[VAL_7]] : vector<4xi64> +// CHECK: %[[VAL_11:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_12:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_13:.*]] = constant 0 : index +// CHECK: %[[VAL_14:.*]] = memref.dim %[[VAL_0]], %[[VAL_13]] : memref +// CHECK: %[[VAL_15:.*]] = constant 1 : index +// CHECK: %[[VAL_16:.*]] = muli %[[VAL_14]], %[[VAL_15]] : index +// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: {{\[}}%[[VAL_16]]], strides: [1] : memref to memref +// CHECK: %[[VAL_18:.*]] = constant 0 : index +// CHECK: %[[VAL_19:.*]] = vector.gather %[[VAL_17]]{{\[}}%[[VAL_18]]] {{\[}}%[[VAL_10]]], %[[VAL_12]], %[[VAL_11]] : memref, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_20:.*]] = constant dense<1.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_21:.*]] = cmpf ult, %[[VAL_19]], %[[VAL_20]] : vector<4xf64> +// CHECK: %[[VAL_22:.*]] = constant dense<3.500000e-01> : vector<4xf64> +// CHECK: %[[VAL_23:.*]] = constant dense<5.500000e-01> : vector<4xf64> +// CHECK: %[[VAL_24:.*]] = select %[[VAL_21]], %[[VAL_22]], %[[VAL_23]] : vector<4xi1>, vector<4xf64> +// CHECK: %[[VAL_25:.*]] = constant 0 : index +// CHECK: vector.transfer_write %[[VAL_24]], %[[VAL_1]]{{\[}}%[[VAL_25]], %[[VAL_4]]] : vector<4xf64>, memref<1x?xf64> +// CHECK: } +// CHECK: return +// CHECK: } + diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir new file mode 100644 index 00000000..4c237cc9 --- /dev/null +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir @@ -0,0 +1,37 @@ +// RUN: %optcall --convert-lospn-nodes-to-cpu %s | FileCheck %s + +module { + func @task_0(%arg0: memref, %arg1: memref<1x?xf64>) { + %cst1 = constant 1.000000e-01 : f64 + %ind1 = constant 1 : index + %0 = "lo_spn.select"(%cst1) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 3.500000e-01 : f64} : (f64) -> f64 + "lo_spn.batch_write"(%arg1, %ind1, %0) {transposed = true} : (memref<1x?xf64>, index, f64) -> () + return + } +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + + +// CHECK-LABEL: func @task_0( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x?xf64>) { +// CHECK: %[[VAL_2:.*]] = constant 1.000000e-01 : f64 +// CHECK: %[[VAL_3:.*]] = constant 1 : index +// CHECK: %[[VAL_4:.*]] = constant 0.000000e+00 : f64 +// CHECK: %[[VAL_5:.*]] = constant 2.000000e+00 : f64 +// CHECK: %[[VAL_6:.*]] = cmpf uge, %[[VAL_2]], %[[VAL_4]] : f64 +// CHECK: %[[VAL_7:.*]] = cmpf ult, %[[VAL_2]], %[[VAL_5]] : f64 +// CHECK: %[[VAL_8:.*]] = constant 1.000000e+00 : f64 +// CHECK: %[[VAL_9:.*]] = cmpf ult, %[[VAL_2]], %[[VAL_8]] : f64 +// CHECK: %[[VAL_10:.*]] = constant 3.500000e-01 : f64 +// CHECK: %[[VAL_11:.*]] = constant 5.500000e-01 : f64 +// CHECK: %[[VAL_12:.*]] = select %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] : f64 +// CHECK: %[[VAL_13:.*]] = and %[[VAL_6]], %[[VAL_7]] : i1 +// CHECK: %[[VAL_14:.*]] = constant 0.000000e+00 : f64 +// CHECK: %[[VAL_15:.*]] = select %[[VAL_13]], %[[VAL_12]], %[[VAL_14]] : f64 +// CHECK: %[[VAL_16:.*]] = constant 0 : index +// CHECK: memref.store %[[VAL_15]], %[[VAL_1]]{{\[}}%[[VAL_16]], %[[VAL_3]]] : memref<1x?xf64> +// CHECK: return +// CHECK: } + diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize-log.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize-log.mlir index fa3f44ed..2ef64e36 100644 --- a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize-log.mlir +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize-log.mlir @@ -184,85 +184,113 @@ module { // CHECK: %[[VAL_99:.*]] = vector.gather %[[VAL_97]]{{\[}}%[[VAL_98]]] {{\[}}%[[VAL_90]]], %[[VAL_92]], %[[VAL_91]] : memref, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> // CHECK: %[[VAL_100:.*]] = memref.get_global @categorical_vec_0 : memref<3xf32> // CHECK: %[[VAL_101:.*]] = fptoui %[[VAL_24]] : vector<8xf32> to vector<8xi64> -// CHECK: %[[VAL_102:.*]] = constant dense<0.000000e+00> : vector<8xf32> -// CHECK: %[[VAL_103:.*]] = constant dense : vector<8xi1> -// CHECK: %[[VAL_104:.*]] = constant 0 : index -// CHECK: %[[VAL_105:.*]] = vector.gather %[[VAL_100]]{{\[}}%[[VAL_104]]] {{\[}}%[[VAL_101]]], %[[VAL_103]], %[[VAL_102]] : memref<3xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK: %[[VAL_106:.*]] = memref.get_global @categorical_vec_1 : memref<3xf32> -// CHECK: %[[VAL_107:.*]] = fptoui %[[VAL_39]] : vector<8xf32> to vector<8xi64> -// CHECK: %[[VAL_108:.*]] = constant dense<0.000000e+00> : vector<8xf32> -// CHECK: %[[VAL_109:.*]] = constant dense : vector<8xi1> -// CHECK: %[[VAL_110:.*]] = constant 0 : index -// CHECK: %[[VAL_111:.*]] = vector.gather %[[VAL_106]]{{\[}}%[[VAL_110]]] {{\[}}%[[VAL_107]]], %[[VAL_109]], %[[VAL_108]] : memref<3xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK: %[[VAL_112:.*]] = memref.get_global @histogram_vec_0 : memref<2xf32> -// CHECK: %[[VAL_113:.*]] = fptoui %[[VAL_54]] : vector<8xf32> to vector<8xi64> -// CHECK: %[[VAL_114:.*]] = constant dense<0.000000e+00> : vector<8xf32> -// CHECK: %[[VAL_115:.*]] = constant dense : vector<8xi1> -// CHECK: %[[VAL_116:.*]] = constant 0 : index -// CHECK: %[[VAL_117:.*]] = vector.gather %[[VAL_112]]{{\[}}%[[VAL_116]]] {{\[}}%[[VAL_113]]], %[[VAL_115]], %[[VAL_114]] : memref<2xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK: %[[VAL_118:.*]] = memref.get_global @histogram_vec_1 : memref<2xf32> -// CHECK: %[[VAL_119:.*]] = fptoui %[[VAL_69]] : vector<8xf32> to vector<8xi64> -// CHECK: %[[VAL_120:.*]] = constant dense<0.000000e+00> : vector<8xf32> +// CHECK: %[[VAL_102:.*]] = constant dense<0xFF800000> : vector<8xf32> +// CHECK: %[[VAL_103:.*]] = constant dense<0> : vector<8xi64> +// CHECK: %[[VAL_104:.*]] = constant dense<3> : vector<8xi64> +// CHECK: %[[VAL_105:.*]] = cmpi sge, %[[VAL_101]], %[[VAL_103]] : vector<8xi64> +// CHECK: %[[VAL_106:.*]] = cmpi slt, %[[VAL_101]], %[[VAL_104]] : vector<8xi64> +// CHECK: %[[VAL_107:.*]] = and %[[VAL_105]], %[[VAL_106]] : vector<8xi1> +// CHECK: %[[VAL_108:.*]] = constant dense : vector<8xi1> +// CHECK: %[[VAL_109:.*]] = constant dense : vector<8xi1> +// CHECK: %[[VAL_110:.*]] = select %[[VAL_107]], %[[VAL_108]], %[[VAL_109]] : vector<8xi1>, vector<8xi1> +// CHECK: %[[VAL_111:.*]] = constant 0 : index +// CHECK: %[[VAL_112:.*]] = vector.gather %[[VAL_100]]{{\[}}%[[VAL_111]]] {{\[}}%[[VAL_101]]], %[[VAL_110]], %[[VAL_102]] : memref<3xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK: %[[VAL_113:.*]] = memref.get_global @categorical_vec_1 : memref<3xf32> +// CHECK: %[[VAL_114:.*]] = fptoui %[[VAL_39]] : vector<8xf32> to vector<8xi64> +// CHECK: %[[VAL_115:.*]] = constant dense<0xFF800000> : vector<8xf32> +// CHECK: %[[VAL_116:.*]] = constant dense<0> : vector<8xi64> +// CHECK: %[[VAL_117:.*]] = constant dense<3> : vector<8xi64> +// CHECK: %[[VAL_118:.*]] = cmpi sge, %[[VAL_114]], %[[VAL_116]] : vector<8xi64> +// CHECK: %[[VAL_119:.*]] = cmpi slt, %[[VAL_114]], %[[VAL_117]] : vector<8xi64> +// CHECK: %[[VAL_120:.*]] = and %[[VAL_118]], %[[VAL_119]] : vector<8xi1> // CHECK: %[[VAL_121:.*]] = constant dense : vector<8xi1> -// CHECK: %[[VAL_122:.*]] = constant 0 : index -// CHECK: %[[VAL_123:.*]] = vector.gather %[[VAL_118]]{{\[}}%[[VAL_122]]] {{\[}}%[[VAL_119]]], %[[VAL_121]], %[[VAL_120]] : memref<2xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK: %[[VAL_124:.*]] = constant dense<-5.000000e-01> : vector<8xf32> -// CHECK: %[[VAL_125:.*]] = constant dense<-0.918938517> : vector<8xf32> -// CHECK: %[[VAL_126:.*]] = constant dense<5.000000e-01> : vector<8xf32> -// CHECK: %[[VAL_127:.*]] = subf %[[VAL_84]], %[[VAL_126]] : vector<8xf32> -// CHECK: %[[VAL_128:.*]] = mulf %[[VAL_127]], %[[VAL_127]] : vector<8xf32> -// CHECK: %[[VAL_129:.*]] = mulf %[[VAL_128]], %[[VAL_124]] : vector<8xf32> -// CHECK: %[[VAL_130:.*]] = addf %[[VAL_125]], %[[VAL_129]] : vector<8xf32> -// CHECK: %[[VAL_131:.*]] = constant dense<-5.000000e+01> : vector<8xf32> -// CHECK: %[[VAL_132:.*]] = constant dense<1.38364661> : vector<8xf32> -// CHECK: %[[VAL_133:.*]] = constant dense<2.500000e-01> : vector<8xf32> -// CHECK: %[[VAL_134:.*]] = subf %[[VAL_99]], %[[VAL_133]] : vector<8xf32> -// CHECK: %[[VAL_135:.*]] = mulf %[[VAL_134]], %[[VAL_134]] : vector<8xf32> -// CHECK: %[[VAL_136:.*]] = mulf %[[VAL_135]], %[[VAL_131]] : vector<8xf32> -// CHECK: %[[VAL_137:.*]] = addf %[[VAL_132]], %[[VAL_136]] : vector<8xf32> -// CHECK: %[[VAL_138:.*]] = addf %[[VAL_105]], %[[VAL_111]] : vector<8xf32> -// CHECK: %[[VAL_139:.*]] = addf %[[VAL_138]], %[[VAL_117]] : vector<8xf32> -// CHECK: %[[VAL_140:.*]] = constant dense<1.000000e-01> : vector<8xf32> -// CHECK: %[[VAL_141:.*]] = addf %[[VAL_139]], %[[VAL_140]] : vector<8xf32> -// CHECK: %[[VAL_142:.*]] = addf %[[VAL_123]], %[[VAL_130]] : vector<8xf32> -// CHECK: %[[VAL_143:.*]] = addf %[[VAL_142]], %[[VAL_137]] : vector<8xf32> -// CHECK: %[[VAL_144:.*]] = constant dense<1.000000e-01> : vector<8xf32> -// CHECK: %[[VAL_145:.*]] = addf %[[VAL_143]], %[[VAL_144]] : vector<8xf32> -// CHECK: %[[VAL_146:.*]] = cmpf ogt, %[[VAL_141]], %[[VAL_145]] : vector<8xf32> -// CHECK: %[[VAL_147:.*]] = select %[[VAL_146]], %[[VAL_141]], %[[VAL_145]] : vector<8xi1>, vector<8xf32> -// CHECK: %[[VAL_148:.*]] = select %[[VAL_146]], %[[VAL_145]], %[[VAL_141]] : vector<8xi1>, vector<8xf32> -// CHECK: %[[VAL_149:.*]] = subf %[[VAL_148]], %[[VAL_147]] : vector<8xf32> -// CHECK: %[[VAL_150:.*]] = math.exp %[[VAL_149]] : vector<8xf32> -// CHECK: %[[VAL_151:.*]] = math.log1p %[[VAL_150]] : vector<8xf32> -// CHECK: %[[VAL_152:.*]] = addf %[[VAL_147]], %[[VAL_151]] : vector<8xf32> -// CHECK: %[[VAL_153:.*]] = constant 0 : index -// CHECK: vector.transfer_write %[[VAL_152]], %[[VAL_1]]{{\[}}%[[VAL_153]], %[[VAL_9]]] : vector<8xf32>, memref<1x?xf32> +// CHECK: %[[VAL_122:.*]] = constant dense : vector<8xi1> +// CHECK: %[[VAL_123:.*]] = select %[[VAL_120]], %[[VAL_121]], %[[VAL_122]] : vector<8xi1>, vector<8xi1> +// CHECK: %[[VAL_124:.*]] = constant 0 : index +// CHECK: %[[VAL_125:.*]] = vector.gather %[[VAL_113]]{{\[}}%[[VAL_124]]] {{\[}}%[[VAL_114]]], %[[VAL_123]], %[[VAL_115]] : memref<3xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK: %[[VAL_126:.*]] = memref.get_global @histogram_vec_0 : memref<2xf32> +// CHECK: %[[VAL_127:.*]] = fptoui %[[VAL_54]] : vector<8xf32> to vector<8xi64> +// CHECK: %[[VAL_128:.*]] = constant dense<0xFF800000> : vector<8xf32> +// CHECK: %[[VAL_129:.*]] = constant dense<0> : vector<8xi64> +// CHECK: %[[VAL_130:.*]] = constant dense<2> : vector<8xi64> +// CHECK: %[[VAL_131:.*]] = cmpi sge, %[[VAL_127]], %[[VAL_129]] : vector<8xi64> +// CHECK: %[[VAL_132:.*]] = cmpi slt, %[[VAL_127]], %[[VAL_130]] : vector<8xi64> +// CHECK: %[[VAL_133:.*]] = and %[[VAL_131]], %[[VAL_132]] : vector<8xi1> +// CHECK: %[[VAL_134:.*]] = constant dense : vector<8xi1> +// CHECK: %[[VAL_135:.*]] = constant dense : vector<8xi1> +// CHECK: %[[VAL_136:.*]] = select %[[VAL_133]], %[[VAL_134]], %[[VAL_135]] : vector<8xi1>, vector<8xi1> +// CHECK: %[[VAL_137:.*]] = constant 0 : index +// CHECK: %[[VAL_138:.*]] = vector.gather %[[VAL_126]]{{\[}}%[[VAL_137]]] {{\[}}%[[VAL_127]]], %[[VAL_136]], %[[VAL_128]] : memref<2xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK: %[[VAL_139:.*]] = memref.get_global @histogram_vec_1 : memref<2xf32> +// CHECK: %[[VAL_140:.*]] = fptoui %[[VAL_69]] : vector<8xf32> to vector<8xi64> +// CHECK: %[[VAL_141:.*]] = constant dense<0xFF800000> : vector<8xf32> +// CHECK: %[[VAL_142:.*]] = constant dense<0> : vector<8xi64> +// CHECK: %[[VAL_143:.*]] = constant dense<2> : vector<8xi64> +// CHECK: %[[VAL_144:.*]] = cmpi sge, %[[VAL_140]], %[[VAL_142]] : vector<8xi64> +// CHECK: %[[VAL_145:.*]] = cmpi slt, %[[VAL_140]], %[[VAL_143]] : vector<8xi64> +// CHECK: %[[VAL_146:.*]] = and %[[VAL_144]], %[[VAL_145]] : vector<8xi1> +// CHECK: %[[VAL_147:.*]] = constant dense : vector<8xi1> +// CHECK: %[[VAL_148:.*]] = constant dense : vector<8xi1> +// CHECK: %[[VAL_149:.*]] = select %[[VAL_146]], %[[VAL_147]], %[[VAL_148]] : vector<8xi1>, vector<8xi1> +// CHECK: %[[VAL_150:.*]] = constant 0 : index +// CHECK: %[[VAL_151:.*]] = vector.gather %[[VAL_139]]{{\[}}%[[VAL_150]]] {{\[}}%[[VAL_140]]], %[[VAL_149]], %[[VAL_141]] : memref<2xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK: %[[VAL_152:.*]] = constant dense<-5.000000e-01> : vector<8xf32> +// CHECK: %[[VAL_153:.*]] = constant dense<-0.918938517> : vector<8xf32> +// CHECK: %[[VAL_154:.*]] = constant dense<5.000000e-01> : vector<8xf32> +// CHECK: %[[VAL_155:.*]] = subf %[[VAL_84]], %[[VAL_154]] : vector<8xf32> +// CHECK: %[[VAL_156:.*]] = mulf %[[VAL_155]], %[[VAL_155]] : vector<8xf32> +// CHECK: %[[VAL_157:.*]] = mulf %[[VAL_156]], %[[VAL_152]] : vector<8xf32> +// CHECK: %[[VAL_158:.*]] = addf %[[VAL_153]], %[[VAL_157]] : vector<8xf32> +// CHECK: %[[VAL_159:.*]] = constant dense<-5.000000e+01> : vector<8xf32> +// CHECK: %[[VAL_160:.*]] = constant dense<1.38364661> : vector<8xf32> +// CHECK: %[[VAL_161:.*]] = constant dense<2.500000e-01> : vector<8xf32> +// CHECK: %[[VAL_162:.*]] = subf %[[VAL_99]], %[[VAL_161]] : vector<8xf32> +// CHECK: %[[VAL_163:.*]] = mulf %[[VAL_162]], %[[VAL_162]] : vector<8xf32> +// CHECK: %[[VAL_164:.*]] = mulf %[[VAL_163]], %[[VAL_159]] : vector<8xf32> +// CHECK: %[[VAL_165:.*]] = addf %[[VAL_160]], %[[VAL_164]] : vector<8xf32> +// CHECK: %[[VAL_166:.*]] = addf %[[VAL_112]], %[[VAL_125]] : vector<8xf32> +// CHECK: %[[VAL_167:.*]] = addf %[[VAL_166]], %[[VAL_138]] : vector<8xf32> +// CHECK: %[[VAL_168:.*]] = constant dense<1.000000e-01> : vector<8xf32> +// CHECK: %[[VAL_169:.*]] = addf %[[VAL_167]], %[[VAL_168]] : vector<8xf32> +// CHECK: %[[VAL_170:.*]] = addf %[[VAL_151]], %[[VAL_158]] : vector<8xf32> +// CHECK: %[[VAL_171:.*]] = addf %[[VAL_170]], %[[VAL_165]] : vector<8xf32> +// CHECK: %[[VAL_172:.*]] = constant dense<1.000000e-01> : vector<8xf32> +// CHECK: %[[VAL_173:.*]] = addf %[[VAL_171]], %[[VAL_172]] : vector<8xf32> +// CHECK: %[[VAL_174:.*]] = cmpf ogt, %[[VAL_169]], %[[VAL_173]] : vector<8xf32> +// CHECK: %[[VAL_175:.*]] = select %[[VAL_174]], %[[VAL_169]], %[[VAL_173]] : vector<8xi1>, vector<8xf32> +// CHECK: %[[VAL_176:.*]] = select %[[VAL_174]], %[[VAL_173]], %[[VAL_169]] : vector<8xi1>, vector<8xf32> +// CHECK: %[[VAL_177:.*]] = subf %[[VAL_176]], %[[VAL_175]] : vector<8xf32> +// CHECK: %[[VAL_178:.*]] = math.exp %[[VAL_177]] : vector<8xf32> +// CHECK: %[[VAL_179:.*]] = math.log1p %[[VAL_178]] : vector<8xf32> +// CHECK: %[[VAL_180:.*]] = addf %[[VAL_175]], %[[VAL_179]] : vector<8xf32> +// CHECK: %[[VAL_181:.*]] = constant 0 : index +// CHECK: vector.transfer_write %[[VAL_180]], %[[VAL_1]]{{\[}}%[[VAL_181]], %[[VAL_9]]] : vector<8xf32>, memref<1x?xf32> // CHECK: } -// CHECK: %[[VAL_154:.*]] = constant 1 : index -// CHECK: scf.for %[[VAL_155:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_154]] { -// CHECK: %[[VAL_156:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_155]]) {staticIndex = 0 : ui32} : (memref, index) -> f32 -// CHECK: %[[VAL_157:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_155]]) {staticIndex = 1 : ui32} : (memref, index) -> f32 -// CHECK: %[[VAL_158:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_155]]) {staticIndex = 2 : ui32} : (memref, index) -> f32 -// CHECK: %[[VAL_159:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_155]]) {staticIndex = 3 : ui32} : (memref, index) -> f32 -// CHECK: %[[VAL_160:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_155]]) {staticIndex = 4 : ui32} : (memref, index) -> f32 -// CHECK: %[[VAL_161:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_155]]) {staticIndex = 5 : ui32} : (memref, index) -> f32 -// CHECK: %[[VAL_162:.*]] = "lo_spn.categorical"(%[[VAL_156]]) {probabilities = [3.500000e-01, 5.500000e-01, 1.000000e-01], supportMarginal = false} : (f32) -> !lo_spn.log -// CHECK: %[[VAL_163:.*]] = "lo_spn.categorical"(%[[VAL_157]]) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false} : (f32) -> !lo_spn.log -// CHECK: %[[VAL_164:.*]] = "lo_spn.histogram"(%[[VAL_158]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 2.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 7.500000e-01 : f64}], supportMarginal = false} : (f32) -> !lo_spn.log -// CHECK: %[[VAL_165:.*]] = "lo_spn.histogram"(%[[VAL_159]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 4.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 5.500000e-01 : f64}], supportMarginal = false} : (f32) -> !lo_spn.log -// CHECK: %[[VAL_166:.*]] = "lo_spn.gaussian"(%[[VAL_160]]) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false} : (f32) -> !lo_spn.log -// CHECK: %[[VAL_167:.*]] = "lo_spn.gaussian"(%[[VAL_161]]) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false} : (f32) -> !lo_spn.log -// CHECK: %[[VAL_168:.*]] = "lo_spn.mul"(%[[VAL_162]], %[[VAL_163]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log -// CHECK: %[[VAL_169:.*]] = "lo_spn.mul"(%[[VAL_168]], %[[VAL_164]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log -// CHECK: %[[VAL_170:.*]] = "lo_spn.constant"() {type = !lo_spn.log, value = 1.000000e-01 : f64} : () -> !lo_spn.log -// CHECK: %[[VAL_171:.*]] = "lo_spn.mul"(%[[VAL_169]], %[[VAL_170]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log -// CHECK: %[[VAL_172:.*]] = "lo_spn.mul"(%[[VAL_165]], %[[VAL_166]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log -// CHECK: %[[VAL_173:.*]] = "lo_spn.mul"(%[[VAL_172]], %[[VAL_167]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log -// CHECK: %[[VAL_174:.*]] = "lo_spn.constant"() {type = !lo_spn.log, value = 1.000000e-01 : f64} : () -> !lo_spn.log -// CHECK: %[[VAL_175:.*]] = "lo_spn.mul"(%[[VAL_173]], %[[VAL_174]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log -// CHECK: %[[VAL_176:.*]] = "lo_spn.add"(%[[VAL_171]], %[[VAL_175]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log -// CHECK: %[[VAL_177:.*]] = "lo_spn.strip_log"(%[[VAL_176]]) {target = f32} : (!lo_spn.log) -> f32 -// CHECK: "lo_spn.batch_write"(%[[VAL_1]], %[[VAL_155]], %[[VAL_177]]) {transposed = true} : (memref<1x?xf32>, index, f32) -> () +// CHECK: %[[VAL_182:.*]] = constant 1 : index +// CHECK: scf.for %[[VAL_183:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_182]] { +// CHECK: %[[VAL_184:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_183]]) {staticIndex = 0 : ui32} : (memref, index) -> f32 +// CHECK: %[[VAL_185:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_183]]) {staticIndex = 1 : ui32} : (memref, index) -> f32 +// CHECK: %[[VAL_186:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_183]]) {staticIndex = 2 : ui32} : (memref, index) -> f32 +// CHECK: %[[VAL_187:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_183]]) {staticIndex = 3 : ui32} : (memref, index) -> f32 +// CHECK: %[[VAL_188:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_183]]) {staticIndex = 4 : ui32} : (memref, index) -> f32 +// CHECK: %[[VAL_189:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_183]]) {staticIndex = 5 : ui32} : (memref, index) -> f32 +// CHECK: %[[VAL_190:.*]] = "lo_spn.categorical"(%[[VAL_184]]) {probabilities = [3.500000e-01, 5.500000e-01, 1.000000e-01], supportMarginal = false} : (f32) -> !lo_spn.log +// CHECK: %[[VAL_191:.*]] = "lo_spn.categorical"(%[[VAL_185]]) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false} : (f32) -> !lo_spn.log +// CHECK: %[[VAL_192:.*]] = "lo_spn.histogram"(%[[VAL_186]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 2.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 7.500000e-01 : f64}], supportMarginal = false} : (f32) -> !lo_spn.log +// CHECK: %[[VAL_193:.*]] = "lo_spn.histogram"(%[[VAL_187]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 4.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 5.500000e-01 : f64}], supportMarginal = false} : (f32) -> !lo_spn.log +// CHECK: %[[VAL_194:.*]] = "lo_spn.gaussian"(%[[VAL_188]]) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false} : (f32) -> !lo_spn.log +// CHECK: %[[VAL_195:.*]] = "lo_spn.gaussian"(%[[VAL_189]]) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false} : (f32) -> !lo_spn.log +// CHECK: %[[VAL_196:.*]] = "lo_spn.mul"(%[[VAL_190]], %[[VAL_191]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log +// CHECK: %[[VAL_197:.*]] = "lo_spn.mul"(%[[VAL_196]], %[[VAL_192]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log +// CHECK: %[[VAL_198:.*]] = "lo_spn.constant"() {type = !lo_spn.log, value = 1.000000e-01 : f64} : () -> !lo_spn.log +// CHECK: %[[VAL_199:.*]] = "lo_spn.mul"(%[[VAL_197]], %[[VAL_198]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log +// CHECK: %[[VAL_200:.*]] = "lo_spn.mul"(%[[VAL_193]], %[[VAL_194]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log +// CHECK: %[[VAL_201:.*]] = "lo_spn.mul"(%[[VAL_200]], %[[VAL_195]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log +// CHECK: %[[VAL_202:.*]] = "lo_spn.constant"() {type = !lo_spn.log, value = 1.000000e-01 : f64} : () -> !lo_spn.log +// CHECK: %[[VAL_203:.*]] = "lo_spn.mul"(%[[VAL_201]], %[[VAL_202]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log +// CHECK: %[[VAL_204:.*]] = "lo_spn.add"(%[[VAL_199]], %[[VAL_203]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log +// CHECK: %[[VAL_205:.*]] = "lo_spn.strip_log"(%[[VAL_204]]) {target = f32} : (!lo_spn.log) -> f32 +// CHECK: "lo_spn.batch_write"(%[[VAL_1]], %[[VAL_183]], %[[VAL_205]]) {transposed = true} : (memref<1x?xf32>, index, f32) -> () // CHECK: } // CHECK: return // CHECK: } @@ -278,4 +306,4 @@ module { // CHECK: %[[VAL_6:.*]] = memref.buffer_cast %[[VAL_5]] : memref<1x?xf32> // CHECK: "lo_spn.copy"(%[[VAL_6]], %[[VAL_1]]) : (memref<1x?xf32>, memref<1x?xf32>) -> () // CHECK: "lo_spn.return"() : () -> () -// CHECK: } +// CHECK: } \ No newline at end of file diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize.mlir index 81b8b6bd..a6d6f2ff 100644 --- a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize.mlir +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize.mlir @@ -185,81 +185,109 @@ module { // CHECK: %[[VAL_100:.*]] = memref.get_global @categorical_vec_0 : memref<3xf64> // CHECK: %[[VAL_101:.*]] = fptoui %[[VAL_24]] : vector<4xf64> to vector<4xi64> // CHECK: %[[VAL_102:.*]] = constant dense<0.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_103:.*]] = constant dense : vector<4xi1> -// CHECK: %[[VAL_104:.*]] = constant 0 : index -// CHECK: %[[VAL_105:.*]] = vector.gather %[[VAL_100]]{{\[}}%[[VAL_104]]] {{\[}}%[[VAL_101]]], %[[VAL_103]], %[[VAL_102]] : memref<3xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_106:.*]] = memref.get_global @categorical_vec_1 : memref<3xf64> -// CHECK: %[[VAL_107:.*]] = fptoui %[[VAL_39]] : vector<4xf64> to vector<4xi64> -// CHECK: %[[VAL_108:.*]] = constant dense<0.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_109:.*]] = constant dense : vector<4xi1> -// CHECK: %[[VAL_110:.*]] = constant 0 : index -// CHECK: %[[VAL_111:.*]] = vector.gather %[[VAL_106]]{{\[}}%[[VAL_110]]] {{\[}}%[[VAL_107]]], %[[VAL_109]], %[[VAL_108]] : memref<3xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_112:.*]] = memref.get_global @histogram_vec_0 : memref<2xf64> -// CHECK: %[[VAL_113:.*]] = fptoui %[[VAL_54]] : vector<4xf64> to vector<4xi64> -// CHECK: %[[VAL_114:.*]] = constant dense<0.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_115:.*]] = constant dense : vector<4xi1> -// CHECK: %[[VAL_116:.*]] = constant 0 : index -// CHECK: %[[VAL_117:.*]] = vector.gather %[[VAL_112]]{{\[}}%[[VAL_116]]] {{\[}}%[[VAL_113]]], %[[VAL_115]], %[[VAL_114]] : memref<2xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_118:.*]] = memref.get_global @histogram_vec_1 : memref<2xf64> -// CHECK: %[[VAL_119:.*]] = fptoui %[[VAL_69]] : vector<4xf64> to vector<4xi64> -// CHECK: %[[VAL_120:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_103:.*]] = constant dense<0> : vector<4xi64> +// CHECK: %[[VAL_104:.*]] = constant dense<3> : vector<4xi64> +// CHECK: %[[VAL_105:.*]] = cmpi sge, %[[VAL_101]], %[[VAL_103]] : vector<4xi64> +// CHECK: %[[VAL_106:.*]] = cmpi slt, %[[VAL_101]], %[[VAL_104]] : vector<4xi64> +// CHECK: %[[VAL_107:.*]] = and %[[VAL_105]], %[[VAL_106]] : vector<4xi1> +// CHECK: %[[VAL_108:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_109:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_110:.*]] = select %[[VAL_107]], %[[VAL_108]], %[[VAL_109]] : vector<4xi1>, vector<4xi1> +// CHECK: %[[VAL_111:.*]] = constant 0 : index +// CHECK: %[[VAL_112:.*]] = vector.gather %[[VAL_100]]{{\[}}%[[VAL_111]]] {{\[}}%[[VAL_101]]], %[[VAL_110]], %[[VAL_102]] : memref<3xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_113:.*]] = memref.get_global @categorical_vec_1 : memref<3xf64> +// CHECK: %[[VAL_114:.*]] = fptoui %[[VAL_39]] : vector<4xf64> to vector<4xi64> +// CHECK: %[[VAL_115:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_116:.*]] = constant dense<0> : vector<4xi64> +// CHECK: %[[VAL_117:.*]] = constant dense<3> : vector<4xi64> +// CHECK: %[[VAL_118:.*]] = cmpi sge, %[[VAL_114]], %[[VAL_116]] : vector<4xi64> +// CHECK: %[[VAL_119:.*]] = cmpi slt, %[[VAL_114]], %[[VAL_117]] : vector<4xi64> +// CHECK: %[[VAL_120:.*]] = and %[[VAL_118]], %[[VAL_119]] : vector<4xi1> // CHECK: %[[VAL_121:.*]] = constant dense : vector<4xi1> -// CHECK: %[[VAL_122:.*]] = constant 0 : index -// CHECK: %[[VAL_123:.*]] = vector.gather %[[VAL_118]]{{\[}}%[[VAL_122]]] {{\[}}%[[VAL_119]]], %[[VAL_121]], %[[VAL_120]] : memref<2xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_124:.*]] = constant dense<0.3989422804014327> : vector<4xf64> -// CHECK: %[[VAL_125:.*]] = constant dense<-5.000000e-01> : vector<4xf64> -// CHECK: %[[VAL_126:.*]] = constant dense<5.000000e-01> : vector<4xf64> -// CHECK: %[[VAL_127:.*]] = subf %[[VAL_84]], %[[VAL_126]] : vector<4xf64> -// CHECK: %[[VAL_128:.*]] = mulf %[[VAL_127]], %[[VAL_127]] : vector<4xf64> -// CHECK: %[[VAL_129:.*]] = mulf %[[VAL_128]], %[[VAL_125]] : vector<4xf64> -// CHECK: %[[VAL_130:.*]] = math.exp %[[VAL_129]] : vector<4xf64> -// CHECK: %[[VAL_131:.*]] = mulf %[[VAL_124]], %[[VAL_130]] : vector<4xf64> -// CHECK: %[[VAL_132:.*]] = constant dense<3.9894228040143269> : vector<4xf64> -// CHECK: %[[VAL_133:.*]] = constant dense<-49.999999999999993> : vector<4xf64> -// CHECK: %[[VAL_134:.*]] = constant dense<2.500000e-01> : vector<4xf64> -// CHECK: %[[VAL_135:.*]] = subf %[[VAL_99]], %[[VAL_134]] : vector<4xf64> -// CHECK: %[[VAL_136:.*]] = mulf %[[VAL_135]], %[[VAL_135]] : vector<4xf64> -// CHECK: %[[VAL_137:.*]] = mulf %[[VAL_136]], %[[VAL_133]] : vector<4xf64> -// CHECK: %[[VAL_138:.*]] = math.exp %[[VAL_137]] : vector<4xf64> -// CHECK: %[[VAL_139:.*]] = mulf %[[VAL_132]], %[[VAL_138]] : vector<4xf64> -// CHECK: %[[VAL_140:.*]] = mulf %[[VAL_105]], %[[VAL_111]] : vector<4xf64> -// CHECK: %[[VAL_141:.*]] = mulf %[[VAL_140]], %[[VAL_117]] : vector<4xf64> -// CHECK: %[[VAL_142:.*]] = constant dense<1.000000e-01> : vector<4xf64> -// CHECK: %[[VAL_143:.*]] = mulf %[[VAL_141]], %[[VAL_142]] : vector<4xf64> -// CHECK: %[[VAL_144:.*]] = mulf %[[VAL_123]], %[[VAL_131]] : vector<4xf64> -// CHECK: %[[VAL_145:.*]] = mulf %[[VAL_144]], %[[VAL_139]] : vector<4xf64> -// CHECK: %[[VAL_146:.*]] = constant dense<1.000000e-01> : vector<4xf64> -// CHECK: %[[VAL_147:.*]] = mulf %[[VAL_145]], %[[VAL_146]] : vector<4xf64> -// CHECK: %[[VAL_148:.*]] = addf %[[VAL_143]], %[[VAL_147]] : vector<4xf64> -// CHECK: %[[VAL_149:.*]] = math.log %[[VAL_148]] : vector<4xf64> +// CHECK: %[[VAL_122:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_123:.*]] = select %[[VAL_120]], %[[VAL_121]], %[[VAL_122]] : vector<4xi1>, vector<4xi1> +// CHECK: %[[VAL_124:.*]] = constant 0 : index +// CHECK: %[[VAL_125:.*]] = vector.gather %[[VAL_113]]{{\[}}%[[VAL_124]]] {{\[}}%[[VAL_114]]], %[[VAL_123]], %[[VAL_115]] : memref<3xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_126:.*]] = memref.get_global @histogram_vec_0 : memref<2xf64> +// CHECK: %[[VAL_127:.*]] = fptoui %[[VAL_54]] : vector<4xf64> to vector<4xi64> +// CHECK: %[[VAL_128:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_129:.*]] = constant dense<0> : vector<4xi64> +// CHECK: %[[VAL_130:.*]] = constant dense<2> : vector<4xi64> +// CHECK: %[[VAL_131:.*]] = cmpi sge, %[[VAL_127]], %[[VAL_129]] : vector<4xi64> +// CHECK: %[[VAL_132:.*]] = cmpi slt, %[[VAL_127]], %[[VAL_130]] : vector<4xi64> +// CHECK: %[[VAL_133:.*]] = and %[[VAL_131]], %[[VAL_132]] : vector<4xi1> +// CHECK: %[[VAL_134:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_135:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_136:.*]] = select %[[VAL_133]], %[[VAL_134]], %[[VAL_135]] : vector<4xi1>, vector<4xi1> +// CHECK: %[[VAL_137:.*]] = constant 0 : index +// CHECK: %[[VAL_138:.*]] = vector.gather %[[VAL_126]]{{\[}}%[[VAL_137]]] {{\[}}%[[VAL_127]]], %[[VAL_136]], %[[VAL_128]] : memref<2xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_139:.*]] = memref.get_global @histogram_vec_1 : memref<2xf64> +// CHECK: %[[VAL_140:.*]] = fptoui %[[VAL_69]] : vector<4xf64> to vector<4xi64> +// CHECK: %[[VAL_141:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_142:.*]] = constant dense<0> : vector<4xi64> +// CHECK: %[[VAL_143:.*]] = constant dense<2> : vector<4xi64> +// CHECK: %[[VAL_144:.*]] = cmpi sge, %[[VAL_140]], %[[VAL_142]] : vector<4xi64> +// CHECK: %[[VAL_145:.*]] = cmpi slt, %[[VAL_140]], %[[VAL_143]] : vector<4xi64> +// CHECK: %[[VAL_146:.*]] = and %[[VAL_144]], %[[VAL_145]] : vector<4xi1> +// CHECK: %[[VAL_147:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_148:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_149:.*]] = select %[[VAL_146]], %[[VAL_147]], %[[VAL_148]] : vector<4xi1>, vector<4xi1> // CHECK: %[[VAL_150:.*]] = constant 0 : index -// CHECK: vector.transfer_write %[[VAL_149]], %[[VAL_1]]{{\[}}%[[VAL_150]], %[[VAL_9]]] : vector<4xf64>, memref<1x?xf64> +// CHECK: %[[VAL_151:.*]] = vector.gather %[[VAL_139]]{{\[}}%[[VAL_150]]] {{\[}}%[[VAL_140]]], %[[VAL_149]], %[[VAL_141]] : memref<2xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_152:.*]] = constant dense<0.3989422804014327> : vector<4xf64> +// CHECK: %[[VAL_153:.*]] = constant dense<-5.000000e-01> : vector<4xf64> +// CHECK: %[[VAL_154:.*]] = constant dense<5.000000e-01> : vector<4xf64> +// CHECK: %[[VAL_155:.*]] = subf %[[VAL_84]], %[[VAL_154]] : vector<4xf64> +// CHECK: %[[VAL_156:.*]] = mulf %[[VAL_155]], %[[VAL_155]] : vector<4xf64> +// CHECK: %[[VAL_157:.*]] = mulf %[[VAL_156]], %[[VAL_153]] : vector<4xf64> +// CHECK: %[[VAL_158:.*]] = math.exp %[[VAL_157]] : vector<4xf64> +// CHECK: %[[VAL_159:.*]] = mulf %[[VAL_152]], %[[VAL_158]] : vector<4xf64> +// CHECK: %[[VAL_160:.*]] = constant dense<3.9894228040143269> : vector<4xf64> +// CHECK: %[[VAL_161:.*]] = constant dense<-49.999999999999993> : vector<4xf64> +// CHECK: %[[VAL_162:.*]] = constant dense<2.500000e-01> : vector<4xf64> +// CHECK: %[[VAL_163:.*]] = subf %[[VAL_99]], %[[VAL_162]] : vector<4xf64> +// CHECK: %[[VAL_164:.*]] = mulf %[[VAL_163]], %[[VAL_163]] : vector<4xf64> +// CHECK: %[[VAL_165:.*]] = mulf %[[VAL_164]], %[[VAL_161]] : vector<4xf64> +// CHECK: %[[VAL_166:.*]] = math.exp %[[VAL_165]] : vector<4xf64> +// CHECK: %[[VAL_167:.*]] = mulf %[[VAL_160]], %[[VAL_166]] : vector<4xf64> +// CHECK: %[[VAL_168:.*]] = mulf %[[VAL_112]], %[[VAL_125]] : vector<4xf64> +// CHECK: %[[VAL_169:.*]] = mulf %[[VAL_168]], %[[VAL_138]] : vector<4xf64> +// CHECK: %[[VAL_170:.*]] = constant dense<1.000000e-01> : vector<4xf64> +// CHECK: %[[VAL_171:.*]] = mulf %[[VAL_169]], %[[VAL_170]] : vector<4xf64> +// CHECK: %[[VAL_172:.*]] = mulf %[[VAL_151]], %[[VAL_159]] : vector<4xf64> +// CHECK: %[[VAL_173:.*]] = mulf %[[VAL_172]], %[[VAL_167]] : vector<4xf64> +// CHECK: %[[VAL_174:.*]] = constant dense<1.000000e-01> : vector<4xf64> +// CHECK: %[[VAL_175:.*]] = mulf %[[VAL_173]], %[[VAL_174]] : vector<4xf64> +// CHECK: %[[VAL_176:.*]] = addf %[[VAL_171]], %[[VAL_175]] : vector<4xf64> +// CHECK: %[[VAL_177:.*]] = math.log %[[VAL_176]] : vector<4xf64> +// CHECK: %[[VAL_178:.*]] = constant 0 : index +// CHECK: vector.transfer_write %[[VAL_177]], %[[VAL_1]]{{\[}}%[[VAL_178]], %[[VAL_9]]] : vector<4xf64>, memref<1x?xf64> // CHECK: } -// CHECK: %[[VAL_151:.*]] = constant 1 : index -// CHECK: scf.for %[[VAL_152:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_151]] { -// CHECK: %[[VAL_153:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_152]]) {staticIndex = 0 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_154:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_152]]) {staticIndex = 1 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_155:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_152]]) {staticIndex = 2 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_156:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_152]]) {staticIndex = 3 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_157:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_152]]) {staticIndex = 4 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_158:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_152]]) {staticIndex = 5 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_159:.*]] = "lo_spn.categorical"(%[[VAL_153]]) {probabilities = [3.500000e-01, 5.500000e-01, 1.000000e-01], supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_160:.*]] = "lo_spn.categorical"(%[[VAL_154]]) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_161:.*]] = "lo_spn.histogram"(%[[VAL_155]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 2.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 7.500000e-01 : f64}], supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_162:.*]] = "lo_spn.histogram"(%[[VAL_156]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 4.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 5.500000e-01 : f64}], supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_163:.*]] = "lo_spn.gaussian"(%[[VAL_157]]) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_164:.*]] = "lo_spn.gaussian"(%[[VAL_158]]) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_165:.*]] = "lo_spn.mul"(%[[VAL_159]], %[[VAL_160]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_166:.*]] = "lo_spn.mul"(%[[VAL_165]], %[[VAL_161]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_167:.*]] = "lo_spn.constant"() {type = f64, value = 1.000000e-01 : f64} : () -> f64 -// CHECK: %[[VAL_168:.*]] = "lo_spn.mul"(%[[VAL_166]], %[[VAL_167]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_169:.*]] = "lo_spn.mul"(%[[VAL_162]], %[[VAL_163]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_170:.*]] = "lo_spn.mul"(%[[VAL_169]], %[[VAL_164]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_171:.*]] = "lo_spn.constant"() {type = f64, value = 1.000000e-01 : f64} : () -> f64 -// CHECK: %[[VAL_172:.*]] = "lo_spn.mul"(%[[VAL_170]], %[[VAL_171]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_173:.*]] = "lo_spn.add"(%[[VAL_168]], %[[VAL_172]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_174:.*]] = "lo_spn.log"(%[[VAL_173]]) : (f64) -> f64 -// CHECK: "lo_spn.batch_write"(%[[VAL_1]], %[[VAL_152]], %[[VAL_174]]) {transposed = true} : (memref<1x?xf64>, index, f64) -> () +// CHECK: %[[VAL_179:.*]] = constant 1 : index +// CHECK: scf.for %[[VAL_180:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_179]] { +// CHECK: %[[VAL_181:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_180]]) {staticIndex = 0 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_182:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_180]]) {staticIndex = 1 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_183:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_180]]) {staticIndex = 2 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_184:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_180]]) {staticIndex = 3 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_185:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_180]]) {staticIndex = 4 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_186:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_180]]) {staticIndex = 5 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_187:.*]] = "lo_spn.categorical"(%[[VAL_181]]) {probabilities = [3.500000e-01, 5.500000e-01, 1.000000e-01], supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_188:.*]] = "lo_spn.categorical"(%[[VAL_182]]) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_189:.*]] = "lo_spn.histogram"(%[[VAL_183]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 2.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 7.500000e-01 : f64}], supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_190:.*]] = "lo_spn.histogram"(%[[VAL_184]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 4.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 5.500000e-01 : f64}], supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_191:.*]] = "lo_spn.gaussian"(%[[VAL_185]]) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_192:.*]] = "lo_spn.gaussian"(%[[VAL_186]]) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_193:.*]] = "lo_spn.mul"(%[[VAL_187]], %[[VAL_188]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_194:.*]] = "lo_spn.mul"(%[[VAL_193]], %[[VAL_189]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_195:.*]] = "lo_spn.constant"() {type = f64, value = 1.000000e-01 : f64} : () -> f64 +// CHECK: %[[VAL_196:.*]] = "lo_spn.mul"(%[[VAL_194]], %[[VAL_195]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_197:.*]] = "lo_spn.mul"(%[[VAL_190]], %[[VAL_191]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_198:.*]] = "lo_spn.mul"(%[[VAL_197]], %[[VAL_192]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_199:.*]] = "lo_spn.constant"() {type = f64, value = 1.000000e-01 : f64} : () -> f64 +// CHECK: %[[VAL_200:.*]] = "lo_spn.mul"(%[[VAL_198]], %[[VAL_199]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_201:.*]] = "lo_spn.add"(%[[VAL_196]], %[[VAL_200]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_202:.*]] = "lo_spn.log"(%[[VAL_201]]) : (f64) -> f64 +// CHECK: "lo_spn.batch_write"(%[[VAL_1]], %[[VAL_180]], %[[VAL_202]]) {transposed = true} : (memref<1x?xf64>, index, f64) -> () // CHECK: } // CHECK: return // CHECK: } diff --git a/mlir/test/transform/canonicalize/select-replacement-categorical.mlir b/mlir/test/transform/canonicalize/select-replacement-categorical.mlir new file mode 100644 index 00000000..0ba3bfb2 --- /dev/null +++ b/mlir/test/transform/canonicalize/select-replacement-categorical.mlir @@ -0,0 +1,51 @@ +// RUN: %optcall --canonicalize %s | FileCheck %s + +module { + "lo_spn.kernel"() ( { + ^bb0(%arg0: memref, %arg1: memref<1x?xf64>): // no predecessors + %c0 = constant 0 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.alloc(%0) : memref<1x?xf64> + "lo_spn.task"(%arg0, %1) ( { + ^bb0(%arg2: index, %arg3: memref, %arg4: memref<1x?xf64>): // no predecessors + %4 = "lo_spn.batch_read"(%arg3, %arg2) {staticIndex = 0 : ui32} : (memref, index) -> f64 + %5 = "lo_spn.body"(%4) ( { + ^bb0(%arg5: f64): // no predecessors + %6 = "lo_spn.categorical"(%arg5) {probabilities = [3.500000e-01, 5.500000e-01, 1.000000e-01], supportMarginal = false} : (f64) -> f64 + %7 = "lo_spn.categorical"(%arg5) {probabilities = [4.500000e-01, 5.500000e-01], supportMarginal = false} : (f64) -> f64 + %8 = "lo_spn.log"(%7) : (f64) -> f64 + "lo_spn.yield"(%8) : (f64) -> () + }) : (f64) -> f64 + "lo_spn.batch_write"(%arg4, %arg2, %5) {transposed = true} : (memref<1x?xf64>, index, f64) -> () + "lo_spn.return"() : () -> () + }) {batchSize = 12 : ui32} : (memref, memref<1x?xf64>) -> () + %2 = memref.tensor_load %1 : memref<1x?xf64> + %3 = memref.buffer_cast %2 : memref<1x?xf64> + "lo_spn.copy"(%3, %arg1) : (memref<1x?xf64>, memref<1x?xf64>) -> () + "lo_spn.return"() : () -> () + }) {sym_name = "spn_kernel", type = (memref, memref<1x?xf64>) -> ()} : () -> () +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + + +// CHECK-LABEL: "lo_spn.kernel"() ( { +// CHECK: ^bb0(%[[VAL_0:.*]]: memref, %[[VAL_1:.*]]: memref<1x?xf64>): +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = memref.dim %[[VAL_0]], %[[VAL_2]] : memref +// CHECK: %[[VAL_4:.*]] = memref.alloc(%[[VAL_3]]) : memref<1x?xf64> +// CHECK: "lo_spn.task"(%[[VAL_0]], %[[VAL_4]]) ( { +// CHECK: ^bb0(%[[VAL_5:.*]]: index, %[[VAL_6:.*]]: memref, %[[VAL_7:.*]]: memref<1x?xf64>): +// CHECK: %[[VAL_8:.*]] = "lo_spn.batch_read"(%[[VAL_6]], %[[VAL_5]]) {staticIndex = 0 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_9:.*]] = "lo_spn.body"(%[[VAL_8]]) ( { +// CHECK: ^bb0(%[[VAL_10:.*]]: f64): +// CHECK: %[[VAL_11:.*]] = "lo_spn.select"(%[[VAL_10]]) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 4.500000e-01 : f64} : (f64) -> f64 +// CHECK: %[[VAL_12:.*]] = "lo_spn.log"(%[[VAL_11]]) : (f64) -> f64 +// CHECK: "lo_spn.yield"(%[[VAL_12]]) : (f64) -> () +// CHECK: }) : (f64) -> f64 +// CHECK: "lo_spn.batch_write"(%[[VAL_7]], %[[VAL_5]], %[[VAL_13:.*]]) {transposed = true} : (memref<1x?xf64>, index, f64) -> () +// CHECK: "lo_spn.return"() : () -> () +// CHECK: }) {batchSize = 12 : ui32} : (memref, memref<1x?xf64>) -> () +// CHECK: "lo_spn.copy"(%[[VAL_4]], %[[VAL_1]]) : (memref<1x?xf64>, memref<1x?xf64>) -> () +// CHECK: "lo_spn.return"() : () -> () +// CHECK: }) {sym_name = "spn_kernel", type = (memref, memref<1x?xf64>) -> ()} : () -> () diff --git a/mlir/test/transform/canonicalize/select-replacement-histogram.mlir b/mlir/test/transform/canonicalize/select-replacement-histogram.mlir new file mode 100644 index 00000000..96eb2fa8 --- /dev/null +++ b/mlir/test/transform/canonicalize/select-replacement-histogram.mlir @@ -0,0 +1,50 @@ +// RUN: %optcall --canonicalize %s | FileCheck %s + +module { + "lo_spn.kernel"() ( { + ^bb0(%arg0: memref, %arg1: memref<1x?xf64>): // no predecessors + %c0 = constant 0 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.alloc(%0) : memref<1x?xf64> + "lo_spn.task"(%arg0, %1) ( { + ^bb0(%arg2: index, %arg3: memref, %arg4: memref<1x?xf64>): // no predecessors + %4 = "lo_spn.batch_read"(%arg3, %arg2) {staticIndex = 0 : ui32} : (memref, index) -> i32 + %5 = "lo_spn.body"(%4) ( { + ^bb0(%arg5: i32): // no predecessors + %6 = "lo_spn.histogram"(%arg5) {bucketCount = 2 : ui32, buckets = [{lb = 41 : i32, ub = 42 : i32, val = 2.500000e-01 : f64}, {lb = 42 : i32, ub = 43 : i32, val = 7.500000e-01 : f64}], supportMarginal = false} : (i32) -> f64 + %7 = "lo_spn.log"(%6) : (f64) -> f64 + "lo_spn.yield"(%7) : (f64) -> () + }) : (i32) -> f64 + "lo_spn.batch_write"(%arg4, %arg2, %5) {transposed = true} : (memref<1x?xf64>, index, f64) -> () + "lo_spn.return"() : () -> () + }) {batchSize = 12 : ui32} : (memref, memref<1x?xf64>) -> () + %2 = memref.tensor_load %1 : memref<1x?xf64> + %3 = memref.buffer_cast %2 : memref<1x?xf64> + "lo_spn.copy"(%3, %arg1) : (memref<1x?xf64>, memref<1x?xf64>) -> () + "lo_spn.return"() : () -> () + }) {sym_name = "spn_kernel", type = (memref, memref<1x?xf64>) -> ()} : () -> () +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + + +// CHECK-LABEL: "lo_spn.kernel"() ( { +// CHECK: ^bb0(%[[VAL_0:.*]]: memref, %[[VAL_1:.*]]: memref<1x?xf64>): +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = memref.dim %[[VAL_0]], %[[VAL_2]] : memref +// CHECK: %[[VAL_4:.*]] = memref.alloc(%[[VAL_3]]) : memref<1x?xf64> +// CHECK: "lo_spn.task"(%[[VAL_0]], %[[VAL_4]]) ( { +// CHECK: ^bb0(%[[VAL_5:.*]]: index, %[[VAL_6:.*]]: memref, %[[VAL_7:.*]]: memref<1x?xf64>): +// CHECK: %[[VAL_8:.*]] = "lo_spn.batch_read"(%[[VAL_6]], %[[VAL_5]]) {staticIndex = 0 : ui32} : (memref, index) -> i32 +// CHECK: %[[VAL_9:.*]] = "lo_spn.body"(%[[VAL_8]]) ( { +// CHECK: ^bb0(%[[VAL_10:.*]]: i32): +// CHECK: %[[VAL_11:.*]] = "lo_spn.select"(%[[VAL_10]]) {input_true_threshold = 4.200000e+01 : f64, supportMarginal = false, val_false = 7.500000e-01 : f64, val_true = 2.500000e-01 : f64} : (i32) -> f64 +// CHECK: %[[VAL_12:.*]] = "lo_spn.log"(%[[VAL_11]]) : (f64) -> f64 +// CHECK: "lo_spn.yield"(%[[VAL_12]]) : (f64) -> () +// CHECK: }) : (i32) -> f64 +// CHECK: "lo_spn.batch_write"(%[[VAL_7]], %[[VAL_5]], %[[VAL_13:.*]]) {transposed = true} : (memref<1x?xf64>, index, f64) -> () +// CHECK: "lo_spn.return"() : () -> () +// CHECK: }) {batchSize = 12 : ui32} : (memref, memref<1x?xf64>) -> () +// CHECK: "lo_spn.copy"(%[[VAL_4]], %[[VAL_1]]) : (memref<1x?xf64>, memref<1x?xf64>) -> () +// CHECK: "lo_spn.return"() : () -> () +// CHECK: }) {sym_name = "spn_kernel", type = (memref, memref<1x?xf64>) -> ()} : () -> () diff --git a/mlir/test/transform/gpu/cuda/lit.local.cfg b/mlir/test/transform/gpu/cuda/lit.local.cfg new file mode 100644 index 00000000..4c8c4f3c --- /dev/null +++ b/mlir/test/transform/gpu/cuda/lit.local.cfg @@ -0,0 +1 @@ +config.unsupported=True \ No newline at end of file diff --git a/python-interface/test/cpu/test_cpu_histogram.py b/python-interface/test/cpu/test_cpu_histogram.py index 0bd03a00..dcdb2476 100644 --- a/python-interface/test/cpu/test_cpu_histogram.py +++ b/python-interface/test/cpu/test_cpu_histogram.py @@ -17,9 +17,9 @@ def test_cpu_histogram(): # Construct a minimal SPN. - h1 = Histogram([0., 1., 2.], [0.25, 0.75], [1, 1], scope=0) - h2 = Histogram([0., 3., 6., 8.], [0.45, 0.1, 0.55], [1, 1], scope=1) - h3 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=0) + h1 = Histogram([0., 1., 3.], [0.25, 0.75], [1, 1], scope=0) + h2 = Histogram([0., 3., 6., 8.], [0.35, 0.1, 0.55], [1, 1], scope=1) + h3 = Histogram([0., 1., 4.], [0.33, 0.67], [1, 1], scope=0) h4 = Histogram([0., 5., 8.], [0.875, 0.125], [1, 1], scope=1) p0 = Product(children=[h1, h2]) diff --git a/python-interface/test/cpu/test_cpu_out_of_bounds.py b/python-interface/test/cpu/test_cpu_out_of_bounds.py new file mode 100644 index 00000000..bad14d58 --- /dev/null +++ b/python-interface/test/cpu/test_cpu_out_of_bounds.py @@ -0,0 +1,63 @@ +# ============================================================================== +# This file is part of the SPNC project under the Apache License v2.0 by the +# Embedded Systems and Applications Group, TU Darmstadt. +# For the full copyright and license information, please view the LICENSE +# file that was distributed with this source code. +# SPDX-License-Identifier: Apache-2.0 +# ============================================================================== + +import numpy as np + +from spn.structure.Base import Product, Sum +from spn.structure.leaves.histogram.Histograms import Histogram +from spn.structure.leaves.parametric.Parametric import Categorical +from spn.algorithms.Inference import log_likelihood + +from spnc.cpu import CPUCompiler + +import pytest + + +def test_cpu_out_of_bounds(): + # Construct a minimal SPN. + h1 = Histogram([0., 1., 2., 3.], [0.125, 0.250, 0.625], [1, 1], scope=0) + h2 = Histogram([0., 1., 2., 3.], [0.100, 0.200, 0.700], [1, 1], scope=1) + c1 = Categorical(p=[0.1, 0.7, 0.2], scope=0) + c2 = Categorical(p=[0.4, 0.2, 0.4], scope=1) + + p0 = Product(children=[h1, h2]) + p1 = Product(children=[c1, c2]) + spn = Sum([0.3, 0.7], [p0, p1]) + + # Generate some out-of-bounds accesses + max_index = 2 + inputs = np.column_stack(( + np.random.randint(2 * max_index, size=30), + np.random.randint(2 * max_index, size=30), + )).astype("float64") + + # Execute the compiled Kernel. + results = CPUCompiler(verbose=False, computeInLogSpace=False, vectorize=False).log_likelihood(spn, inputs, + supportMarginal=False) + + # Compute the reference results using the inference from SPFlow. + reference = log_likelihood(spn, inputs) + reference = reference.reshape(30) + + # Account for SPFlow behavior: Out-of-bounds values are not returned as -inf, but: + # for Categoricals: 0.0 + # for Histograms: (np.log(np.finfo(float).eps)) + # Note: np.log(np.finfo(float).eps) is equal to: np.log(2.220446049250313e-16) = -36.04365338911715 + # Find all inputs where an out-of-bounds access might occur. + # If the access is in-bounds, the corresponding index will be 0 (otherwise: > 0). + cond = np.sum(np.where((inputs > max_index), 1, 0), axis=1) + reference[cond > 0] = -np.inf + + # Check the computation results against the reference + # Check in normal space if log-results are not very close to each other. + assert np.all(np.isclose(results, reference)) or np.all(np.isclose(np.exp(results), np.exp(reference))) + + +if __name__ == "__main__": + test_cpu_out_of_bounds() + print("COMPUTATION OK") diff --git a/python-interface/test/cpu/test_cpu_transformation_categorical_to_select.py b/python-interface/test/cpu/test_cpu_transformation_categorical_to_select.py new file mode 100644 index 00000000..402a87c6 --- /dev/null +++ b/python-interface/test/cpu/test_cpu_transformation_categorical_to_select.py @@ -0,0 +1,50 @@ +# ============================================================================== +# This file is part of the SPNC project under the Apache License v2.0 by the +# Embedded Systems and Applications Group, TU Darmstadt. +# For the full copyright and license information, please view the LICENSE +# file that was distributed with this source code. +# SPDX-License-Identifier: Apache-2.0 +# ============================================================================== + +import numpy as np + +from spn.structure.Base import Product, Sum +from spn.structure.leaves.parametric.Parametric import Categorical +from spn.algorithms.Inference import log_likelihood + +from spnc.cpu import CPUCompiler + + +def test_cpu_transformation_categorical_to_select(): + # Construct a minimal SPN. + c1 = Categorical(p=[0.25, 0.75], scope=0) + c2 = Categorical(p=[0.33, 0.67], scope=1) + c3 = Categorical(p=[0.80, 0.20], scope=0) + c4 = Categorical(p=[0.50, 0.50], scope=1) + + p0 = Product(children=[c1, c2]) + p1 = Product(children=[c3, c4]) + spn = Sum([0.3, 0.7], [p0, p1]) + + inputs = np.column_stack(( + np.random.randint(2, size=30), + np.random.randint(2, size=30), + )).astype("float64") + + # Execute the compiled Kernel. + results = CPUCompiler(computeInLogSpace=False, vectorize=False).log_likelihood(spn, inputs, + supportMarginal=False, + batchSize=10) + + # Compute the reference results using the inference from SPFlow. + reference = log_likelihood(spn, inputs) + reference = reference.reshape(30) + + # Check the computation results against the reference + # Check in normal space if log-results are not very close to each other. + assert np.all(np.isclose(results, reference)) or np.all(np.isclose(np.exp(results), np.exp(reference))) + + +if __name__ == "__main__": + test_cpu_transformation_categorical_to_select() + print("COMPUTATION OK") diff --git a/python-interface/test/cpu/test_cpu_transformation_histogram_to_select.py b/python-interface/test/cpu/test_cpu_transformation_histogram_to_select.py new file mode 100644 index 00000000..ebe0515e --- /dev/null +++ b/python-interface/test/cpu/test_cpu_transformation_histogram_to_select.py @@ -0,0 +1,50 @@ +# ============================================================================== +# This file is part of the SPNC project under the Apache License v2.0 by the +# Embedded Systems and Applications Group, TU Darmstadt. +# For the full copyright and license information, please view the LICENSE +# file that was distributed with this source code. +# SPDX-License-Identifier: Apache-2.0 +# ============================================================================== + +import numpy as np + +from spn.structure.Base import Product, Sum +from spn.structure.leaves.histogram.Histograms import Histogram +from spn.algorithms.Inference import log_likelihood + +from spnc.cpu import CPUCompiler + + +def test_cpu_transformation_histogram_to_select(): + # Construct a minimal SPN. + h1 = Histogram([0., 1., 2.], [0.25, 0.75], [1, 1], scope=0) + h2 = Histogram([0., 1., 2.], [0.35, 0.65], [1, 1], scope=1) + h3 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=1) + h4 = Histogram([0., 1., 2.], [0.875, 0.125], [1, 1], scope=0) + + p0 = Product(children=[h1, h2]) + p1 = Product(children=[h3, h4]) + spn = Sum([0.3, 0.7], [p0, p1]) + + inputs = np.column_stack(( + np.random.randint(2, size=30), + np.random.randint(2, size=30), + )).astype("float64") + + # Execute the compiled Kernel. + results = CPUCompiler(verbose=False, computeInLogSpace=False, vectorize=False).log_likelihood(spn, inputs, + supportMarginal=False, + batchSize=10) + + # Compute the reference results using the inference from SPFlow. + reference = log_likelihood(spn, inputs) + reference = reference.reshape(30) + + # Check the computation results against the reference + # Check in normal space if log-results are not very close to each other. + assert np.all(np.isclose(results, reference)) or np.all(np.isclose(np.exp(results), np.exp(reference))) + + +if __name__ == "__main__": + test_cpu_transformation_histogram_to_select() + print("COMPUTATION OK") diff --git a/python-interface/test/cpu/test_graph_partitioning.py b/python-interface/test/cpu/test_graph_partitioning.py index 6fb6d6a1..c436813c 100644 --- a/python-interface/test/cpu/test_graph_partitioning.py +++ b/python-interface/test/cpu/test_graph_partitioning.py @@ -17,9 +17,9 @@ def test_cpu_histogram(): # Construct a minimal SPN. - h1 = Histogram([0., 1., 2.], [0.25, 0.75], [1, 1], scope=0) + h1 = Histogram([0., 1., 3.], [0.25, 0.75], [1, 1], scope=0) h2 = Histogram([0., 3., 6., 8.], [0.45, 0.1, 0.55], [1, 1], scope=1) - h3 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=0) + h3 = Histogram([0., 1., 4.], [0.33, 0.67], [1, 1], scope=0) h4 = Histogram([0., 5., 8.], [0.875, 0.125], [1, 1], scope=1) p0 = Product(children=[h1, h2]) diff --git a/python-interface/test/cpu/test_marginal_cpu_histogram.py b/python-interface/test/cpu/test_marginal_cpu_histogram.py index 4b8753a5..d3cd77f4 100644 --- a/python-interface/test/cpu/test_marginal_cpu_histogram.py +++ b/python-interface/test/cpu/test_marginal_cpu_histogram.py @@ -17,10 +17,10 @@ def test_cpu_histogram(): # Construct a minimal SPN. - h1 = Histogram([0., 1., 2.], [0.25, 0.75], [1, 1], scope=0) - h2 = Histogram([0., 1., 2.], [0.45, 0.55], [1, 1], scope=1) - h3 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=0) - h4 = Histogram([0., 1., 2.], [0.875, 0.125], [1, 1], scope=1) + h1 = Histogram([0., 1., 3.], [0.25, 0.75], [1, 1], scope=0) + h2 = Histogram([0., 1., 3.], [0.45, 0.55], [1, 1], scope=1) + h3 = Histogram([0., 1., 3.], [0.33, 0.67], [1, 1], scope=0) + h4 = Histogram([0., 1., 3.], [0.875, 0.125], [1, 1], scope=1) p0 = Product(children=[h1, h2]) p1 = Product(children=[h3, h4]) diff --git a/python-interface/test/cpu/test_marginal_cpu_transformation_to_select.py b/python-interface/test/cpu/test_marginal_cpu_transformation_to_select.py new file mode 100644 index 00000000..4b33d121 --- /dev/null +++ b/python-interface/test/cpu/test_marginal_cpu_transformation_to_select.py @@ -0,0 +1,52 @@ +# ============================================================================== +# This file is part of the SPNC project under the Apache License v2.0 by the +# Embedded Systems and Applications Group, TU Darmstadt. +# For the full copyright and license information, please view the LICENSE +# file that was distributed with this source code. +# SPDX-License-Identifier: Apache-2.0 +# ============================================================================== + +import numpy as np + +from spn.structure.Base import Product, Sum +from spn.structure.leaves.histogram.Histograms import Histogram +from spn.structure.leaves.parametric.Parametric import Categorical +from spn.algorithms.Inference import log_likelihood + +from spnc.cpu import CPUCompiler + + +def test_marginal_cpu_transformation_to_select(): + # Construct a minimal SPN. + h1 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=0) + h2 = Histogram([0., 1., 2.], [0.875, 0.125], [1, 1], scope=1) + c1 = Categorical(p=[0.3, 0.7], scope=0) + c2 = Categorical(p=[0.4, 0.6], scope=1) + + p0 = Product(children=[h1, h2]) + p1 = Product(children=[c1, c2]) + spn = Sum([0.3, 0.7], [p0, p1]) + + inputs = np.column_stack(( + np.random.randint(2, size=30), + np.random.randint(2, size=30), + )).astype("float64") + + # Insert some NaN in random places into the input data. + inputs.ravel()[np.random.choice(inputs.size, 5, replace=False)] = np.nan + + # Execute the compiled Kernel. + results = CPUCompiler(computeInLogSpace=False, vectorize=False).log_likelihood(spn, inputs, supportMarginal=True, batchSize=10) + + # Compute the reference results using the inference from SPFlow. + reference = log_likelihood(spn, inputs) + reference = reference.reshape(30) + + # Check the computation results against the reference + # Check in normal space if log-results are not very close to each other. + assert np.all(np.isclose(results, reference)) or np.all(np.isclose(np.exp(results), np.exp(reference))) + + +if __name__ == "__main__": + test_marginal_cpu_transformation_to_select() + print("COMPUTATION OK") diff --git a/python-interface/test/vector/test_log_vector_graph_partitioning.py b/python-interface/test/vector/test_log_vector_graph_partitioning.py index 5c45814a..e0c8c7af 100644 --- a/python-interface/test/vector/test_log_vector_graph_partitioning.py +++ b/python-interface/test/vector/test_log_vector_graph_partitioning.py @@ -20,10 +20,10 @@ @pytest.mark.skipif(not CPUCompiler.isVectorizationSupported(), reason="CPU vectorization not supported") def test_log_vector_histogram(): # Construct a minimal SPN. - h1 = Histogram([0., 1., 2.], [0.25, 0.75], [1, 1], scope=0) - h2 = Histogram([0., 1., 2.], [0.45, 0.55], [1, 1], scope=1) - h3 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=0) - h4 = Histogram([0., 1., 2.], [0.875, 0.125], [1, 1], scope=1) + h1 = Histogram([0., 1., 3.], [0.25, 0.75], [1, 1], scope=0) + h2 = Histogram([0., 3., 6., 8.], [0.45, 0.1, 0.55], [1, 1], scope=1) + h3 = Histogram([0., 1., 4.], [0.33, 0.67], [1, 1], scope=0) + h4 = Histogram([0., 5., 8.], [0.875, 0.125], [1, 1], scope=1) p0 = Product(children=[h1, h2]) p1 = Product(children=[h3, h4]) diff --git a/python-interface/test/vector/test_log_vector_histogram.py b/python-interface/test/vector/test_log_vector_histogram.py index 100a2c18..1513f899 100644 --- a/python-interface/test/vector/test_log_vector_histogram.py +++ b/python-interface/test/vector/test_log_vector_histogram.py @@ -20,10 +20,10 @@ @pytest.mark.skipif(not CPUCompiler.isVectorizationSupported(), reason="CPU vectorization not supported") def test_log_vector_histogram(): # Construct a minimal SPN. - h1 = Histogram([0., 1., 2.], [0.25, 0.75], [1, 1], scope=0) - h2 = Histogram([0., 1., 2.], [0.45, 0.55], [1, 1], scope=1) - h3 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=0) - h4 = Histogram([0., 1., 2.], [0.875, 0.125], [1, 1], scope=1) + h1 = Histogram([0., 1., 3.], [0.25, 0.75], [1, 1], scope=0) + h2 = Histogram([0., 3., 6., 8.], [0.45, 0.1, 0.55], [1, 1], scope=1) + h3 = Histogram([0., 1., 4.], [0.33, 0.67], [1, 1], scope=0) + h4 = Histogram([0., 5., 8.], [0.875, 0.125], [1, 1], scope=1) p0 = Product(children=[h1, h2]) p1 = Product(children=[h3, h4]) diff --git a/python-interface/test/vector/test_vector_histogram.py b/python-interface/test/vector/test_vector_histogram.py index 7d24dbc0..b66ebbbf 100644 --- a/python-interface/test/vector/test_vector_histogram.py +++ b/python-interface/test/vector/test_vector_histogram.py @@ -20,10 +20,10 @@ @pytest.mark.skipif(not CPUCompiler.isVectorizationSupported(), reason="CPU vectorization not supported") def test_vector_histogram(): # Construct a minimal SPN. - h1 = Histogram([0., 1., 2.], [0.25, 0.75], [1, 1], scope=0) - h2 = Histogram([0., 1., 2.], [0.45, 0.55], [1, 1], scope=1) - h3 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=0) - h4 = Histogram([0., 1., 2.], [0.875, 0.125], [1, 1], scope=1) + h1 = Histogram([0., 1., 3.], [0.25, 0.75], [1, 1], scope=0) + h2 = Histogram([0., 3., 6., 8.], [0.45, 0.1, 0.55], [1, 1], scope=1) + h3 = Histogram([0., 1., 4.], [0.33, 0.67], [1, 1], scope=0) + h4 = Histogram([0., 5., 8.], [0.875, 0.125], [1, 1], scope=1) p0 = Product(children=[h1, h2]) p1 = Product(children=[h3, h4]) diff --git a/python-interface/test/vector/test_vector_out_of_bounds.py b/python-interface/test/vector/test_vector_out_of_bounds.py new file mode 100644 index 00000000..918cf557 --- /dev/null +++ b/python-interface/test/vector/test_vector_out_of_bounds.py @@ -0,0 +1,67 @@ +# ============================================================================== +# This file is part of the SPNC project under the Apache License v2.0 by the +# Embedded Systems and Applications Group, TU Darmstadt. +# For the full copyright and license information, please view the LICENSE +# file that was distributed with this source code. +# SPDX-License-Identifier: Apache-2.0 +# ============================================================================== + +import numpy as np + +from spn.structure.Base import Product, Sum +from spn.structure.leaves.histogram.Histograms import Histogram +from spn.structure.leaves.parametric.Parametric import Categorical +from spn.algorithms.Inference import log_likelihood + +from spnc.cpu import CPUCompiler + +import pytest + + +@pytest.mark.skipif(not CPUCompiler.isVectorizationSupported(), reason="CPU vectorization not supported") +def test_vector_out_of_bounds(): + # Construct a minimal SPN. + h1 = Histogram([0., 1., 2., 3.], [0.125, 0.250, 0.625], [1, 1], scope=0) + h2 = Histogram([0., 1., 2., 3.], [0.100, 0.200, 0.700], [1, 1], scope=1) + c1 = Categorical(p=[0.1, 0.7, 0.2], scope=0) + c2 = Categorical(p=[0.4, 0.2, 0.4], scope=1) + + p0 = Product(children=[h1, h2]) + p1 = Product(children=[c1, c2]) + spn = Sum([0.3, 0.7], [p0, p1]) + + # Generate some out-of-bounds accesses + max_index = 2 + inputs = np.column_stack(( + np.random.randint(2 * max_index, size=30), + np.random.randint(2 * max_index, size=30), + )).astype("float64") + + if not CPUCompiler.isVectorizationSupported(): + print("Test not supported by the compiler installation") + return 0 + + # Execute the compiled Kernel. + results = CPUCompiler(verbose=False, computeInLogSpace=False).log_likelihood(spn, inputs, supportMarginal=False) + + # Compute the reference results using the inference from SPFlow. + reference = log_likelihood(spn, inputs) + reference = reference.reshape(30) + + # Account for SPFlow behavior: Out-of-bounds values are not returned as -inf, but: + # for Categoricals: 0.0 + # for Histograms: (np.log(np.finfo(float).eps)) + # Note: np.log(np.finfo(float).eps) is equal to: np.log(2.220446049250313e-16) = -36.04365338911715 + # Find all inputs where an out-of-bounds access will occur. + # If the access is in-bounds, the corresponding index will be 0 (otherwise: > 0). + cond = np.sum(np.where((inputs > max_index), 1, 0), axis=1) + reference[cond > 0] = -np.inf + + # Check the computation results against the reference + # Check in normal space if log-results are not very close to each other. + assert np.all(np.isclose(results, reference)) or np.all(np.isclose(np.exp(results), np.exp(reference))) + + +if __name__ == "__main__": + test_vector_out_of_bounds() + print("COMPUTATION OK")