From 01f33830838b899c06f73010b43349e791c40f94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Halkenh=C3=A4user?= Date: Mon, 31 May 2021 19:14:10 +0200 Subject: [PATCH 01/12] [WIP] Added optimization which replaces specific Categoricals by Select. LoSPN Categorical leaves are lowered to SelectOp if they consist of two probabilities. --- .../Conversion/LoSPNtoCPU/NodePatterns.h | 10 ++++ mlir/include/Dialect/LoSPN/LoSPNOps.h | 1 + mlir/include/Dialect/LoSPN/LoSPNOps.td | 21 +++++++- .../Conversion/LoSPNtoCPU/NodePatterns.cpp | 17 ++++++ mlir/lib/Dialect/LoSPN/LoSPNOps.cpp | 22 +++++++- .../select-replacement-categorical.mlir | 54 +++++++++++++++++++ 6 files changed, 123 insertions(+), 2 deletions(-) create mode 100644 mlir/test/transform/canonicalize/select-replacement-categorical.mlir diff --git a/mlir/include/Conversion/LoSPNtoCPU/NodePatterns.h b/mlir/include/Conversion/LoSPNtoCPU/NodePatterns.h index c588707d..eb8ed50d 100644 --- a/mlir/include/Conversion/LoSPNtoCPU/NodePatterns.h +++ b/mlir/include/Conversion/LoSPNtoCPU/NodePatterns.h @@ -146,6 +146,15 @@ namespace mlir { ConversionPatternRewriter& rewriter) const override; }; + struct SelectLowering : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(low::SPNSelectLeaf op, + ArrayRef operands, + ConversionPatternRewriter& rewriter) const override; + }; + struct ResolveConvertToVector : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -171,6 +180,7 @@ namespace mlir { patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); + patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); } diff --git a/mlir/include/Dialect/LoSPN/LoSPNOps.h b/mlir/include/Dialect/LoSPN/LoSPNOps.h index 4e0cb9f4..280d5364 100644 --- a/mlir/include/Dialect/LoSPN/LoSPNOps.h +++ b/mlir/include/Dialect/LoSPN/LoSPNOps.h @@ -18,6 +18,7 @@ #include "mlir/IR/FunctionSupport.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" #define GET_OP_CLASSES #include "LoSPN/LoSPNOps.h.inc" diff --git a/mlir/include/Dialect/LoSPN/LoSPNOps.td b/mlir/include/Dialect/LoSPN/LoSPNOps.td index 725f4b5e..3bee3f19 100644 --- a/mlir/include/Dialect/LoSPN/LoSPNOps.td +++ b/mlir/include/Dialect/LoSPN/LoSPNOps.td @@ -325,7 +325,6 @@ def SPNHistogramLeaf : LoSPNBodyOp<"histogram", [NoSideEffect, UI32Attr:$bucketCount, BoolAttr:$supportMarginal); let results = (outs LoSPNComputeType); - } /// @@ -343,6 +342,8 @@ def SPNCategoricalLeaf : LoSPNBodyOp<"categorical", [NoSideEffect, let arguments = (ins LoSPNInputType:$index, F64ArrayAttr:$probabilities, BoolAttr:$supportMarginal); + let hasCanonicalizeMethod = 1; + let results = (outs LoSPNComputeType); } @@ -364,4 +365,22 @@ def SPNGaussianLeaf : LoSPNBodyOp<"gaussian", [NoSideEffect, let results = (outs LoSPNComputeType); } +/// +/// Select of an SPN leaf node value. +/// +def SPNSelectLeaf : LoSPNBodyOp<"select", [NoSideEffect]> { + + let summary = "Leaf node value select"; + + let description = [{ + Single value select of a Categorical or Histogram leaf. + }]; + + let arguments = (ins LoSPNInputType:$cond, LoSPNInputType:$threshold, + LoSPNComputeType:$val_true, LoSPNComputeType:$val_false); + + let results = (outs LoSPNComputeType); + +} + #endif // LoSPN_Ops \ No newline at end of file diff --git a/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp b/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp index 1b2591cc..ae8be800 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp @@ -483,6 +483,23 @@ mlir::LogicalResult mlir::spn::CategoricalLowering::matchAndRewrite(mlir::spn::l values, resultType, "categorical_", computesLog); } +mlir::LogicalResult mlir::spn::SelectLowering::matchAndRewrite(mlir::spn::low::SPNSelectLeaf op, + llvm::ArrayRef operands, + mlir::ConversionPatternRewriter& rewriter) const { + mlir::Value cond; + if (op.cond().getType().isa()) { + cond = rewriter.create(op->getLoc(), IntegerType::get(op.getContext(), 1), + mlir::CmpFPredicate::UGE, op.cond(), op.threshold()); + } else if (op.cond().getType().isa()) { + cond = rewriter.create(op->getLoc(), IntegerType::get(op.getContext(), 1), + mlir::CmpIPredicate::uge, op.cond(), op.threshold()); + } else { + return rewriter.notifyMatchFailure(op, "Expected condition-value to be either Float- or IntegerType"); + } + rewriter.replaceOpWithNewOp(op, cond, op.val_true(), op.val_false()); + return success(); +} + mlir::LogicalResult mlir::spn::ResolveConvertToVector::matchAndRewrite(mlir::spn::low::SPNConvertToVector op, llvm::ArrayRef operands, mlir::ConversionPatternRewriter& rewriter) const { diff --git a/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp b/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp index c6323473..c614d717 100644 --- a/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp +++ b/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp @@ -13,7 +13,6 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" namespace mlir { namespace spn { @@ -335,6 +334,27 @@ ::mlir::OpFoldResult mlir::spn::low::SPNAdd::fold(::llvm::ArrayRef<::mlir::Attri return nullptr; } +//===----------------------------------------------------------------------===// +// SPNCategoricalLeaf +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::spn::low::SPNCategoricalLeaf::canonicalize(SPNCategoricalLeaf op, PatternRewriter &rewriter) { + // Rewrite Categoricals which contain exactly two probabilities into a LoSPN Select. + auto probabilities = op.probabilities().getValue(); + if (probabilities.size() == 2) { + auto p0 = probabilities[0].dyn_cast(); + auto p1 = probabilities[1].dyn_cast(); + // auto index = FloatAttr::get(FloatType::getF64(op->getContext()), op.index().dyn_cast()); + auto threshold_max_true = FloatAttr::get(op.index().getType(), 1.0); + auto p0_Value = rewriter.create(op.getLoc(), p0.getType(), probabilities[0].dyn_cast(), p0); + auto p1_Value = rewriter.create(p0_Value.getLoc(), p1.getType(), probabilities[1].dyn_cast(), p1); + auto threshold = rewriter.create(p1_Value.getLoc(), threshold_max_true.getType(), threshold_max_true.dyn_cast(), threshold_max_true); + rewriter.replaceOpWithNewOp(op, p0.getType(), op.index(), threshold, p1_Value, p0_Value); + return success(); + } + return failure(); +} + //===----------------------------------------------------------------------===// // SPNGaussianLeaf //===----------------------------------------------------------------------===// diff --git a/mlir/test/transform/canonicalize/select-replacement-categorical.mlir b/mlir/test/transform/canonicalize/select-replacement-categorical.mlir new file mode 100644 index 00000000..3ab0acba --- /dev/null +++ b/mlir/test/transform/canonicalize/select-replacement-categorical.mlir @@ -0,0 +1,54 @@ +// RUN: %optcall --canonicalize %s | FileCheck %s + +module { + "lo_spn.kernel"() ( { + ^bb0(%arg0: memref, %arg1: memref): // no predecessors + %c0 = constant 0 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.alloc(%0) : memref + "lo_spn.task"(%arg0, %1) ( { + ^bb0(%arg2: index, %arg3: memref, %arg4: memref): // no predecessors + %4 = "lo_spn.batch_read"(%arg3, %arg2) {sampleIndex = 0 : ui32} : (memref, index) -> f64 + %5 = "lo_spn.body"(%4) ( { + ^bb0(%arg5: f64): // no predecessors + %6 = "lo_spn.categorical"(%arg5) {probabilities = [3.500000e-01, 5.500000e-01, 1.000000e-01], supportMarginal = false} : (f64) -> f64 + %7 = "lo_spn.categorical"(%arg5) {probabilities = [4.500000e-01, 5.500000e-01], supportMarginal = false} : (f64) -> f64 + %8 = "lo_spn.log"(%7) : (f64) -> f64 + "lo_spn.yield"(%8) : (f64) -> () + }) : (f64) -> f64 + "lo_spn.batch_write"(%5, %arg4, %arg2) : (f64, memref, index) -> () + "lo_spn.return"() : () -> () + }) {batchSize = 12 : ui32} : (memref, memref) -> () + %2 = memref.tensor_load %1 : memref + %3 = memref.buffer_cast %2 : memref + "lo_spn.copy"(%3, %arg1) : (memref, memref) -> () + "lo_spn.return"() : () -> () + }) {sym_name = "spn_kernel", type = (memref, memref) -> ()} : () -> () +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + + +// CHECK-LABEL: "lo_spn.kernel"() ( { +// CHECK: ^bb0(%[[VAL_0:.*]]: memref, %[[VAL_1:.*]]: memref): +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = memref.dim %[[VAL_0]], %[[VAL_2]] : memref +// CHECK: %[[VAL_4:.*]] = memref.alloc(%[[VAL_3]]) : memref +// CHECK: "lo_spn.task"(%[[VAL_0]], %[[VAL_4]]) ( { +// CHECK: ^bb0(%[[VAL_5:.*]]: index, %[[VAL_6:.*]]: memref, %[[VAL_7:.*]]: memref): +// CHECK: %[[VAL_8:.*]] = "lo_spn.batch_read"(%[[VAL_6]], %[[VAL_5]]) {sampleIndex = 0 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_9:.*]] = "lo_spn.body"(%[[VAL_8]]) ( { +// CHECK: ^bb0(%[[VAL_10:.*]]: f64): +// CHECK: %[[VAL_11:.*]] = constant 4.500000e-01 : f64 +// CHECK: %[[VAL_12:.*]] = constant 5.500000e-01 : f64 +// CHECK: %[[VAL_13:.*]] = constant 1.000000e+00 : f64 +// CHECK: %[[VAL_14:.*]] = "lo_spn.select"(%[[VAL_10]], %[[VAL_13]], %[[VAL_12]], %[[VAL_11]]) : (f64, f64, f64, f64) -> f64 +// CHECK: %[[VAL_15:.*]] = "lo_spn.log"(%[[VAL_14]]) : (f64) -> f64 +// CHECK: "lo_spn.yield"(%[[VAL_15]]) : (f64) -> () +// CHECK: }) : (f64) -> f64 +// CHECK: "lo_spn.batch_write"(%[[VAL_16:.*]], %[[VAL_7]], %[[VAL_5]]) : (f64, memref, index) -> () +// CHECK: "lo_spn.return"() : () -> () +// CHECK: }) {batchSize = 12 : ui32} : (memref, memref) -> () +// CHECK: "lo_spn.copy"(%[[VAL_4]], %[[VAL_1]]) : (memref, memref) -> () +// CHECK: "lo_spn.return"() : () -> () +// CHECK: }) {sym_name = "spn_kernel", type = (memref, memref) -> ()} : () -> () From d315a35d6bc456d3c9814d8a40b1047cda0f559c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Halkenh=C3=A4user?= Date: Tue, 15 Jun 2021 17:30:24 +0200 Subject: [PATCH 02/12] [WIP] Added support for Histogram w.r.t. Select-canonicalization. - ADDED support for marginals when lowering from lospn to cpu - ADDED/FIXED support for log-computations - FIXED argument types of SPNSelectLeaf --- mlir/include/Dialect/LoSPN/LoSPNOps.td | 11 ++-- .../LoSPNtoCPU/LoSPNtoCPUConversionPasses.cpp | 1 + .../Conversion/LoSPNtoCPU/NodePatterns.cpp | 49 +++++++++++++++-- mlir/lib/Dialect/LoSPN/LoSPNOps.cpp | 53 ++++++++++++++++--- .../select-replacement-categorical.mlir | 11 ++-- .../select-replacement-histogram.mlir | 50 +++++++++++++++++ 6 files changed, 152 insertions(+), 23 deletions(-) create mode 100644 mlir/test/transform/canonicalize/select-replacement-histogram.mlir diff --git a/mlir/include/Dialect/LoSPN/LoSPNOps.td b/mlir/include/Dialect/LoSPN/LoSPNOps.td index 3bee3f19..20e8ce54 100644 --- a/mlir/include/Dialect/LoSPN/LoSPNOps.td +++ b/mlir/include/Dialect/LoSPN/LoSPNOps.td @@ -324,6 +324,8 @@ def SPNHistogramLeaf : LoSPNBodyOp<"histogram", [NoSideEffect, let arguments = (ins LoSPNInputType:$index, BucketListAttr:$buckets, UI32Attr:$bucketCount, BoolAttr:$supportMarginal); + let hasCanonicalizeMethod = 1; + let results = (outs LoSPNComputeType); } @@ -367,8 +369,11 @@ def SPNGaussianLeaf : LoSPNBodyOp<"gaussian", [NoSideEffect, /// /// Select of an SPN leaf node value. +/// Corresponds to: ($input < $input_true_threshold) ? $val_true : $val_false; /// -def SPNSelectLeaf : LoSPNBodyOp<"select", [NoSideEffect]> { +def SPNSelectLeaf : LoSPNBodyOp<"select", [NoSideEffect, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Leaf node value select"; @@ -376,8 +381,8 @@ def SPNSelectLeaf : LoSPNBodyOp<"select", [NoSideEffect]> { Single value select of a Categorical or Histogram leaf. }]; - let arguments = (ins LoSPNInputType:$cond, LoSPNInputType:$threshold, - LoSPNComputeType:$val_true, LoSPNComputeType:$val_false); + let arguments = (ins LoSPNInputType:$input, F64Attr:$input_true_threshold, + F64Attr:$val_true, F64Attr:$val_false, BoolAttr:$supportMarginal); let results = (outs LoSPNComputeType); diff --git a/mlir/lib/Conversion/LoSPNtoCPU/LoSPNtoCPUConversionPasses.cpp b/mlir/lib/Conversion/LoSPNtoCPU/LoSPNtoCPUConversionPasses.cpp index 069cf8bb..b36294a1 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/LoSPNtoCPUConversionPasses.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/LoSPNtoCPUConversionPasses.cpp @@ -137,6 +137,7 @@ void mlir::spn::LoSPNNodeVectorizationPass::runOnOperation() { target.addLegalOp(); OwningRewritePatternList patterns(&getContext()); + patterns.insert(typeConverter, &getContext()); mlir::spn::populateLoSPNCPUVectorizationNodePatterns(patterns, &getContext(), typeConverter); auto op = getOperation(); diff --git a/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp b/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp index ae8be800..70fa469f 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp @@ -321,6 +321,7 @@ mlir::LogicalResult mlir::spn::GaussianLogLowering::matchAndRewrite(mlir::spn::l Value gaussian = rewriter.create(op->getLoc(), coefficientConst, fraction); if (op.supportMarginal()) { auto isNan = rewriter.create(op->getLoc(), CmpFPredicate::UNO, index, index); + // FixMe / Question: Could this be a bug? (Either rename to 'constZero' -OR- set to 1.0 (instead of 0.0)) auto constOne = rewriter.create(op.getLoc(), rewriter.getFloatAttr(resultType, 0.0)); gaussian = rewriter.create(op.getLoc(), isNan, constOne, gaussian); } @@ -486,17 +487,55 @@ mlir::LogicalResult mlir::spn::CategoricalLowering::matchAndRewrite(mlir::spn::l mlir::LogicalResult mlir::spn::SelectLowering::matchAndRewrite(mlir::spn::low::SPNSelectLeaf op, llvm::ArrayRef operands, mlir::ConversionPatternRewriter& rewriter) const { + // If the input type is not an integer, but also not a float, we cannot convert it and this pattern fails. mlir::Value cond; - if (op.cond().getType().isa()) { + auto inputTy = op.input().getType(); + if (inputTy.isa()) { + auto thresholdAttr = FloatAttr::get(inputTy, op.input_true_thresholdAttr().getValueAsDouble()); + auto input_true_threshold = rewriter.create(op->getLoc(), inputTy, thresholdAttr); cond = rewriter.create(op->getLoc(), IntegerType::get(op.getContext(), 1), - mlir::CmpFPredicate::UGE, op.cond(), op.threshold()); - } else if (op.cond().getType().isa()) { + mlir::CmpFPredicate::ULT, op.input(), input_true_threshold); + } else if (inputTy.isa()) { + auto thresholdAttr = IntegerAttr::get(inputTy, op.input_true_thresholdAttr().getValueAsDouble()); + auto input_true_threshold = rewriter.create(op->getLoc(), inputTy, thresholdAttr); cond = rewriter.create(op->getLoc(), IntegerType::get(op.getContext(), 1), - mlir::CmpIPredicate::uge, op.cond(), op.threshold()); + mlir::CmpIPredicate::ult, op.input(), input_true_threshold); } else { return rewriter.notifyMatchFailure(op, "Expected condition-value to be either Float- or IntegerType"); } - rewriter.replaceOpWithNewOp(op, cond, op.val_true(), op.val_false()); + + Type resultType = op.getResult().getType(); + bool computesLog = false; + if (auto logType = resultType.dyn_cast()) { + resultType = logType.getBaseType(); + computesLog = true; + } + + ConstantOp val_true, val_false; + if (computesLog) { + val_true = rewriter.create(op->getLoc(), + resultType, + FloatAttr::get(resultType, log(op.val_trueAttr().getValueAsDouble()))); + val_false = rewriter.create(op->getLoc(), + resultType, + FloatAttr::get(resultType, + log(op.val_falseAttr().getValueAsDouble()))); + } else { + val_true = rewriter.create(op->getLoc(), op.val_trueAttr().getType(), op.val_trueAttr()); + val_false = rewriter.create(op->getLoc(), op.val_falseAttr().getType(), op.val_falseAttr()); + } + + mlir::Value leaf = rewriter.template create(op.getLoc(), cond, val_true, val_false); + if (op.supportMarginal()) { + assert(inputTy.template isa()); + auto isNan = rewriter.create(op->getLoc(), mlir::CmpFPredicate::UNO, op.input(), op.input()); + auto marginalValue = (computesLog) ? 0.0 : 1.0; + auto constOne = rewriter.create(op.getLoc(), + rewriter.getFloatAttr(resultType, marginalValue)); + leaf = rewriter.create(op.getLoc(), isNan, constOne, leaf); + } + rewriter.replaceOp(op, leaf); + return success(); } diff --git a/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp b/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp index c614d717..024339cd 100644 --- a/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp +++ b/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp @@ -338,23 +338,60 @@ ::mlir::OpFoldResult mlir::spn::low::SPNAdd::fold(::llvm::ArrayRef<::mlir::Attri // SPNCategoricalLeaf //===----------------------------------------------------------------------===// -::mlir::LogicalResult mlir::spn::low::SPNCategoricalLeaf::canonicalize(SPNCategoricalLeaf op, PatternRewriter &rewriter) { +::mlir::LogicalResult mlir::spn::low::SPNCategoricalLeaf::canonicalize(SPNCategoricalLeaf op, + PatternRewriter& rewriter) { // Rewrite Categoricals which contain exactly two probabilities into a LoSPN Select. auto probabilities = op.probabilities().getValue(); if (probabilities.size() == 2) { - auto p0 = probabilities[0].dyn_cast(); - auto p1 = probabilities[1].dyn_cast(); - // auto index = FloatAttr::get(FloatType::getF64(op->getContext()), op.index().dyn_cast()); + auto pTrue = probabilities[0].dyn_cast(); + auto pFalse = probabilities[1].dyn_cast(); auto threshold_max_true = FloatAttr::get(op.index().getType(), 1.0); - auto p0_Value = rewriter.create(op.getLoc(), p0.getType(), probabilities[0].dyn_cast(), p0); - auto p1_Value = rewriter.create(p0_Value.getLoc(), p1.getType(), probabilities[1].dyn_cast(), p1); - auto threshold = rewriter.create(p1_Value.getLoc(), threshold_max_true.getType(), threshold_max_true.dyn_cast(), threshold_max_true); - rewriter.replaceOpWithNewOp(op, p0.getType(), op.index(), threshold, p1_Value, p0_Value); + rewriter.replaceOpWithNewOp(op, + pTrue.getType(), + op.index(), + threshold_max_true, + pTrue, + pFalse, + op.supportMarginalAttr()); return success(); } return failure(); } +//===----------------------------------------------------------------------===// +// SPNHistogramLeaf +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::spn::low::SPNHistogramLeaf::canonicalize(SPNHistogramLeaf op, PatternRewriter& rewriter) { + // Rewrite certain Histograms which contain exactly two buckets into a LoSPN Select. + // Buckets' index range must be 1 and buckets have to be consecutive / contiguous. + // i.e.: (UB_0-LB_0 == 1) && (UB_1-LB_1 == 1) && (UB_0 == LB_1) + auto buckets = op.buckets(); + if (buckets.size() == 2) { + auto b0 = buckets[0].cast(); + auto b1 = buckets[1].cast(); + + bool isQualifiedIndexRange = ((b0.ub().getInt() - b0.lb().getInt()) == 1) && + ((b1.ub().getInt() - b1.lb().getInt()) == 1); + bool isContiguous = (b0.ub().getInt() == b1.lb().getInt()); + + if (isQualifiedIndexRange && isContiguous) { + auto pTrue = b0.val(); + auto pFalse = b1.val(); + auto threshold_max_true = FloatAttr::get(Float64Type::get(op.getContext()), b0.ub().getInt()); + rewriter.replaceOpWithNewOp(op, + pTrue.getType(), + op.index(), + threshold_max_true, + pTrue, + pFalse, + op.supportMarginalAttr()); + return success(); + } + } + return failure(); +} + //===----------------------------------------------------------------------===// // SPNGaussianLeaf //===----------------------------------------------------------------------===// diff --git a/mlir/test/transform/canonicalize/select-replacement-categorical.mlir b/mlir/test/transform/canonicalize/select-replacement-categorical.mlir index 3ab0acba..fd611656 100644 --- a/mlir/test/transform/canonicalize/select-replacement-categorical.mlir +++ b/mlir/test/transform/canonicalize/select-replacement-categorical.mlir @@ -39,14 +39,11 @@ module { // CHECK: %[[VAL_8:.*]] = "lo_spn.batch_read"(%[[VAL_6]], %[[VAL_5]]) {sampleIndex = 0 : ui32} : (memref, index) -> f64 // CHECK: %[[VAL_9:.*]] = "lo_spn.body"(%[[VAL_8]]) ( { // CHECK: ^bb0(%[[VAL_10:.*]]: f64): -// CHECK: %[[VAL_11:.*]] = constant 4.500000e-01 : f64 -// CHECK: %[[VAL_12:.*]] = constant 5.500000e-01 : f64 -// CHECK: %[[VAL_13:.*]] = constant 1.000000e+00 : f64 -// CHECK: %[[VAL_14:.*]] = "lo_spn.select"(%[[VAL_10]], %[[VAL_13]], %[[VAL_12]], %[[VAL_11]]) : (f64, f64, f64, f64) -> f64 -// CHECK: %[[VAL_15:.*]] = "lo_spn.log"(%[[VAL_14]]) : (f64) -> f64 -// CHECK: "lo_spn.yield"(%[[VAL_15]]) : (f64) -> () +// CHECK: %[[VAL_11:.*]] = "lo_spn.select"(%[[VAL_10]]) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 4.500000e-01 : f64} : (f64) -> f64 +// CHECK: %[[VAL_12:.*]] = "lo_spn.log"(%[[VAL_11]]) : (f64) -> f64 +// CHECK: "lo_spn.yield"(%[[VAL_12]]) : (f64) -> () // CHECK: }) : (f64) -> f64 -// CHECK: "lo_spn.batch_write"(%[[VAL_16:.*]], %[[VAL_7]], %[[VAL_5]]) : (f64, memref, index) -> () +// CHECK: "lo_spn.batch_write"(%[[VAL_13:.*]], %[[VAL_7]], %[[VAL_5]]) : (f64, memref, index) -> () // CHECK: "lo_spn.return"() : () -> () // CHECK: }) {batchSize = 12 : ui32} : (memref, memref) -> () // CHECK: "lo_spn.copy"(%[[VAL_4]], %[[VAL_1]]) : (memref, memref) -> () diff --git a/mlir/test/transform/canonicalize/select-replacement-histogram.mlir b/mlir/test/transform/canonicalize/select-replacement-histogram.mlir new file mode 100644 index 00000000..e0b3f0f4 --- /dev/null +++ b/mlir/test/transform/canonicalize/select-replacement-histogram.mlir @@ -0,0 +1,50 @@ +// RUN: %optcall --canonicalize %s | FileCheck %s + +module { + "lo_spn.kernel"() ( { + ^bb0(%arg0: memref, %arg1: memref): // no predecessors + %c0 = constant 0 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.alloc(%0) : memref + "lo_spn.task"(%arg0, %1) ( { + ^bb0(%arg2: index, %arg3: memref, %arg4: memref): // no predecessors + %4 = "lo_spn.batch_read"(%arg3, %arg2) {sampleIndex = 0 : ui32} : (memref, index) -> i32 + %5 = "lo_spn.body"(%4) ( { + ^bb0(%arg5: i32): // no predecessors + %6 = "lo_spn.histogram"(%arg5) {bucketCount = 2 : ui32, buckets = [{lb = 41 : i32, ub = 42 : i32, val = 2.500000e-01 : f64}, {lb = 42 : i32, ub = 43 : i32, val = 7.500000e-01 : f64}], supportMarginal = false} : (i32) -> f64 + %7 = "lo_spn.log"(%6) : (f64) -> f64 + "lo_spn.yield"(%7) : (f64) -> () + }) : (i32) -> f64 + "lo_spn.batch_write"(%5, %arg4, %arg2) : (f64, memref, index) -> () + "lo_spn.return"() : () -> () + }) {batchSize = 12 : ui32} : (memref, memref) -> () + %2 = memref.tensor_load %1 : memref + %3 = memref.buffer_cast %2 : memref + "lo_spn.copy"(%3, %arg1) : (memref, memref) -> () + "lo_spn.return"() : () -> () + }) {sym_name = "spn_kernel", type = (memref, memref) -> ()} : () -> () +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + + +// CHECK-LABEL: "lo_spn.kernel"() ( { +// CHECK: ^bb0(%[[VAL_0:.*]]: memref, %[[VAL_1:.*]]: memref): +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = memref.dim %[[VAL_0]], %[[VAL_2]] : memref +// CHECK: %[[VAL_4:.*]] = memref.alloc(%[[VAL_3]]) : memref +// CHECK: "lo_spn.task"(%[[VAL_0]], %[[VAL_4]]) ( { +// CHECK: ^bb0(%[[VAL_5:.*]]: index, %[[VAL_6:.*]]: memref, %[[VAL_7:.*]]: memref): +// CHECK: %[[VAL_8:.*]] = "lo_spn.batch_read"(%[[VAL_6]], %[[VAL_5]]) {sampleIndex = 0 : ui32} : (memref, index) -> i32 +// CHECK: %[[VAL_9:.*]] = "lo_spn.body"(%[[VAL_8]]) ( { +// CHECK: ^bb0(%[[VAL_10:.*]]: i32): +// CHECK: %[[VAL_11:.*]] = "lo_spn.select"(%[[VAL_10]]) {input_true_threshold = 4.200000e+01 : f64, supportMarginal = false, val_false = 7.500000e-01 : f64, val_true = 2.500000e-01 : f64} : (i32) -> f64 +// CHECK: %[[VAL_12:.*]] = "lo_spn.log"(%[[VAL_11]]) : (f64) -> f64 +// CHECK: "lo_spn.yield"(%[[VAL_12]]) : (f64) -> () +// CHECK: }) : (i32) -> f64 +// CHECK: "lo_spn.batch_write"(%[[VAL_13:.*]], %[[VAL_7]], %[[VAL_5]]) : (f64, memref, index) -> () +// CHECK: "lo_spn.return"() : () -> () +// CHECK: }) {batchSize = 12 : ui32} : (memref, memref) -> () +// CHECK: "lo_spn.copy"(%[[VAL_4]], %[[VAL_1]]) : (memref, memref) -> () +// CHECK: "lo_spn.return"() : () -> () +// CHECK: }) {sym_name = "spn_kernel", type = (memref, memref) -> ()} : () -> () From b59fb0c4be3b9635fa244a1089280279ba62ac51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Halkenh=C3=A4user?= Date: Mon, 28 Jun 2021 13:33:15 +0200 Subject: [PATCH 03/12] [WIP] Added stubs for vectorized transformation of SPNSelectLeaf. --- .../LoSPNtoCPU/Vectorization/VectorizationPatterns.h | 10 ++++++++++ mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp | 3 +++ .../Vectorization/VectorizeNodePatterns.cpp | 11 +++++++++++ 3 files changed, 24 insertions(+) diff --git a/mlir/include/Conversion/LoSPNtoCPU/Vectorization/VectorizationPatterns.h b/mlir/include/Conversion/LoSPNtoCPU/Vectorization/VectorizationPatterns.h index 1f25b020..55188b5f 100644 --- a/mlir/include/Conversion/LoSPNtoCPU/Vectorization/VectorizationPatterns.h +++ b/mlir/include/Conversion/LoSPNtoCPU/Vectorization/VectorizationPatterns.h @@ -149,6 +149,15 @@ namespace mlir { ConversionPatternRewriter& rewriter) const override; }; + struct VectorizeSelectLeaf : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(low::SPNSelectLeaf op, + ArrayRef operands, + ConversionPatternRewriter& rewriter) const override; + }; + static inline void populateLoSPNCPUVectorizationNodePatterns(OwningRewritePatternList& patterns, MLIRContext* context, TypeConverter& typeConverter) { @@ -159,6 +168,7 @@ namespace mlir { patterns.insert(typeConverter, context, 2); patterns.insert(typeConverter, context, 2); patterns.insert(typeConverter, context, 2); + patterns.insert(typeConverter, context, 2); } } diff --git a/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp b/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp index 70fa469f..0b7d49ef 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp @@ -487,6 +487,9 @@ mlir::LogicalResult mlir::spn::CategoricalLowering::matchAndRewrite(mlir::spn::l mlir::LogicalResult mlir::spn::SelectLowering::matchAndRewrite(mlir::spn::low::SPNSelectLeaf op, llvm::ArrayRef operands, mlir::ConversionPatternRewriter& rewriter) const { + if (op.checkVectorized()) { + return rewriter.notifyMatchFailure(op, "Pattern only matches non-vectorized SelectLeaf"); + } // If the input type is not an integer, but also not a float, we cannot convert it and this pattern fails. mlir::Value cond; auto inputTy = op.input().getType(); diff --git a/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp b/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp index 41078027..19eab198 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp @@ -608,4 +608,15 @@ mlir::LogicalResult mlir::spn::ResolveVectorizedStripLog::matchAndRewrite(low::S } rewriter.replaceOp(op, operands[0]); return success(); +} + +mlir::LogicalResult mlir::spn::VectorizeSelectLeaf::matchAndRewrite(mlir::spn::low::SPNSelectLeaf op, + llvm::ArrayRef operands, + mlir::ConversionPatternRewriter& rewriter) const { + // Replace the vectorized version of a BatchRead with a Gather load from the input memref. + if (!op.checkVectorized()) { + return rewriter.notifyMatchFailure(op, "Pattern only matches vectorized Select"); + } + + return rewriter.notifyMatchFailure(op, "Pattern matched a vectorized Select"); } \ No newline at end of file From 9cb2ed06aa702fca4e67e2ef9b78b3266fd2ba67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Halkenh=C3=A4user?= Date: Mon, 28 Jun 2021 22:07:44 +0200 Subject: [PATCH 04/12] [WIP] Implemented first iteration of vectorized SPNSelectleaf pattern. --- .../Vectorization/VectorizeNodePatterns.cpp | 63 ++++++++++++++++++- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp b/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp index c46c8c2a..d124a608 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp @@ -659,10 +659,69 @@ mlir::LogicalResult mlir::spn::ResolveVectorizedConvertLog::matchAndRewrite(mlir mlir::LogicalResult mlir::spn::VectorizeSelectLeaf::matchAndRewrite(mlir::spn::low::SPNSelectLeaf op, llvm::ArrayRef operands, mlir::ConversionPatternRewriter& rewriter) const { - // Replace the vectorized version of a BatchRead with a Gather load from the input memref. if (!op.checkVectorized()) { return rewriter.notifyMatchFailure(op, "Pattern only matches vectorized Select"); } + if (op.getResult().getType().isa()) { + return rewriter.notifyMatchFailure(op, "Pattern does not match for log-space computation"); + } - return rewriter.notifyMatchFailure(op, "Pattern matched a vectorized Select"); + auto inputVec = operands.front(); + auto inputVecTy = inputVec.getType(); + + // Input should be a vector + if (!inputVecTy.isa()) { + return rewriter.notifyMatchFailure(op, "Vectorization pattern did not match, input was not a vector"); + } + + auto inputTy = op.input().getType(); + VectorType vectorType = inputVecTy.dyn_cast(); + + // If the input type is not an integer, but also not a float, we cannot convert it and this pattern fails. + mlir::Value cond; + auto booleanVectorTy = VectorType::get(vectorType.getShape(), IntegerType::get(op.getContext(), 1)); + Value thresholdVec = + broadcastVectorConstant(vectorType, op.input_true_thresholdAttr().getValueAsDouble(), rewriter, op.getLoc()); + if (inputTy.isa()) { + cond = + rewriter.create(op->getLoc(), booleanVectorTy, mlir::CmpFPredicate::ULT, inputVec, thresholdVec); + } else if (inputTy.isa()) { + // Convert from floating-point input to integer value if necessary. + // This conversion is also possible in vectorized mode. + auto intVectorTy = VectorType::get(vectorType.getShape(), inputTy); + thresholdVec = rewriter.create(op->getLoc(), thresholdVec, intVectorTy); + cond = + rewriter.create(op->getLoc(), booleanVectorTy, mlir::CmpIPredicate::ult, inputVec, thresholdVec); + } else { + return rewriter.notifyMatchFailure(op, "Expected condition-value to be either Float- or IntegerType"); + } + + Type resultType = op.getResult().getType(); + bool computesLog = false; + if (auto logType = resultType.dyn_cast()) { + resultType = logType.getBaseType(); + computesLog = true; + } + + ConstantOp val_true, val_false; + if (computesLog) { + auto logVecTy = VectorType::get(vectorType.getShape(), resultType); + val_true = broadcastVectorConstant(logVecTy, log(op.val_trueAttr().getValueAsDouble()), rewriter, op.getLoc()); + val_false = broadcastVectorConstant(logVecTy, log(op.val_falseAttr().getValueAsDouble()), rewriter, op.getLoc()); + } else { + val_true = broadcastVectorConstant(vectorType, op.val_trueAttr().getValueAsDouble(), rewriter, op.getLoc()); + val_false = broadcastVectorConstant(vectorType, op.val_falseAttr().getValueAsDouble(), rewriter, op.getLoc()); + } + + mlir::Value leaf = rewriter.template create(op.getLoc(), cond, val_true, val_false); + if (op.supportMarginal()) { + assert(inputTy.template isa()); + auto isNan = rewriter.create(op->getLoc(), CmpFPredicate::UNO, inputVec, inputVec); + auto marginalValue = (computesLog) ? 0.0 : 1.0; + auto constOne = broadcastVectorConstant(vectorType, marginalValue, rewriter, op.getLoc()); + leaf = rewriter.create(op.getLoc(), isNan, constOne, leaf); + } + rewriter.replaceOp(op, leaf); + + return success(); } \ No newline at end of file From 1e5f758339c64a9e228d3e51e3e946ca04acc1fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Halkenh=C3=A4user?= Date: Thu, 1 Jul 2021 23:51:26 +0200 Subject: [PATCH 05/12] [WIP] [FIX] SPNSelectLeaf should now emit the correct type when using log-space. --- mlir/lib/Dialect/LoSPN/LoSPNOps.cpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp b/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp index ef5268cb..ea39a9bd 100644 --- a/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp +++ b/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp @@ -378,8 +378,14 @@ ::mlir::LogicalResult mlir::spn::low::SPNCategoricalLeaf::canonicalize(SPNCatego auto pTrue = probabilities[0].dyn_cast(); auto pFalse = probabilities[1].dyn_cast(); auto threshold_max_true = FloatAttr::get(op.index().getType(), 1.0); + + mlir::Type outputTy = pTrue.getType(); + if (auto outputLogType = op.getResult().getType().dyn_cast()) { + outputTy = outputLogType; + } + rewriter.replaceOpWithNewOp(op, - pTrue.getType(), + outputTy, op.index(), threshold_max_true, pTrue, @@ -411,8 +417,14 @@ ::mlir::LogicalResult mlir::spn::low::SPNHistogramLeaf::canonicalize(SPNHistogra auto pTrue = b0.val(); auto pFalse = b1.val(); auto threshold_max_true = FloatAttr::get(Float64Type::get(op.getContext()), b0.ub().getInt()); + + mlir::Type outputTy = pTrue.getType(); + if (auto outputLogType = op.getResult().getType().dyn_cast()) { + outputTy = outputLogType; + } + rewriter.replaceOpWithNewOp(op, - pTrue.getType(), + outputTy, op.index(), threshold_max_true, pTrue, From 68cf2701350530dfcdefb67aa106ef9325c0ed91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Halkenh=C3=A4user?= Date: Fri, 2 Jul 2021 11:32:54 +0200 Subject: [PATCH 06/12] [WIP] [FIX] Removed faulty/unnecessary log-space check. --- .../Vectorization/VectorizeNodePatterns.cpp | 3 --- mlir/lib/Dialect/LoSPN/LoSPNOps.cpp | 14 ++------------ 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp b/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp index d124a608..4f8712e4 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp @@ -662,9 +662,6 @@ mlir::LogicalResult mlir::spn::VectorizeSelectLeaf::matchAndRewrite(mlir::spn::l if (!op.checkVectorized()) { return rewriter.notifyMatchFailure(op, "Pattern only matches vectorized Select"); } - if (op.getResult().getType().isa()) { - return rewriter.notifyMatchFailure(op, "Pattern does not match for log-space computation"); - } auto inputVec = operands.front(); auto inputVecTy = inputVec.getType(); diff --git a/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp b/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp index ea39a9bd..93fdd157 100644 --- a/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp +++ b/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp @@ -379,13 +379,8 @@ ::mlir::LogicalResult mlir::spn::low::SPNCategoricalLeaf::canonicalize(SPNCatego auto pFalse = probabilities[1].dyn_cast(); auto threshold_max_true = FloatAttr::get(op.index().getType(), 1.0); - mlir::Type outputTy = pTrue.getType(); - if (auto outputLogType = op.getResult().getType().dyn_cast()) { - outputTy = outputLogType; - } - rewriter.replaceOpWithNewOp(op, - outputTy, + op.getResult().getType(), op.index(), threshold_max_true, pTrue, @@ -418,13 +413,8 @@ ::mlir::LogicalResult mlir::spn::low::SPNHistogramLeaf::canonicalize(SPNHistogra auto pFalse = b1.val(); auto threshold_max_true = FloatAttr::get(Float64Type::get(op.getContext()), b0.ub().getInt()); - mlir::Type outputTy = pTrue.getType(); - if (auto outputLogType = op.getResult().getType().dyn_cast()) { - outputTy = outputLogType; - } - rewriter.replaceOpWithNewOp(op, - outputTy, + op.getResult().getType(), op.index(), threshold_max_true, pTrue, From 9045d7f9f5b1aefafedb91352101225b1f57ad15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Halkenh=C3=A4user?= Date: Thu, 22 Jul 2021 18:14:23 +0200 Subject: [PATCH 07/12] [RC] Added tests for lospn.select transformations. - Test if lospn.select is replaced by llvm.select (scalar and vectorized) --- .../lower-to-cpu-nodes-select-vectorize.mlir | 265 ++++++++++++++++++ .../lower-to-cpu-nodes-select.mlir | 134 +++++++++ 2 files changed, 399 insertions(+) create mode 100644 mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select-vectorize.mlir create mode 100644 mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select-vectorize.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select-vectorize.mlir new file mode 100644 index 00000000..580ca32e --- /dev/null +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select-vectorize.mlir @@ -0,0 +1,265 @@ +// RUN: %optcall --vectorize-lospn-nodes %s | FileCheck %s + +module { + func @vec_task_0(%arg0: memref, %arg1: memref) { + %c0 = constant 0 : index + %0 = memref.dim %arg0, %c0 : memref + %c4 = constant 4 : index + %1 = remi_unsigned %0, %c4 : index + %2 = subi %0, %1 : index + %c0_0 = constant 0 : index + %c4_1 = constant 4 : index + scf.for %arg2 = %c0_0 to %2 step %c4_1 { + %3 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 0 : ui32, vector_width = 4 : i32} : (memref, index) -> f64 + %4 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 1 : ui32, vector_width = 4 : i32} : (memref, index) -> f64 + %5 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 2 : ui32, vector_width = 4 : i32} : (memref, index) -> f64 + %6 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 3 : ui32, vector_width = 4 : i32} : (memref, index) -> f64 + %7 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 4 : ui32, vector_width = 4 : i32} : (memref, index) -> f64 + %8 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 5 : ui32, vector_width = 4 : i32} : (memref, index) -> f64 + %cst = constant 1.000000e-01 : f64 + %9 = "lo_spn.select"(%3) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 3.500000e-01 : f64, vector_width = 4 : i32} : (f64) -> f64 + %10 = "lo_spn.categorical"(%4) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false, vector_width = 4 : i32} : (f64) -> f64 + %11 = "lo_spn.select"(%5) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 7.500000e-01 : f64, val_true = 2.500000e-01 : f64, vector_width = 4 : i32} : (f64) -> f64 + %12 = "lo_spn.select"(%6) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 4.500000e-01 : f64, vector_width = 4 : i32} : (f64) -> f64 + %13 = "lo_spn.gaussian"(%7) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false, vector_width = 4 : i32} : (f64) -> f64 + %14 = "lo_spn.gaussian"(%8) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false, vector_width = 4 : i32} : (f64) -> f64 + %15 = "lo_spn.mul"(%9, %10) {vector_width = 4 : i32} : (f64, f64) -> f64 + %16 = "lo_spn.mul"(%15, %11) {vector_width = 4 : i32} : (f64, f64) -> f64 + %17 = "lo_spn.mul"(%16, %cst) {vector_width = 4 : i32} : (f64, f64) -> f64 + %18 = "lo_spn.mul"(%12, %13) {vector_width = 4 : i32} : (f64, f64) -> f64 + %19 = "lo_spn.mul"(%18, %14) {vector_width = 4 : i32} : (f64, f64) -> f64 + %20 = "lo_spn.mul"(%19, %cst) {vector_width = 4 : i32} : (f64, f64) -> f64 + %21 = "lo_spn.add"(%17, %20) {vector_width = 4 : i32} : (f64, f64) -> f64 + %22 = "lo_spn.log"(%21) {vector_width = 4 : i32} : (f64) -> f64 + "lo_spn.batch_write"(%22, %arg1, %arg2) {vector_width = 4 : i32} : (f64, memref, index) -> () + } + %c1 = constant 1 : index + scf.for %arg2 = %2 to %0 step %c1 { + %3 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 0 : ui32} : (memref, index) -> f64 + %4 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 1 : ui32} : (memref, index) -> f64 + %5 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 2 : ui32} : (memref, index) -> f64 + %6 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 3 : ui32} : (memref, index) -> f64 + %7 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 4 : ui32} : (memref, index) -> f64 + %8 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 5 : ui32} : (memref, index) -> f64 + %cst = constant 1.000000e-01 : f64 + %9 = "lo_spn.select"(%3) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 3.500000e-01 : f64} : (f64) -> f64 + %10 = "lo_spn.categorical"(%4) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false} : (f64) -> f64 + %11 = "lo_spn.select"(%5) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 7.500000e-01 : f64, val_true = 2.500000e-01 : f64} : (f64) -> f64 + %12 = "lo_spn.select"(%6) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 4.500000e-01 : f64} : (f64) -> f64 + %13 = "lo_spn.gaussian"(%7) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false} : (f64) -> f64 + %14 = "lo_spn.gaussian"(%8) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false} : (f64) -> f64 + %15 = "lo_spn.mul"(%9, %10) : (f64, f64) -> f64 + %16 = "lo_spn.mul"(%15, %11) : (f64, f64) -> f64 + %17 = "lo_spn.mul"(%16, %cst) : (f64, f64) -> f64 + %18 = "lo_spn.mul"(%12, %13) : (f64, f64) -> f64 + %19 = "lo_spn.mul"(%18, %14) : (f64, f64) -> f64 + %20 = "lo_spn.mul"(%19, %cst) : (f64, f64) -> f64 + %21 = "lo_spn.add"(%17, %20) : (f64, f64) -> f64 + %22 = "lo_spn.log"(%21) : (f64) -> f64 + "lo_spn.batch_write"(%22, %arg1, %arg2) : (f64, memref, index) -> () + } + return + } + func @spn_vector(%arg0: memref, %arg1: memref) { + %c0 = constant 0 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.alloc(%0) : memref + call @vec_task_0(%arg0, %1) : (memref, memref) -> () + "lo_spn.copy"(%1, %arg1) : (memref, memref) -> () + "lo_spn.return"() : () -> () + } +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// CHECK-LABEL: memref.global "private" constant @categorical_vec_0 : memref<3xf64> = dense<[2.500000e-01, 6.250000e-01, 1.250000e-01]> + +// CHECK-LABEL: func @vec_task_0( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref) { +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = memref.dim %[[VAL_0]], %[[VAL_2]] : memref +// CHECK: %[[VAL_4:.*]] = constant 4 : index +// CHECK: %[[VAL_5:.*]] = remi_unsigned %[[VAL_3]], %[[VAL_4]] : index +// CHECK: %[[VAL_6:.*]] = subi %[[VAL_3]], %[[VAL_5]] : index +// CHECK: %[[VAL_7:.*]] = constant 0 : index +// CHECK: %[[VAL_8:.*]] = constant 4 : index +// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] { +// CHECK: %[[VAL_10:.*]] = index_cast %[[VAL_9]] : index to i64 +// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : i64 to vector<4xi64> +// CHECK: %[[VAL_12:.*]] = constant dense<[0, 6, 12, 18]> : vector<4xi64> +// CHECK: %[[VAL_13:.*]] = constant dense<6> : vector<4xi64> +// CHECK: %[[VAL_14:.*]] = muli %[[VAL_11]], %[[VAL_13]] : vector<4xi64> +// CHECK: %[[VAL_15:.*]] = addi %[[VAL_14]], %[[VAL_12]] : vector<4xi64> +// CHECK: %[[VAL_16:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_17:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_18:.*]] = constant 0 : index +// CHECK: %[[VAL_19:.*]] = memref.dim %[[VAL_0]], %[[VAL_18]] : memref +// CHECK: %[[VAL_20:.*]] = constant 6 : index +// CHECK: %[[VAL_21:.*]] = muli %[[VAL_19]], %[[VAL_20]] : index +// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: {{\[}}%[[VAL_21]]], strides: [1] : memref to memref +// CHECK: %[[VAL_23:.*]] = constant 0 : index +// CHECK: %[[VAL_24:.*]] = vector.gather %[[VAL_22]]{{\[}}%[[VAL_23]]] {{\[}}%[[VAL_15]]], %[[VAL_17]], %[[VAL_16]] : memref, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_25:.*]] = index_cast %[[VAL_9]] : index to i64 +// CHECK: %[[VAL_26:.*]] = vector.broadcast %[[VAL_25]] : i64 to vector<4xi64> +// CHECK: %[[VAL_27:.*]] = constant dense<[1, 7, 13, 19]> : vector<4xi64> +// CHECK: %[[VAL_28:.*]] = constant dense<6> : vector<4xi64> +// CHECK: %[[VAL_29:.*]] = muli %[[VAL_26]], %[[VAL_28]] : vector<4xi64> +// CHECK: %[[VAL_30:.*]] = addi %[[VAL_29]], %[[VAL_27]] : vector<4xi64> +// CHECK: %[[VAL_31:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_32:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_33:.*]] = constant 0 : index +// CHECK: %[[VAL_34:.*]] = memref.dim %[[VAL_0]], %[[VAL_33]] : memref +// CHECK: %[[VAL_35:.*]] = constant 6 : index +// CHECK: %[[VAL_36:.*]] = muli %[[VAL_34]], %[[VAL_35]] : index +// CHECK: %[[VAL_37:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: {{\[}}%[[VAL_36]]], strides: [1] : memref to memref +// CHECK: %[[VAL_38:.*]] = constant 0 : index +// CHECK: %[[VAL_39:.*]] = vector.gather %[[VAL_37]]{{\[}}%[[VAL_38]]] {{\[}}%[[VAL_30]]], %[[VAL_32]], %[[VAL_31]] : memref, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_40:.*]] = index_cast %[[VAL_9]] : index to i64 +// CHECK: %[[VAL_41:.*]] = vector.broadcast %[[VAL_40]] : i64 to vector<4xi64> +// CHECK: %[[VAL_42:.*]] = constant dense<[2, 8, 14, 20]> : vector<4xi64> +// CHECK: %[[VAL_43:.*]] = constant dense<6> : vector<4xi64> +// CHECK: %[[VAL_44:.*]] = muli %[[VAL_41]], %[[VAL_43]] : vector<4xi64> +// CHECK: %[[VAL_45:.*]] = addi %[[VAL_44]], %[[VAL_42]] : vector<4xi64> +// CHECK: %[[VAL_46:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_47:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_48:.*]] = constant 0 : index +// CHECK: %[[VAL_49:.*]] = memref.dim %[[VAL_0]], %[[VAL_48]] : memref +// CHECK: %[[VAL_50:.*]] = constant 6 : index +// CHECK: %[[VAL_51:.*]] = muli %[[VAL_49]], %[[VAL_50]] : index +// CHECK: %[[VAL_52:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: {{\[}}%[[VAL_51]]], strides: [1] : memref to memref +// CHECK: %[[VAL_53:.*]] = constant 0 : index +// CHECK: %[[VAL_54:.*]] = vector.gather %[[VAL_52]]{{\[}}%[[VAL_53]]] {{\[}}%[[VAL_45]]], %[[VAL_47]], %[[VAL_46]] : memref, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_55:.*]] = index_cast %[[VAL_9]] : index to i64 +// CHECK: %[[VAL_56:.*]] = vector.broadcast %[[VAL_55]] : i64 to vector<4xi64> +// CHECK: %[[VAL_57:.*]] = constant dense<[3, 9, 15, 21]> : vector<4xi64> +// CHECK: %[[VAL_58:.*]] = constant dense<6> : vector<4xi64> +// CHECK: %[[VAL_59:.*]] = muli %[[VAL_56]], %[[VAL_58]] : vector<4xi64> +// CHECK: %[[VAL_60:.*]] = addi %[[VAL_59]], %[[VAL_57]] : vector<4xi64> +// CHECK: %[[VAL_61:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_62:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_63:.*]] = constant 0 : index +// CHECK: %[[VAL_64:.*]] = memref.dim %[[VAL_0]], %[[VAL_63]] : memref +// CHECK: %[[VAL_65:.*]] = constant 6 : index +// CHECK: %[[VAL_66:.*]] = muli %[[VAL_64]], %[[VAL_65]] : index +// CHECK: %[[VAL_67:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: {{\[}}%[[VAL_66]]], strides: [1] : memref to memref +// CHECK: %[[VAL_68:.*]] = constant 0 : index +// CHECK: %[[VAL_69:.*]] = vector.gather %[[VAL_67]]{{\[}}%[[VAL_68]]] {{\[}}%[[VAL_60]]], %[[VAL_62]], %[[VAL_61]] : memref, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_70:.*]] = index_cast %[[VAL_9]] : index to i64 +// CHECK: %[[VAL_71:.*]] = vector.broadcast %[[VAL_70]] : i64 to vector<4xi64> +// CHECK: %[[VAL_72:.*]] = constant dense<[4, 10, 16, 22]> : vector<4xi64> +// CHECK: %[[VAL_73:.*]] = constant dense<6> : vector<4xi64> +// CHECK: %[[VAL_74:.*]] = muli %[[VAL_71]], %[[VAL_73]] : vector<4xi64> +// CHECK: %[[VAL_75:.*]] = addi %[[VAL_74]], %[[VAL_72]] : vector<4xi64> +// CHECK: %[[VAL_76:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_77:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_78:.*]] = constant 0 : index +// CHECK: %[[VAL_79:.*]] = memref.dim %[[VAL_0]], %[[VAL_78]] : memref +// CHECK: %[[VAL_80:.*]] = constant 6 : index +// CHECK: %[[VAL_81:.*]] = muli %[[VAL_79]], %[[VAL_80]] : index +// CHECK: %[[VAL_82:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: {{\[}}%[[VAL_81]]], strides: [1] : memref to memref +// CHECK: %[[VAL_83:.*]] = constant 0 : index +// CHECK: %[[VAL_84:.*]] = vector.gather %[[VAL_82]]{{\[}}%[[VAL_83]]] {{\[}}%[[VAL_75]]], %[[VAL_77]], %[[VAL_76]] : memref, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_85:.*]] = index_cast %[[VAL_9]] : index to i64 +// CHECK: %[[VAL_86:.*]] = vector.broadcast %[[VAL_85]] : i64 to vector<4xi64> +// CHECK: %[[VAL_87:.*]] = constant dense<[5, 11, 17, 23]> : vector<4xi64> +// CHECK: %[[VAL_88:.*]] = constant dense<6> : vector<4xi64> +// CHECK: %[[VAL_89:.*]] = muli %[[VAL_86]], %[[VAL_88]] : vector<4xi64> +// CHECK: %[[VAL_90:.*]] = addi %[[VAL_89]], %[[VAL_87]] : vector<4xi64> +// CHECK: %[[VAL_91:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_92:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_93:.*]] = constant 0 : index +// CHECK: %[[VAL_94:.*]] = memref.dim %[[VAL_0]], %[[VAL_93]] : memref +// CHECK: %[[VAL_95:.*]] = constant 6 : index +// CHECK: %[[VAL_96:.*]] = muli %[[VAL_94]], %[[VAL_95]] : index +// CHECK: %[[VAL_97:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: {{\[}}%[[VAL_96]]], strides: [1] : memref to memref +// CHECK: %[[VAL_98:.*]] = constant 0 : index +// CHECK: %[[VAL_99:.*]] = vector.gather %[[VAL_97]]{{\[}}%[[VAL_98]]] {{\[}}%[[VAL_90]]], %[[VAL_92]], %[[VAL_91]] : memref, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_100:.*]] = constant 1.000000e-01 : f64 +// CHECK: %[[VAL_101:.*]] = constant dense<1.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_102:.*]] = cmpf ult, %[[VAL_24]], %[[VAL_101]] : vector<4xf64> +// CHECK: %[[VAL_103:.*]] = constant dense<3.500000e-01> : vector<4xf64> +// CHECK: %[[VAL_104:.*]] = constant dense<5.500000e-01> : vector<4xf64> +// CHECK: %[[VAL_105:.*]] = select %[[VAL_102]], %[[VAL_103]], %[[VAL_104]] : vector<4xi1>, vector<4xf64> +// CHECK: %[[VAL_106:.*]] = memref.get_global @categorical_vec_0 : memref<3xf64> +// CHECK: %[[VAL_107:.*]] = fptoui %[[VAL_39]] : vector<4xf64> to vector<4xi64> +// CHECK: %[[VAL_108:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_109:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_110:.*]] = constant 0 : index +// CHECK: %[[VAL_111:.*]] = vector.gather %[[VAL_106]]{{\[}}%[[VAL_110]]] {{\[}}%[[VAL_107]]], %[[VAL_109]], %[[VAL_108]] : memref<3xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_112:.*]] = constant dense<1.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_113:.*]] = cmpf ult, %[[VAL_54]], %[[VAL_112]] : vector<4xf64> +// CHECK: %[[VAL_114:.*]] = constant dense<2.500000e-01> : vector<4xf64> +// CHECK: %[[VAL_115:.*]] = constant dense<7.500000e-01> : vector<4xf64> +// CHECK: %[[VAL_116:.*]] = select %[[VAL_113]], %[[VAL_114]], %[[VAL_115]] : vector<4xi1>, vector<4xf64> +// CHECK: %[[VAL_117:.*]] = constant dense<1.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_118:.*]] = cmpf ult, %[[VAL_69]], %[[VAL_117]] : vector<4xf64> +// CHECK: %[[VAL_119:.*]] = constant dense<4.500000e-01> : vector<4xf64> +// CHECK: %[[VAL_120:.*]] = constant dense<5.500000e-01> : vector<4xf64> +// CHECK: %[[VAL_121:.*]] = select %[[VAL_118]], %[[VAL_119]], %[[VAL_120]] : vector<4xi1>, vector<4xf64> +// CHECK: %[[VAL_122:.*]] = constant dense<0.3989422804014327> : vector<4xf64> +// CHECK: %[[VAL_123:.*]] = constant dense<-5.000000e-01> : vector<4xf64> +// CHECK: %[[VAL_124:.*]] = constant dense<5.000000e-01> : vector<4xf64> +// CHECK: %[[VAL_125:.*]] = subf %[[VAL_84]], %[[VAL_124]] : vector<4xf64> +// CHECK: %[[VAL_126:.*]] = mulf %[[VAL_125]], %[[VAL_125]] : vector<4xf64> +// CHECK: %[[VAL_127:.*]] = mulf %[[VAL_126]], %[[VAL_123]] : vector<4xf64> +// CHECK: %[[VAL_128:.*]] = math.exp %[[VAL_127]] : vector<4xf64> +// CHECK: %[[VAL_129:.*]] = mulf %[[VAL_122]], %[[VAL_128]] : vector<4xf64> +// CHECK: %[[VAL_130:.*]] = constant dense<3.9894228040143269> : vector<4xf64> +// CHECK: %[[VAL_131:.*]] = constant dense<-49.999999999999993> : vector<4xf64> +// CHECK: %[[VAL_132:.*]] = constant dense<2.500000e-01> : vector<4xf64> +// CHECK: %[[VAL_133:.*]] = subf %[[VAL_99]], %[[VAL_132]] : vector<4xf64> +// CHECK: %[[VAL_134:.*]] = mulf %[[VAL_133]], %[[VAL_133]] : vector<4xf64> +// CHECK: %[[VAL_135:.*]] = mulf %[[VAL_134]], %[[VAL_131]] : vector<4xf64> +// CHECK: %[[VAL_136:.*]] = math.exp %[[VAL_135]] : vector<4xf64> +// CHECK: %[[VAL_137:.*]] = mulf %[[VAL_130]], %[[VAL_136]] : vector<4xf64> +// CHECK: %[[VAL_138:.*]] = mulf %[[VAL_105]], %[[VAL_111]] : vector<4xf64> +// CHECK: %[[VAL_139:.*]] = mulf %[[VAL_138]], %[[VAL_116]] : vector<4xf64> +// CHECK: %[[VAL_140:.*]] = "lo_spn.to_vector"(%[[VAL_100]]) : (f64) -> vector<4xf64> +// CHECK: %[[VAL_141:.*]] = mulf %[[VAL_139]], %[[VAL_140]] : vector<4xf64> +// CHECK: %[[VAL_142:.*]] = mulf %[[VAL_121]], %[[VAL_129]] : vector<4xf64> +// CHECK: %[[VAL_143:.*]] = mulf %[[VAL_142]], %[[VAL_137]] : vector<4xf64> +// CHECK: %[[VAL_144:.*]] = "lo_spn.to_vector"(%[[VAL_100]]) : (f64) -> vector<4xf64> +// CHECK: %[[VAL_145:.*]] = mulf %[[VAL_143]], %[[VAL_144]] : vector<4xf64> +// CHECK: %[[VAL_146:.*]] = addf %[[VAL_141]], %[[VAL_145]] : vector<4xf64> +// CHECK: %[[VAL_147:.*]] = math.log %[[VAL_146]] : vector<4xf64> +// CHECK: vector.transfer_write %[[VAL_147]], %[[VAL_1]]{{\[}}%[[VAL_9]]] : vector<4xf64>, memref +// CHECK: } +// CHECK: %[[VAL_148:.*]] = constant 1 : index +// CHECK: scf.for %[[VAL_149:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_148]] { +// CHECK: %[[VAL_150:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_149]]) {sampleIndex = 0 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_151:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_149]]) {sampleIndex = 1 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_152:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_149]]) {sampleIndex = 2 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_153:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_149]]) {sampleIndex = 3 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_154:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_149]]) {sampleIndex = 4 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_155:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_149]]) {sampleIndex = 5 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_156:.*]] = constant 1.000000e-01 : f64 +// CHECK: %[[VAL_157:.*]] = "lo_spn.select"(%[[VAL_150]]) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 3.500000e-01 : f64} : (f64) -> f64 +// CHECK: %[[VAL_158:.*]] = "lo_spn.categorical"(%[[VAL_151]]) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_159:.*]] = "lo_spn.select"(%[[VAL_152]]) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 7.500000e-01 : f64, val_true = 2.500000e-01 : f64} : (f64) -> f64 +// CHECK: %[[VAL_160:.*]] = "lo_spn.select"(%[[VAL_153]]) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 4.500000e-01 : f64} : (f64) -> f64 +// CHECK: %[[VAL_161:.*]] = "lo_spn.gaussian"(%[[VAL_154]]) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_162:.*]] = "lo_spn.gaussian"(%[[VAL_155]]) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_163:.*]] = "lo_spn.mul"(%[[VAL_157]], %[[VAL_158]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_164:.*]] = "lo_spn.mul"(%[[VAL_163]], %[[VAL_159]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_165:.*]] = "lo_spn.mul"(%[[VAL_164]], %[[VAL_156]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_166:.*]] = "lo_spn.mul"(%[[VAL_160]], %[[VAL_161]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_167:.*]] = "lo_spn.mul"(%[[VAL_166]], %[[VAL_162]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_168:.*]] = "lo_spn.mul"(%[[VAL_167]], %[[VAL_156]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_169:.*]] = "lo_spn.add"(%[[VAL_165]], %[[VAL_168]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_170:.*]] = "lo_spn.log"(%[[VAL_169]]) : (f64) -> f64 +// CHECK: "lo_spn.batch_write"(%[[VAL_170]], %[[VAL_1]], %[[VAL_149]]) : (f64, memref, index) -> () +// CHECK: } +// CHECK: return +// CHECK: } + +// CHECK-LABEL: func @spn_vector( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref) { +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = memref.dim %[[VAL_0]], %[[VAL_2]] : memref +// CHECK: %[[VAL_4:.*]] = memref.alloc(%[[VAL_3]]) : memref +// CHECK: call @vec_task_0(%[[VAL_0]], %[[VAL_4]]) : (memref, memref) -> () +// CHECK: "lo_spn.copy"(%[[VAL_4]], %[[VAL_1]]) : (memref, memref) -> () +// CHECK: "lo_spn.return"() : () -> () +// CHECK: } diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir new file mode 100644 index 00000000..40ef3270 --- /dev/null +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir @@ -0,0 +1,134 @@ +// RUN: %optcall --convert-lospn-nodes-to-cpu %s | FileCheck %s + +module { + func @task_0(%arg0: memref, %arg1: memref) { + %c0 = constant 0 : index + %c0_0 = constant 0 : index + %0 = memref.dim %arg0, %c0_0 : memref + %c1 = constant 1 : index + scf.for %arg2 = %c0 to %0 step %c1 { + %1 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 0 : ui32} : (memref, index) -> f64 + %2 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 1 : ui32} : (memref, index) -> f64 + %3 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 2 : ui32} : (memref, index) -> f64 + %4 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 3 : ui32} : (memref, index) -> f64 + %5 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 4 : ui32} : (memref, index) -> f64 + %6 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 5 : ui32} : (memref, index) -> f64 + %cst = constant 1.000000e-01 : f64 + %7 = "lo_spn.select"(%1) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 3.500000e-01 : f64} : (f64) -> f64 + %8 = "lo_spn.categorical"(%2) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false} : (f64) -> f64 + %9 = "lo_spn.select"(%3) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 7.500000e-01 : f64, val_true = 2.500000e-01 : f64} : (f64) -> f64 + %10 = "lo_spn.select"(%4) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 4.500000e-01 : f64} : (f64) -> f64 + %11 = "lo_spn.gaussian"(%5) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false} : (f64) -> f64 + %12 = "lo_spn.gaussian"(%6) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false} : (f64) -> f64 + %13 = "lo_spn.mul"(%7, %8) : (f64, f64) -> f64 + %14 = "lo_spn.mul"(%13, %9) : (f64, f64) -> f64 + %15 = "lo_spn.mul"(%14, %cst) : (f64, f64) -> f64 + %16 = "lo_spn.mul"(%10, %11) : (f64, f64) -> f64 + %17 = "lo_spn.mul"(%16, %12) : (f64, f64) -> f64 + %18 = "lo_spn.mul"(%17, %cst) : (f64, f64) -> f64 + %19 = "lo_spn.add"(%15, %18) : (f64, f64) -> f64 + %20 = "lo_spn.log"(%19) : (f64) -> f64 + "lo_spn.batch_write"(%20, %arg1, %arg2) : (f64, memref, index) -> () + } + return + } + func @spn_vector(%arg0: memref, %arg1: memref) { + %c0 = constant 0 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.alloc(%0) : memref + call @task_0(%arg0, %1) : (memref, memref) -> () + "lo_spn.copy"(%1, %arg1) : (memref, memref) -> () + "lo_spn.return"() : () -> () + } +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// CHECK-LABEL: memref.global "private" constant @categorical_0 : memref<3xf64> = dense<[2.500000e-01, 6.250000e-01, 1.250000e-01]> + +// CHECK-LABEL: func @task_0( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref) { +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_5]] { +// CHECK: %[[VAL_7:.*]] = constant 0 : index +// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_7]]] : memref +// CHECK: %[[VAL_9:.*]] = constant 1 : index +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_9]]] : memref +// CHECK: %[[VAL_11:.*]] = constant 2 : index +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_11]]] : memref +// CHECK: %[[VAL_13:.*]] = constant 3 : index +// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_13]]] : memref +// CHECK: %[[VAL_15:.*]] = constant 4 : index +// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_15]]] : memref +// CHECK: %[[VAL_17:.*]] = constant 5 : index +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_17]]] : memref +// CHECK: %[[VAL_19:.*]] = constant 1.000000e-01 : f64 +// CHECK: %[[VAL_20:.*]] = constant 1.000000e+00 : f64 +// CHECK: %[[VAL_21:.*]] = cmpf ult, %[[VAL_8]], %[[VAL_20]] : f64 +// CHECK: %[[VAL_22:.*]] = constant 3.500000e-01 : f64 +// CHECK: %[[VAL_23:.*]] = constant 5.500000e-01 : f64 +// CHECK: %[[VAL_24:.*]] = select %[[VAL_21]], %[[VAL_22]], %[[VAL_23]] : f64 +// CHECK: %[[VAL_25:.*]] = memref.get_global @categorical_0 : memref<3xf64> +// CHECK: %[[VAL_26:.*]] = fptoui %[[VAL_10]] : f64 to i64 +// CHECK: %[[VAL_27:.*]] = index_cast %[[VAL_26]] : i64 to index +// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_27]]] : memref<3xf64> +// CHECK: %[[VAL_29:.*]] = constant 1.000000e+00 : f64 +// CHECK: %[[VAL_30:.*]] = cmpf ult, %[[VAL_12]], %[[VAL_29]] : f64 +// CHECK: %[[VAL_31:.*]] = constant 2.500000e-01 : f64 +// CHECK: %[[VAL_32:.*]] = constant 7.500000e-01 : f64 +// CHECK: %[[VAL_33:.*]] = select %[[VAL_30]], %[[VAL_31]], %[[VAL_32]] : f64 +// CHECK: %[[VAL_34:.*]] = constant 1.000000e+00 : f64 +// CHECK: %[[VAL_35:.*]] = cmpf ult, %[[VAL_14]], %[[VAL_34]] : f64 +// CHECK: %[[VAL_36:.*]] = constant 4.500000e-01 : f64 +// CHECK: %[[VAL_37:.*]] = constant 5.500000e-01 : f64 +// CHECK: %[[VAL_38:.*]] = select %[[VAL_35]], %[[VAL_36]], %[[VAL_37]] : f64 +// CHECK: %[[VAL_39:.*]] = constant 0.3989422804014327 : f64 +// CHECK: %[[VAL_40:.*]] = constant -5.000000e-01 : f64 +// CHECK: %[[VAL_41:.*]] = constant 5.000000e-01 : f64 +// CHECK: %[[VAL_42:.*]] = subf %[[VAL_16]], %[[VAL_41]] : f64 +// CHECK: %[[VAL_43:.*]] = mulf %[[VAL_42]], %[[VAL_42]] : f64 +// CHECK: %[[VAL_44:.*]] = mulf %[[VAL_43]], %[[VAL_40]] : f64 +// CHECK: %[[VAL_45:.*]] = math.exp %[[VAL_44]] : f64 +// CHECK: %[[VAL_46:.*]] = mulf %[[VAL_39]], %[[VAL_45]] : f64 +// CHECK: %[[VAL_47:.*]] = constant 3.9894228040143269 : f64 +// CHECK: %[[VAL_48:.*]] = constant -49.999999999999993 : f64 +// CHECK: %[[VAL_49:.*]] = constant 2.500000e-01 : f64 +// CHECK: %[[VAL_50:.*]] = subf %[[VAL_18]], %[[VAL_49]] : f64 +// CHECK: %[[VAL_51:.*]] = mulf %[[VAL_50]], %[[VAL_50]] : f64 +// CHECK: %[[VAL_52:.*]] = mulf %[[VAL_51]], %[[VAL_48]] : f64 +// CHECK: %[[VAL_53:.*]] = math.exp %[[VAL_52]] : f64 +// CHECK: %[[VAL_54:.*]] = mulf %[[VAL_47]], %[[VAL_53]] : f64 +// CHECK: %[[VAL_55:.*]] = mulf %[[VAL_24]], %[[VAL_28]] : f64 +// CHECK: %[[VAL_56:.*]] = mulf %[[VAL_55]], %[[VAL_33]] : f64 +// CHECK: %[[VAL_57:.*]] = mulf %[[VAL_56]], %[[VAL_19]] : f64 +// CHECK: %[[VAL_58:.*]] = mulf %[[VAL_38]], %[[VAL_46]] : f64 +// CHECK: %[[VAL_59:.*]] = mulf %[[VAL_58]], %[[VAL_54]] : f64 +// CHECK: %[[VAL_60:.*]] = mulf %[[VAL_59]], %[[VAL_19]] : f64 +// CHECK: %[[VAL_61:.*]] = addf %[[VAL_57]], %[[VAL_60]] : f64 +// CHECK: %[[VAL_62:.*]] = math.log %[[VAL_61]] : f64 +// CHECK: memref.store %[[VAL_62]], %[[VAL_1]]{{\[}}%[[VAL_6]]] : memref +// CHECK: } +// CHECK: return +// CHECK: } + +// CHECK-LABEL: func @spn_vector( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref) { +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = memref.dim %[[VAL_0]], %[[VAL_2]] : memref +// CHECK: %[[VAL_4:.*]] = memref.alloc(%[[VAL_3]]) : memref +// CHECK: call @task_0(%[[VAL_0]], %[[VAL_4]]) : (memref, memref) -> () +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_4]], %[[VAL_5]] : memref +// CHECK: %[[VAL_7:.*]] = constant 0 : index +// CHECK: %[[VAL_8:.*]] = constant 1 : index +// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] { +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] : memref +// CHECK: memref.store %[[VAL_10]], %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref +// CHECK: } +// CHECK: return +// CHECK: } From 01840f3578259e675f69608a7dfe9d6dae31575d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Halkenh=C3=A4user?= Date: Mon, 2 Aug 2021 14:24:37 +0200 Subject: [PATCH 08/12] [FIX] Reduced / concentrated tests on specific instructions. --- .../lower-to-cpu-nodes-select-vectorize.mlir | 273 ++---------------- .../lower-to-cpu-nodes-select.mlir | 130 +-------- 2 files changed, 40 insertions(+), 363 deletions(-) diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select-vectorize.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select-vectorize.mlir index 580ca32e..1838c9de 100644 --- a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select-vectorize.mlir +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select-vectorize.mlir @@ -1,265 +1,48 @@ // RUN: %optcall --vectorize-lospn-nodes %s | FileCheck %s module { - func @vec_task_0(%arg0: memref, %arg1: memref) { + func @vec_task_0(%arg0: memref, %arg1: memref) { %c0 = constant 0 : index - %0 = memref.dim %arg0, %c0 : memref %c4 = constant 4 : index - %1 = remi_unsigned %0, %c4 : index - %2 = subi %0, %1 : index - %c0_0 = constant 0 : index - %c4_1 = constant 4 : index - scf.for %arg2 = %c0_0 to %2 step %c4_1 { - %3 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 0 : ui32, vector_width = 4 : i32} : (memref, index) -> f64 - %4 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 1 : ui32, vector_width = 4 : i32} : (memref, index) -> f64 - %5 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 2 : ui32, vector_width = 4 : i32} : (memref, index) -> f64 - %6 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 3 : ui32, vector_width = 4 : i32} : (memref, index) -> f64 - %7 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 4 : ui32, vector_width = 4 : i32} : (memref, index) -> f64 - %8 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 5 : ui32, vector_width = 4 : i32} : (memref, index) -> f64 - %cst = constant 1.000000e-01 : f64 - %9 = "lo_spn.select"(%3) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 3.500000e-01 : f64, vector_width = 4 : i32} : (f64) -> f64 - %10 = "lo_spn.categorical"(%4) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false, vector_width = 4 : i32} : (f64) -> f64 - %11 = "lo_spn.select"(%5) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 7.500000e-01 : f64, val_true = 2.500000e-01 : f64, vector_width = 4 : i32} : (f64) -> f64 - %12 = "lo_spn.select"(%6) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 4.500000e-01 : f64, vector_width = 4 : i32} : (f64) -> f64 - %13 = "lo_spn.gaussian"(%7) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false, vector_width = 4 : i32} : (f64) -> f64 - %14 = "lo_spn.gaussian"(%8) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false, vector_width = 4 : i32} : (f64) -> f64 - %15 = "lo_spn.mul"(%9, %10) {vector_width = 4 : i32} : (f64, f64) -> f64 - %16 = "lo_spn.mul"(%15, %11) {vector_width = 4 : i32} : (f64, f64) -> f64 - %17 = "lo_spn.mul"(%16, %cst) {vector_width = 4 : i32} : (f64, f64) -> f64 - %18 = "lo_spn.mul"(%12, %13) {vector_width = 4 : i32} : (f64, f64) -> f64 - %19 = "lo_spn.mul"(%18, %14) {vector_width = 4 : i32} : (f64, f64) -> f64 - %20 = "lo_spn.mul"(%19, %cst) {vector_width = 4 : i32} : (f64, f64) -> f64 - %21 = "lo_spn.add"(%17, %20) {vector_width = 4 : i32} : (f64, f64) -> f64 - %22 = "lo_spn.log"(%21) {vector_width = 4 : i32} : (f64) -> f64 - "lo_spn.batch_write"(%22, %arg1, %arg2) {vector_width = 4 : i32} : (f64, memref, index) -> () - } - %c1 = constant 1 : index - scf.for %arg2 = %2 to %0 step %c1 { - %3 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 0 : ui32} : (memref, index) -> f64 - %4 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 1 : ui32} : (memref, index) -> f64 - %5 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 2 : ui32} : (memref, index) -> f64 - %6 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 3 : ui32} : (memref, index) -> f64 - %7 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 4 : ui32} : (memref, index) -> f64 - %8 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 5 : ui32} : (memref, index) -> f64 - %cst = constant 1.000000e-01 : f64 - %9 = "lo_spn.select"(%3) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 3.500000e-01 : f64} : (f64) -> f64 - %10 = "lo_spn.categorical"(%4) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false} : (f64) -> f64 - %11 = "lo_spn.select"(%5) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 7.500000e-01 : f64, val_true = 2.500000e-01 : f64} : (f64) -> f64 - %12 = "lo_spn.select"(%6) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 4.500000e-01 : f64} : (f64) -> f64 - %13 = "lo_spn.gaussian"(%7) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false} : (f64) -> f64 - %14 = "lo_spn.gaussian"(%8) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false} : (f64) -> f64 - %15 = "lo_spn.mul"(%9, %10) : (f64, f64) -> f64 - %16 = "lo_spn.mul"(%15, %11) : (f64, f64) -> f64 - %17 = "lo_spn.mul"(%16, %cst) : (f64, f64) -> f64 - %18 = "lo_spn.mul"(%12, %13) : (f64, f64) -> f64 - %19 = "lo_spn.mul"(%18, %14) : (f64, f64) -> f64 - %20 = "lo_spn.mul"(%19, %cst) : (f64, f64) -> f64 - %21 = "lo_spn.add"(%17, %20) : (f64, f64) -> f64 - %22 = "lo_spn.log"(%21) : (f64) -> f64 - "lo_spn.batch_write"(%22, %arg1, %arg2) : (f64, memref, index) -> () + scf.for %arg2 = %c0 to %c4 step %c4 { + %0 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 0 : ui32, vector_width = 4 : i32} : (memref, index) -> f64 + %1 = "lo_spn.select"(%0) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 3.500000e-01 : f64, vector_width = 4 : i32} : (f64) -> f64 + "lo_spn.batch_write"(%1, %arg1, %arg2) {vector_width = 4 : i32} : (f64, memref, index) -> () } return } - func @spn_vector(%arg0: memref, %arg1: memref) { - %c0 = constant 0 : index - %0 = memref.dim %arg0, %c0 : memref - %1 = memref.alloc(%0) : memref - call @vec_task_0(%arg0, %1) : (memref, memref) -> () - "lo_spn.copy"(%1, %arg1) : (memref, memref) -> () - "lo_spn.return"() : () -> () - } } // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -// CHECK-LABEL: memref.global "private" constant @categorical_vec_0 : memref<3xf64> = dense<[2.500000e-01, 6.250000e-01, 1.250000e-01]> // CHECK-LABEL: func @vec_task_0( -// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_0:.*]]: memref, // CHECK-SAME: %[[VAL_1:.*]]: memref) { // CHECK: %[[VAL_2:.*]] = constant 0 : index -// CHECK: %[[VAL_3:.*]] = memref.dim %[[VAL_0]], %[[VAL_2]] : memref -// CHECK: %[[VAL_4:.*]] = constant 4 : index -// CHECK: %[[VAL_5:.*]] = remi_unsigned %[[VAL_3]], %[[VAL_4]] : index -// CHECK: %[[VAL_6:.*]] = subi %[[VAL_3]], %[[VAL_5]] : index -// CHECK: %[[VAL_7:.*]] = constant 0 : index -// CHECK: %[[VAL_8:.*]] = constant 4 : index -// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] { -// CHECK: %[[VAL_10:.*]] = index_cast %[[VAL_9]] : index to i64 -// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : i64 to vector<4xi64> -// CHECK: %[[VAL_12:.*]] = constant dense<[0, 6, 12, 18]> : vector<4xi64> -// CHECK: %[[VAL_13:.*]] = constant dense<6> : vector<4xi64> -// CHECK: %[[VAL_14:.*]] = muli %[[VAL_11]], %[[VAL_13]] : vector<4xi64> -// CHECK: %[[VAL_15:.*]] = addi %[[VAL_14]], %[[VAL_12]] : vector<4xi64> -// CHECK: %[[VAL_16:.*]] = constant dense<0.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_17:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_3:.*]] = constant 4 : index +// CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_2]] to %[[VAL_3]] step %[[VAL_3]] { +// CHECK: %[[VAL_5:.*]] = index_cast %[[VAL_4]] : index to i64 +// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : i64 to vector<4xi64> +// CHECK: %[[VAL_7:.*]] = constant dense<[0, 1, 2, 3]> : vector<4xi64> +// CHECK: %[[VAL_8:.*]] = constant dense<1> : vector<4xi64> +// CHECK: %[[VAL_9:.*]] = muli %[[VAL_6]], %[[VAL_8]] : vector<4xi64> +// CHECK: %[[VAL_10:.*]] = addi %[[VAL_9]], %[[VAL_7]] : vector<4xi64> +// CHECK: %[[VAL_11:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_12:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_13:.*]] = constant 0 : index +// CHECK: %[[VAL_14:.*]] = memref.dim %[[VAL_0]], %[[VAL_13]] : memref +// CHECK: %[[VAL_15:.*]] = constant 1 : index +// CHECK: %[[VAL_16:.*]] = muli %[[VAL_14]], %[[VAL_15]] : index +// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: {{\[}}%[[VAL_16]]], strides: [1] : memref to memref // CHECK: %[[VAL_18:.*]] = constant 0 : index -// CHECK: %[[VAL_19:.*]] = memref.dim %[[VAL_0]], %[[VAL_18]] : memref -// CHECK: %[[VAL_20:.*]] = constant 6 : index -// CHECK: %[[VAL_21:.*]] = muli %[[VAL_19]], %[[VAL_20]] : index -// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: {{\[}}%[[VAL_21]]], strides: [1] : memref to memref -// CHECK: %[[VAL_23:.*]] = constant 0 : index -// CHECK: %[[VAL_24:.*]] = vector.gather %[[VAL_22]]{{\[}}%[[VAL_23]]] {{\[}}%[[VAL_15]]], %[[VAL_17]], %[[VAL_16]] : memref, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_25:.*]] = index_cast %[[VAL_9]] : index to i64 -// CHECK: %[[VAL_26:.*]] = vector.broadcast %[[VAL_25]] : i64 to vector<4xi64> -// CHECK: %[[VAL_27:.*]] = constant dense<[1, 7, 13, 19]> : vector<4xi64> -// CHECK: %[[VAL_28:.*]] = constant dense<6> : vector<4xi64> -// CHECK: %[[VAL_29:.*]] = muli %[[VAL_26]], %[[VAL_28]] : vector<4xi64> -// CHECK: %[[VAL_30:.*]] = addi %[[VAL_29]], %[[VAL_27]] : vector<4xi64> -// CHECK: %[[VAL_31:.*]] = constant dense<0.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_32:.*]] = constant dense : vector<4xi1> -// CHECK: %[[VAL_33:.*]] = constant 0 : index -// CHECK: %[[VAL_34:.*]] = memref.dim %[[VAL_0]], %[[VAL_33]] : memref -// CHECK: %[[VAL_35:.*]] = constant 6 : index -// CHECK: %[[VAL_36:.*]] = muli %[[VAL_34]], %[[VAL_35]] : index -// CHECK: %[[VAL_37:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: {{\[}}%[[VAL_36]]], strides: [1] : memref to memref -// CHECK: %[[VAL_38:.*]] = constant 0 : index -// CHECK: %[[VAL_39:.*]] = vector.gather %[[VAL_37]]{{\[}}%[[VAL_38]]] {{\[}}%[[VAL_30]]], %[[VAL_32]], %[[VAL_31]] : memref, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_40:.*]] = index_cast %[[VAL_9]] : index to i64 -// CHECK: %[[VAL_41:.*]] = vector.broadcast %[[VAL_40]] : i64 to vector<4xi64> -// CHECK: %[[VAL_42:.*]] = constant dense<[2, 8, 14, 20]> : vector<4xi64> -// CHECK: %[[VAL_43:.*]] = constant dense<6> : vector<4xi64> -// CHECK: %[[VAL_44:.*]] = muli %[[VAL_41]], %[[VAL_43]] : vector<4xi64> -// CHECK: %[[VAL_45:.*]] = addi %[[VAL_44]], %[[VAL_42]] : vector<4xi64> -// CHECK: %[[VAL_46:.*]] = constant dense<0.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_47:.*]] = constant dense : vector<4xi1> -// CHECK: %[[VAL_48:.*]] = constant 0 : index -// CHECK: %[[VAL_49:.*]] = memref.dim %[[VAL_0]], %[[VAL_48]] : memref -// CHECK: %[[VAL_50:.*]] = constant 6 : index -// CHECK: %[[VAL_51:.*]] = muli %[[VAL_49]], %[[VAL_50]] : index -// CHECK: %[[VAL_52:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: {{\[}}%[[VAL_51]]], strides: [1] : memref to memref -// CHECK: %[[VAL_53:.*]] = constant 0 : index -// CHECK: %[[VAL_54:.*]] = vector.gather %[[VAL_52]]{{\[}}%[[VAL_53]]] {{\[}}%[[VAL_45]]], %[[VAL_47]], %[[VAL_46]] : memref, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_55:.*]] = index_cast %[[VAL_9]] : index to i64 -// CHECK: %[[VAL_56:.*]] = vector.broadcast %[[VAL_55]] : i64 to vector<4xi64> -// CHECK: %[[VAL_57:.*]] = constant dense<[3, 9, 15, 21]> : vector<4xi64> -// CHECK: %[[VAL_58:.*]] = constant dense<6> : vector<4xi64> -// CHECK: %[[VAL_59:.*]] = muli %[[VAL_56]], %[[VAL_58]] : vector<4xi64> -// CHECK: %[[VAL_60:.*]] = addi %[[VAL_59]], %[[VAL_57]] : vector<4xi64> -// CHECK: %[[VAL_61:.*]] = constant dense<0.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_62:.*]] = constant dense : vector<4xi1> -// CHECK: %[[VAL_63:.*]] = constant 0 : index -// CHECK: %[[VAL_64:.*]] = memref.dim %[[VAL_0]], %[[VAL_63]] : memref -// CHECK: %[[VAL_65:.*]] = constant 6 : index -// CHECK: %[[VAL_66:.*]] = muli %[[VAL_64]], %[[VAL_65]] : index -// CHECK: %[[VAL_67:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: {{\[}}%[[VAL_66]]], strides: [1] : memref to memref -// CHECK: %[[VAL_68:.*]] = constant 0 : index -// CHECK: %[[VAL_69:.*]] = vector.gather %[[VAL_67]]{{\[}}%[[VAL_68]]] {{\[}}%[[VAL_60]]], %[[VAL_62]], %[[VAL_61]] : memref, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_70:.*]] = index_cast %[[VAL_9]] : index to i64 -// CHECK: %[[VAL_71:.*]] = vector.broadcast %[[VAL_70]] : i64 to vector<4xi64> -// CHECK: %[[VAL_72:.*]] = constant dense<[4, 10, 16, 22]> : vector<4xi64> -// CHECK: %[[VAL_73:.*]] = constant dense<6> : vector<4xi64> -// CHECK: %[[VAL_74:.*]] = muli %[[VAL_71]], %[[VAL_73]] : vector<4xi64> -// CHECK: %[[VAL_75:.*]] = addi %[[VAL_74]], %[[VAL_72]] : vector<4xi64> -// CHECK: %[[VAL_76:.*]] = constant dense<0.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_77:.*]] = constant dense : vector<4xi1> -// CHECK: %[[VAL_78:.*]] = constant 0 : index -// CHECK: %[[VAL_79:.*]] = memref.dim %[[VAL_0]], %[[VAL_78]] : memref -// CHECK: %[[VAL_80:.*]] = constant 6 : index -// CHECK: %[[VAL_81:.*]] = muli %[[VAL_79]], %[[VAL_80]] : index -// CHECK: %[[VAL_82:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: {{\[}}%[[VAL_81]]], strides: [1] : memref to memref -// CHECK: %[[VAL_83:.*]] = constant 0 : index -// CHECK: %[[VAL_84:.*]] = vector.gather %[[VAL_82]]{{\[}}%[[VAL_83]]] {{\[}}%[[VAL_75]]], %[[VAL_77]], %[[VAL_76]] : memref, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_85:.*]] = index_cast %[[VAL_9]] : index to i64 -// CHECK: %[[VAL_86:.*]] = vector.broadcast %[[VAL_85]] : i64 to vector<4xi64> -// CHECK: %[[VAL_87:.*]] = constant dense<[5, 11, 17, 23]> : vector<4xi64> -// CHECK: %[[VAL_88:.*]] = constant dense<6> : vector<4xi64> -// CHECK: %[[VAL_89:.*]] = muli %[[VAL_86]], %[[VAL_88]] : vector<4xi64> -// CHECK: %[[VAL_90:.*]] = addi %[[VAL_89]], %[[VAL_87]] : vector<4xi64> -// CHECK: %[[VAL_91:.*]] = constant dense<0.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_92:.*]] = constant dense : vector<4xi1> -// CHECK: %[[VAL_93:.*]] = constant 0 : index -// CHECK: %[[VAL_94:.*]] = memref.dim %[[VAL_0]], %[[VAL_93]] : memref -// CHECK: %[[VAL_95:.*]] = constant 6 : index -// CHECK: %[[VAL_96:.*]] = muli %[[VAL_94]], %[[VAL_95]] : index -// CHECK: %[[VAL_97:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: {{\[}}%[[VAL_96]]], strides: [1] : memref to memref -// CHECK: %[[VAL_98:.*]] = constant 0 : index -// CHECK: %[[VAL_99:.*]] = vector.gather %[[VAL_97]]{{\[}}%[[VAL_98]]] {{\[}}%[[VAL_90]]], %[[VAL_92]], %[[VAL_91]] : memref, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_100:.*]] = constant 1.000000e-01 : f64 -// CHECK: %[[VAL_101:.*]] = constant dense<1.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_102:.*]] = cmpf ult, %[[VAL_24]], %[[VAL_101]] : vector<4xf64> -// CHECK: %[[VAL_103:.*]] = constant dense<3.500000e-01> : vector<4xf64> -// CHECK: %[[VAL_104:.*]] = constant dense<5.500000e-01> : vector<4xf64> -// CHECK: %[[VAL_105:.*]] = select %[[VAL_102]], %[[VAL_103]], %[[VAL_104]] : vector<4xi1>, vector<4xf64> -// CHECK: %[[VAL_106:.*]] = memref.get_global @categorical_vec_0 : memref<3xf64> -// CHECK: %[[VAL_107:.*]] = fptoui %[[VAL_39]] : vector<4xf64> to vector<4xi64> -// CHECK: %[[VAL_108:.*]] = constant dense<0.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_109:.*]] = constant dense : vector<4xi1> -// CHECK: %[[VAL_110:.*]] = constant 0 : index -// CHECK: %[[VAL_111:.*]] = vector.gather %[[VAL_106]]{{\[}}%[[VAL_110]]] {{\[}}%[[VAL_107]]], %[[VAL_109]], %[[VAL_108]] : memref<3xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_112:.*]] = constant dense<1.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_113:.*]] = cmpf ult, %[[VAL_54]], %[[VAL_112]] : vector<4xf64> -// CHECK: %[[VAL_114:.*]] = constant dense<2.500000e-01> : vector<4xf64> -// CHECK: %[[VAL_115:.*]] = constant dense<7.500000e-01> : vector<4xf64> -// CHECK: %[[VAL_116:.*]] = select %[[VAL_113]], %[[VAL_114]], %[[VAL_115]] : vector<4xi1>, vector<4xf64> -// CHECK: %[[VAL_117:.*]] = constant dense<1.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_118:.*]] = cmpf ult, %[[VAL_69]], %[[VAL_117]] : vector<4xf64> -// CHECK: %[[VAL_119:.*]] = constant dense<4.500000e-01> : vector<4xf64> -// CHECK: %[[VAL_120:.*]] = constant dense<5.500000e-01> : vector<4xf64> -// CHECK: %[[VAL_121:.*]] = select %[[VAL_118]], %[[VAL_119]], %[[VAL_120]] : vector<4xi1>, vector<4xf64> -// CHECK: %[[VAL_122:.*]] = constant dense<0.3989422804014327> : vector<4xf64> -// CHECK: %[[VAL_123:.*]] = constant dense<-5.000000e-01> : vector<4xf64> -// CHECK: %[[VAL_124:.*]] = constant dense<5.000000e-01> : vector<4xf64> -// CHECK: %[[VAL_125:.*]] = subf %[[VAL_84]], %[[VAL_124]] : vector<4xf64> -// CHECK: %[[VAL_126:.*]] = mulf %[[VAL_125]], %[[VAL_125]] : vector<4xf64> -// CHECK: %[[VAL_127:.*]] = mulf %[[VAL_126]], %[[VAL_123]] : vector<4xf64> -// CHECK: %[[VAL_128:.*]] = math.exp %[[VAL_127]] : vector<4xf64> -// CHECK: %[[VAL_129:.*]] = mulf %[[VAL_122]], %[[VAL_128]] : vector<4xf64> -// CHECK: %[[VAL_130:.*]] = constant dense<3.9894228040143269> : vector<4xf64> -// CHECK: %[[VAL_131:.*]] = constant dense<-49.999999999999993> : vector<4xf64> -// CHECK: %[[VAL_132:.*]] = constant dense<2.500000e-01> : vector<4xf64> -// CHECK: %[[VAL_133:.*]] = subf %[[VAL_99]], %[[VAL_132]] : vector<4xf64> -// CHECK: %[[VAL_134:.*]] = mulf %[[VAL_133]], %[[VAL_133]] : vector<4xf64> -// CHECK: %[[VAL_135:.*]] = mulf %[[VAL_134]], %[[VAL_131]] : vector<4xf64> -// CHECK: %[[VAL_136:.*]] = math.exp %[[VAL_135]] : vector<4xf64> -// CHECK: %[[VAL_137:.*]] = mulf %[[VAL_130]], %[[VAL_136]] : vector<4xf64> -// CHECK: %[[VAL_138:.*]] = mulf %[[VAL_105]], %[[VAL_111]] : vector<4xf64> -// CHECK: %[[VAL_139:.*]] = mulf %[[VAL_138]], %[[VAL_116]] : vector<4xf64> -// CHECK: %[[VAL_140:.*]] = "lo_spn.to_vector"(%[[VAL_100]]) : (f64) -> vector<4xf64> -// CHECK: %[[VAL_141:.*]] = mulf %[[VAL_139]], %[[VAL_140]] : vector<4xf64> -// CHECK: %[[VAL_142:.*]] = mulf %[[VAL_121]], %[[VAL_129]] : vector<4xf64> -// CHECK: %[[VAL_143:.*]] = mulf %[[VAL_142]], %[[VAL_137]] : vector<4xf64> -// CHECK: %[[VAL_144:.*]] = "lo_spn.to_vector"(%[[VAL_100]]) : (f64) -> vector<4xf64> -// CHECK: %[[VAL_145:.*]] = mulf %[[VAL_143]], %[[VAL_144]] : vector<4xf64> -// CHECK: %[[VAL_146:.*]] = addf %[[VAL_141]], %[[VAL_145]] : vector<4xf64> -// CHECK: %[[VAL_147:.*]] = math.log %[[VAL_146]] : vector<4xf64> -// CHECK: vector.transfer_write %[[VAL_147]], %[[VAL_1]]{{\[}}%[[VAL_9]]] : vector<4xf64>, memref -// CHECK: } -// CHECK: %[[VAL_148:.*]] = constant 1 : index -// CHECK: scf.for %[[VAL_149:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_148]] { -// CHECK: %[[VAL_150:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_149]]) {sampleIndex = 0 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_151:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_149]]) {sampleIndex = 1 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_152:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_149]]) {sampleIndex = 2 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_153:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_149]]) {sampleIndex = 3 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_154:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_149]]) {sampleIndex = 4 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_155:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_149]]) {sampleIndex = 5 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_156:.*]] = constant 1.000000e-01 : f64 -// CHECK: %[[VAL_157:.*]] = "lo_spn.select"(%[[VAL_150]]) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 3.500000e-01 : f64} : (f64) -> f64 -// CHECK: %[[VAL_158:.*]] = "lo_spn.categorical"(%[[VAL_151]]) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_159:.*]] = "lo_spn.select"(%[[VAL_152]]) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 7.500000e-01 : f64, val_true = 2.500000e-01 : f64} : (f64) -> f64 -// CHECK: %[[VAL_160:.*]] = "lo_spn.select"(%[[VAL_153]]) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 4.500000e-01 : f64} : (f64) -> f64 -// CHECK: %[[VAL_161:.*]] = "lo_spn.gaussian"(%[[VAL_154]]) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_162:.*]] = "lo_spn.gaussian"(%[[VAL_155]]) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_163:.*]] = "lo_spn.mul"(%[[VAL_157]], %[[VAL_158]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_164:.*]] = "lo_spn.mul"(%[[VAL_163]], %[[VAL_159]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_165:.*]] = "lo_spn.mul"(%[[VAL_164]], %[[VAL_156]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_166:.*]] = "lo_spn.mul"(%[[VAL_160]], %[[VAL_161]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_167:.*]] = "lo_spn.mul"(%[[VAL_166]], %[[VAL_162]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_168:.*]] = "lo_spn.mul"(%[[VAL_167]], %[[VAL_156]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_169:.*]] = "lo_spn.add"(%[[VAL_165]], %[[VAL_168]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_170:.*]] = "lo_spn.log"(%[[VAL_169]]) : (f64) -> f64 -// CHECK: "lo_spn.batch_write"(%[[VAL_170]], %[[VAL_1]], %[[VAL_149]]) : (f64, memref, index) -> () +// CHECK: %[[VAL_19:.*]] = vector.gather %[[VAL_17]]{{\[}}%[[VAL_18]]] {{\[}}%[[VAL_10]]], %[[VAL_12]], %[[VAL_11]] : memref, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_20:.*]] = constant dense<1.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_21:.*]] = cmpf ult, %[[VAL_19]], %[[VAL_20]] : vector<4xf64> +// CHECK: %[[VAL_22:.*]] = constant dense<3.500000e-01> : vector<4xf64> +// CHECK: %[[VAL_23:.*]] = constant dense<5.500000e-01> : vector<4xf64> +// CHECK: %[[VAL_24:.*]] = select %[[VAL_21]], %[[VAL_22]], %[[VAL_23]] : vector<4xi1>, vector<4xf64> +// CHECK: vector.transfer_write %[[VAL_24]], %[[VAL_1]]{{\[}}%[[VAL_4]]] : vector<4xf64>, memref // CHECK: } // CHECK: return // CHECK: } - -// CHECK-LABEL: func @spn_vector( -// CHECK-SAME: %[[VAL_0:.*]]: memref, -// CHECK-SAME: %[[VAL_1:.*]]: memref) { -// CHECK: %[[VAL_2:.*]] = constant 0 : index -// CHECK: %[[VAL_3:.*]] = memref.dim %[[VAL_0]], %[[VAL_2]] : memref -// CHECK: %[[VAL_4:.*]] = memref.alloc(%[[VAL_3]]) : memref -// CHECK: call @vec_task_0(%[[VAL_0]], %[[VAL_4]]) : (memref, memref) -> () -// CHECK: "lo_spn.copy"(%[[VAL_4]], %[[VAL_1]]) : (memref, memref) -> () -// CHECK: "lo_spn.return"() : () -> () -// CHECK: } diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir index 40ef3270..2ef17bfc 100644 --- a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir @@ -2,133 +2,27 @@ module { func @task_0(%arg0: memref, %arg1: memref) { - %c0 = constant 0 : index - %c0_0 = constant 0 : index - %0 = memref.dim %arg0, %c0_0 : memref - %c1 = constant 1 : index - scf.for %arg2 = %c0 to %0 step %c1 { - %1 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 0 : ui32} : (memref, index) -> f64 - %2 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 1 : ui32} : (memref, index) -> f64 - %3 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 2 : ui32} : (memref, index) -> f64 - %4 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 3 : ui32} : (memref, index) -> f64 - %5 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 4 : ui32} : (memref, index) -> f64 - %6 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 5 : ui32} : (memref, index) -> f64 - %cst = constant 1.000000e-01 : f64 - %7 = "lo_spn.select"(%1) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 3.500000e-01 : f64} : (f64) -> f64 - %8 = "lo_spn.categorical"(%2) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false} : (f64) -> f64 - %9 = "lo_spn.select"(%3) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 7.500000e-01 : f64, val_true = 2.500000e-01 : f64} : (f64) -> f64 - %10 = "lo_spn.select"(%4) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 4.500000e-01 : f64} : (f64) -> f64 - %11 = "lo_spn.gaussian"(%5) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false} : (f64) -> f64 - %12 = "lo_spn.gaussian"(%6) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false} : (f64) -> f64 - %13 = "lo_spn.mul"(%7, %8) : (f64, f64) -> f64 - %14 = "lo_spn.mul"(%13, %9) : (f64, f64) -> f64 - %15 = "lo_spn.mul"(%14, %cst) : (f64, f64) -> f64 - %16 = "lo_spn.mul"(%10, %11) : (f64, f64) -> f64 - %17 = "lo_spn.mul"(%16, %12) : (f64, f64) -> f64 - %18 = "lo_spn.mul"(%17, %cst) : (f64, f64) -> f64 - %19 = "lo_spn.add"(%15, %18) : (f64, f64) -> f64 - %20 = "lo_spn.log"(%19) : (f64) -> f64 - "lo_spn.batch_write"(%20, %arg1, %arg2) : (f64, memref, index) -> () - } + %cst1 = constant 1.000000e-01 : f64 + %ind1 = constant 1 : index + %0 = "lo_spn.select"(%cst1) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 5.500000e-01 : f64, val_true = 3.500000e-01 : f64} : (f64) -> f64 + "lo_spn.batch_write"(%0, %arg1, %ind1) : (f64, memref, index) -> () return } - func @spn_vector(%arg0: memref, %arg1: memref) { - %c0 = constant 0 : index - %0 = memref.dim %arg0, %c0 : memref - %1 = memref.alloc(%0) : memref - call @task_0(%arg0, %1) : (memref, memref) -> () - "lo_spn.copy"(%1, %arg1) : (memref, memref) -> () - "lo_spn.return"() : () -> () - } } // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -// CHECK-LABEL: memref.global "private" constant @categorical_0 : memref<3xf64> = dense<[2.500000e-01, 6.250000e-01, 1.250000e-01]> // CHECK-LABEL: func @task_0( // CHECK-SAME: %[[VAL_0:.*]]: memref, // CHECK-SAME: %[[VAL_1:.*]]: memref) { -// CHECK: %[[VAL_2:.*]] = constant 0 : index -// CHECK: %[[VAL_3:.*]] = constant 0 : index -// CHECK: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref -// CHECK: %[[VAL_5:.*]] = constant 1 : index -// CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_5]] { -// CHECK: %[[VAL_7:.*]] = constant 0 : index -// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_7]]] : memref -// CHECK: %[[VAL_9:.*]] = constant 1 : index -// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_9]]] : memref -// CHECK: %[[VAL_11:.*]] = constant 2 : index -// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_11]]] : memref -// CHECK: %[[VAL_13:.*]] = constant 3 : index -// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_13]]] : memref -// CHECK: %[[VAL_15:.*]] = constant 4 : index -// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_15]]] : memref -// CHECK: %[[VAL_17:.*]] = constant 5 : index -// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_17]]] : memref -// CHECK: %[[VAL_19:.*]] = constant 1.000000e-01 : f64 -// CHECK: %[[VAL_20:.*]] = constant 1.000000e+00 : f64 -// CHECK: %[[VAL_21:.*]] = cmpf ult, %[[VAL_8]], %[[VAL_20]] : f64 -// CHECK: %[[VAL_22:.*]] = constant 3.500000e-01 : f64 -// CHECK: %[[VAL_23:.*]] = constant 5.500000e-01 : f64 -// CHECK: %[[VAL_24:.*]] = select %[[VAL_21]], %[[VAL_22]], %[[VAL_23]] : f64 -// CHECK: %[[VAL_25:.*]] = memref.get_global @categorical_0 : memref<3xf64> -// CHECK: %[[VAL_26:.*]] = fptoui %[[VAL_10]] : f64 to i64 -// CHECK: %[[VAL_27:.*]] = index_cast %[[VAL_26]] : i64 to index -// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_27]]] : memref<3xf64> -// CHECK: %[[VAL_29:.*]] = constant 1.000000e+00 : f64 -// CHECK: %[[VAL_30:.*]] = cmpf ult, %[[VAL_12]], %[[VAL_29]] : f64 -// CHECK: %[[VAL_31:.*]] = constant 2.500000e-01 : f64 -// CHECK: %[[VAL_32:.*]] = constant 7.500000e-01 : f64 -// CHECK: %[[VAL_33:.*]] = select %[[VAL_30]], %[[VAL_31]], %[[VAL_32]] : f64 -// CHECK: %[[VAL_34:.*]] = constant 1.000000e+00 : f64 -// CHECK: %[[VAL_35:.*]] = cmpf ult, %[[VAL_14]], %[[VAL_34]] : f64 -// CHECK: %[[VAL_36:.*]] = constant 4.500000e-01 : f64 -// CHECK: %[[VAL_37:.*]] = constant 5.500000e-01 : f64 -// CHECK: %[[VAL_38:.*]] = select %[[VAL_35]], %[[VAL_36]], %[[VAL_37]] : f64 -// CHECK: %[[VAL_39:.*]] = constant 0.3989422804014327 : f64 -// CHECK: %[[VAL_40:.*]] = constant -5.000000e-01 : f64 -// CHECK: %[[VAL_41:.*]] = constant 5.000000e-01 : f64 -// CHECK: %[[VAL_42:.*]] = subf %[[VAL_16]], %[[VAL_41]] : f64 -// CHECK: %[[VAL_43:.*]] = mulf %[[VAL_42]], %[[VAL_42]] : f64 -// CHECK: %[[VAL_44:.*]] = mulf %[[VAL_43]], %[[VAL_40]] : f64 -// CHECK: %[[VAL_45:.*]] = math.exp %[[VAL_44]] : f64 -// CHECK: %[[VAL_46:.*]] = mulf %[[VAL_39]], %[[VAL_45]] : f64 -// CHECK: %[[VAL_47:.*]] = constant 3.9894228040143269 : f64 -// CHECK: %[[VAL_48:.*]] = constant -49.999999999999993 : f64 -// CHECK: %[[VAL_49:.*]] = constant 2.500000e-01 : f64 -// CHECK: %[[VAL_50:.*]] = subf %[[VAL_18]], %[[VAL_49]] : f64 -// CHECK: %[[VAL_51:.*]] = mulf %[[VAL_50]], %[[VAL_50]] : f64 -// CHECK: %[[VAL_52:.*]] = mulf %[[VAL_51]], %[[VAL_48]] : f64 -// CHECK: %[[VAL_53:.*]] = math.exp %[[VAL_52]] : f64 -// CHECK: %[[VAL_54:.*]] = mulf %[[VAL_47]], %[[VAL_53]] : f64 -// CHECK: %[[VAL_55:.*]] = mulf %[[VAL_24]], %[[VAL_28]] : f64 -// CHECK: %[[VAL_56:.*]] = mulf %[[VAL_55]], %[[VAL_33]] : f64 -// CHECK: %[[VAL_57:.*]] = mulf %[[VAL_56]], %[[VAL_19]] : f64 -// CHECK: %[[VAL_58:.*]] = mulf %[[VAL_38]], %[[VAL_46]] : f64 -// CHECK: %[[VAL_59:.*]] = mulf %[[VAL_58]], %[[VAL_54]] : f64 -// CHECK: %[[VAL_60:.*]] = mulf %[[VAL_59]], %[[VAL_19]] : f64 -// CHECK: %[[VAL_61:.*]] = addf %[[VAL_57]], %[[VAL_60]] : f64 -// CHECK: %[[VAL_62:.*]] = math.log %[[VAL_61]] : f64 -// CHECK: memref.store %[[VAL_62]], %[[VAL_1]]{{\[}}%[[VAL_6]]] : memref -// CHECK: } -// CHECK: return -// CHECK: } - -// CHECK-LABEL: func @spn_vector( -// CHECK-SAME: %[[VAL_0:.*]]: memref, -// CHECK-SAME: %[[VAL_1:.*]]: memref) { -// CHECK: %[[VAL_2:.*]] = constant 0 : index -// CHECK: %[[VAL_3:.*]] = memref.dim %[[VAL_0]], %[[VAL_2]] : memref -// CHECK: %[[VAL_4:.*]] = memref.alloc(%[[VAL_3]]) : memref -// CHECK: call @task_0(%[[VAL_0]], %[[VAL_4]]) : (memref, memref) -> () -// CHECK: %[[VAL_5:.*]] = constant 0 : index -// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_4]], %[[VAL_5]] : memref -// CHECK: %[[VAL_7:.*]] = constant 0 : index -// CHECK: %[[VAL_8:.*]] = constant 1 : index -// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] { -// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] : memref -// CHECK: memref.store %[[VAL_10]], %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref -// CHECK: } +// CHECK: %[[VAL_2:.*]] = constant 1.000000e-01 : f64 +// CHECK: %[[VAL_3:.*]] = constant 1 : index +// CHECK: %[[VAL_4:.*]] = constant 1.000000e+00 : f64 +// CHECK: %[[VAL_5:.*]] = cmpf ult, %[[VAL_2]], %[[VAL_4]] : f64 +// CHECK: %[[VAL_6:.*]] = constant 3.500000e-01 : f64 +// CHECK: %[[VAL_7:.*]] = constant 5.500000e-01 : f64 +// CHECK: %[[VAL_8:.*]] = select %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : f64 +// CHECK: memref.store %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_3]]] : memref // CHECK: return // CHECK: } From 7f2f782bddc85d3a9ece7cd72326d4a5a1d1610c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Halkenh=C3=A4user?= Date: Mon, 2 Aug 2021 14:55:01 +0200 Subject: [PATCH 09/12] [FIX] Incorporated remaining feedback (excluding LoSPNOps.cpp/.h) from LS. --- .../Conversion/LoSPNtoCPU/LoSPNtoCPUConversionPasses.cpp | 1 - mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp | 6 +++--- .../LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp | 4 ++-- mlir/lib/Dialect/LoSPN/LoSPNOps.cpp | 4 ++-- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Conversion/LoSPNtoCPU/LoSPNtoCPUConversionPasses.cpp b/mlir/lib/Conversion/LoSPNtoCPU/LoSPNtoCPUConversionPasses.cpp index b36294a1..069cf8bb 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/LoSPNtoCPUConversionPasses.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/LoSPNtoCPUConversionPasses.cpp @@ -137,7 +137,6 @@ void mlir::spn::LoSPNNodeVectorizationPass::runOnOperation() { target.addLegalOp(); OwningRewritePatternList patterns(&getContext()); - patterns.insert(typeConverter, &getContext()); mlir::spn::populateLoSPNCPUVectorizationNodePatterns(patterns, &getContext(), typeConverter); auto op = getOperation(); diff --git a/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp b/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp index 63140a36..3d4be1cb 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp @@ -323,7 +323,6 @@ mlir::LogicalResult mlir::spn::GaussianLogLowering::matchAndRewrite(mlir::spn::l Value gaussian = rewriter.create(op->getLoc(), coefficientConst, fraction); if (op.supportMarginal()) { auto isNan = rewriter.create(op->getLoc(), CmpFPredicate::UNO, index, index); - // FixMe / Question: Could this be a bug? (Either rename to 'constZero' -OR- set to 1.0 (instead of 0.0)) auto constOne = rewriter.create(op.getLoc(), rewriter.getFloatAttr(resultType, 0.0)); gaussian = rewriter.create(op.getLoc(), isNan, constOne, gaussian); } @@ -530,9 +529,10 @@ mlir::LogicalResult mlir::spn::SelectLowering::matchAndRewrite(mlir::spn::low::S val_false = rewriter.create(op->getLoc(), op.val_falseAttr().getType(), op.val_falseAttr()); } - mlir::Value leaf = rewriter.template create(op.getLoc(), cond, val_true, val_false); + + mlir::Value leaf = rewriter.create(op.getLoc(), cond, val_true, val_false); if (op.supportMarginal()) { - assert(inputTy.template isa()); + assert(inputTy.isa()); auto isNan = rewriter.create(op->getLoc(), mlir::CmpFPredicate::UNO, op.input(), op.input()); auto marginalValue = (computesLog) ? 0.0 : 1.0; auto constOne = rewriter.create(op.getLoc(), diff --git a/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp b/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp index 4e091683..dd42a978 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp @@ -699,9 +699,9 @@ mlir::LogicalResult mlir::spn::VectorizeSelectLeaf::matchAndRewrite(mlir::spn::l val_false = broadcastVectorConstant(vectorType, op.val_falseAttr().getValueAsDouble(), rewriter, op.getLoc()); } - mlir::Value leaf = rewriter.template create(op.getLoc(), cond, val_true, val_false); + mlir::Value leaf = rewriter.create(op.getLoc(), cond, val_true, val_false); if (op.supportMarginal()) { - assert(inputTy.template isa()); + assert(inputTy.isa()); auto isNan = rewriter.create(op->getLoc(), CmpFPredicate::UNO, inputVec, inputVec); auto marginalValue = (computesLog) ? 0.0 : 1.0; auto constOne = broadcastVectorConstant(vectorType, marginalValue, rewriter, op.getLoc()); diff --git a/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp b/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp index 7ea354ba..569b9c6f 100644 --- a/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp +++ b/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp @@ -388,7 +388,7 @@ ::mlir::LogicalResult mlir::spn::low::SPNCategoricalLeaf::canonicalize(SPNCatego op.supportMarginalAttr()); return success(); } - return failure(); + return rewriter.notifyMatchFailure(op, "Categorical held != 2 probabilities (no reduction to select possible)"); } //===----------------------------------------------------------------------===// @@ -423,7 +423,7 @@ ::mlir::LogicalResult mlir::spn::low::SPNHistogramLeaf::canonicalize(SPNHistogra return success(); } } - return failure(); + return rewriter.notifyMatchFailure(op, "Histogram held != 2 buckets (no reduction to select possible)"); } #define GET_OP_CLASSES From 89ff05bac2057668e3689deda4a8fd1a86af5803 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Halkenh=C3=A4user?= Date: Mon, 2 Aug 2021 19:55:15 +0200 Subject: [PATCH 10/12] [FIX] Create correct element-type for the threshold vector (instead of convert). --- .../LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp b/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp index dd42a978..3a345732 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp @@ -666,16 +666,14 @@ mlir::LogicalResult mlir::spn::VectorizeSelectLeaf::matchAndRewrite(mlir::spn::l // If the input type is not an integer, but also not a float, we cannot convert it and this pattern fails. mlir::Value cond; auto booleanVectorTy = VectorType::get(vectorType.getShape(), IntegerType::get(op.getContext(), 1)); - Value thresholdVec = - broadcastVectorConstant(vectorType, op.input_true_thresholdAttr().getValueAsDouble(), rewriter, op.getLoc()); if (inputTy.isa()) { + auto thresholdVec = + broadcastVectorConstant(vectorType, op.input_true_thresholdAttr().getValueAsDouble(), rewriter, op.getLoc()); cond = rewriter.create(op->getLoc(), booleanVectorTy, mlir::CmpFPredicate::ULT, inputVec, thresholdVec); } else if (inputTy.isa()) { - // Convert from floating-point input to integer value if necessary. - // This conversion is also possible in vectorized mode. - auto intVectorTy = VectorType::get(vectorType.getShape(), inputTy); - thresholdVec = rewriter.create(op->getLoc(), thresholdVec, intVectorTy); + auto thresholdVec = + broadcastVectorConstant(vectorType, op.input_true_thresholdAttr().getValueAsDouble(), rewriter, op.getLoc()); cond = rewriter.create(op->getLoc(), booleanVectorTy, mlir::CmpIPredicate::ult, inputVec, thresholdVec); } else { From 47beb840d65f2d09009116c1aae16ab858d22c8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Halkenh=C3=A4user?= Date: Thu, 5 Aug 2021 13:50:51 +0200 Subject: [PATCH 11/12] [FIX] Adapted Python tests. - Now, they will not(!) be eligible for "select-transformation". (New tests will be added in the next commit.) --- python-interface/test/cpu/test_cpu_histogram.py | 6 +++--- python-interface/test/cpu/test_graph_partitioning.py | 4 ++-- python-interface/test/cpu/test_marginal_cpu_histogram.py | 8 ++++---- .../test/vector/test_log_vector_graph_partitioning.py | 8 ++++---- python-interface/test/vector/test_log_vector_histogram.py | 8 ++++---- python-interface/test/vector/test_vector_histogram.py | 8 ++++---- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/python-interface/test/cpu/test_cpu_histogram.py b/python-interface/test/cpu/test_cpu_histogram.py index 0bd03a00..dcdb2476 100644 --- a/python-interface/test/cpu/test_cpu_histogram.py +++ b/python-interface/test/cpu/test_cpu_histogram.py @@ -17,9 +17,9 @@ def test_cpu_histogram(): # Construct a minimal SPN. - h1 = Histogram([0., 1., 2.], [0.25, 0.75], [1, 1], scope=0) - h2 = Histogram([0., 3., 6., 8.], [0.45, 0.1, 0.55], [1, 1], scope=1) - h3 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=0) + h1 = Histogram([0., 1., 3.], [0.25, 0.75], [1, 1], scope=0) + h2 = Histogram([0., 3., 6., 8.], [0.35, 0.1, 0.55], [1, 1], scope=1) + h3 = Histogram([0., 1., 4.], [0.33, 0.67], [1, 1], scope=0) h4 = Histogram([0., 5., 8.], [0.875, 0.125], [1, 1], scope=1) p0 = Product(children=[h1, h2]) diff --git a/python-interface/test/cpu/test_graph_partitioning.py b/python-interface/test/cpu/test_graph_partitioning.py index 6fb6d6a1..c436813c 100644 --- a/python-interface/test/cpu/test_graph_partitioning.py +++ b/python-interface/test/cpu/test_graph_partitioning.py @@ -17,9 +17,9 @@ def test_cpu_histogram(): # Construct a minimal SPN. - h1 = Histogram([0., 1., 2.], [0.25, 0.75], [1, 1], scope=0) + h1 = Histogram([0., 1., 3.], [0.25, 0.75], [1, 1], scope=0) h2 = Histogram([0., 3., 6., 8.], [0.45, 0.1, 0.55], [1, 1], scope=1) - h3 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=0) + h3 = Histogram([0., 1., 4.], [0.33, 0.67], [1, 1], scope=0) h4 = Histogram([0., 5., 8.], [0.875, 0.125], [1, 1], scope=1) p0 = Product(children=[h1, h2]) diff --git a/python-interface/test/cpu/test_marginal_cpu_histogram.py b/python-interface/test/cpu/test_marginal_cpu_histogram.py index 4b8753a5..d3cd77f4 100644 --- a/python-interface/test/cpu/test_marginal_cpu_histogram.py +++ b/python-interface/test/cpu/test_marginal_cpu_histogram.py @@ -17,10 +17,10 @@ def test_cpu_histogram(): # Construct a minimal SPN. - h1 = Histogram([0., 1., 2.], [0.25, 0.75], [1, 1], scope=0) - h2 = Histogram([0., 1., 2.], [0.45, 0.55], [1, 1], scope=1) - h3 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=0) - h4 = Histogram([0., 1., 2.], [0.875, 0.125], [1, 1], scope=1) + h1 = Histogram([0., 1., 3.], [0.25, 0.75], [1, 1], scope=0) + h2 = Histogram([0., 1., 3.], [0.45, 0.55], [1, 1], scope=1) + h3 = Histogram([0., 1., 3.], [0.33, 0.67], [1, 1], scope=0) + h4 = Histogram([0., 1., 3.], [0.875, 0.125], [1, 1], scope=1) p0 = Product(children=[h1, h2]) p1 = Product(children=[h3, h4]) diff --git a/python-interface/test/vector/test_log_vector_graph_partitioning.py b/python-interface/test/vector/test_log_vector_graph_partitioning.py index 5c45814a..e0c8c7af 100644 --- a/python-interface/test/vector/test_log_vector_graph_partitioning.py +++ b/python-interface/test/vector/test_log_vector_graph_partitioning.py @@ -20,10 +20,10 @@ @pytest.mark.skipif(not CPUCompiler.isVectorizationSupported(), reason="CPU vectorization not supported") def test_log_vector_histogram(): # Construct a minimal SPN. - h1 = Histogram([0., 1., 2.], [0.25, 0.75], [1, 1], scope=0) - h2 = Histogram([0., 1., 2.], [0.45, 0.55], [1, 1], scope=1) - h3 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=0) - h4 = Histogram([0., 1., 2.], [0.875, 0.125], [1, 1], scope=1) + h1 = Histogram([0., 1., 3.], [0.25, 0.75], [1, 1], scope=0) + h2 = Histogram([0., 3., 6., 8.], [0.45, 0.1, 0.55], [1, 1], scope=1) + h3 = Histogram([0., 1., 4.], [0.33, 0.67], [1, 1], scope=0) + h4 = Histogram([0., 5., 8.], [0.875, 0.125], [1, 1], scope=1) p0 = Product(children=[h1, h2]) p1 = Product(children=[h3, h4]) diff --git a/python-interface/test/vector/test_log_vector_histogram.py b/python-interface/test/vector/test_log_vector_histogram.py index 100a2c18..1513f899 100644 --- a/python-interface/test/vector/test_log_vector_histogram.py +++ b/python-interface/test/vector/test_log_vector_histogram.py @@ -20,10 +20,10 @@ @pytest.mark.skipif(not CPUCompiler.isVectorizationSupported(), reason="CPU vectorization not supported") def test_log_vector_histogram(): # Construct a minimal SPN. - h1 = Histogram([0., 1., 2.], [0.25, 0.75], [1, 1], scope=0) - h2 = Histogram([0., 1., 2.], [0.45, 0.55], [1, 1], scope=1) - h3 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=0) - h4 = Histogram([0., 1., 2.], [0.875, 0.125], [1, 1], scope=1) + h1 = Histogram([0., 1., 3.], [0.25, 0.75], [1, 1], scope=0) + h2 = Histogram([0., 3., 6., 8.], [0.45, 0.1, 0.55], [1, 1], scope=1) + h3 = Histogram([0., 1., 4.], [0.33, 0.67], [1, 1], scope=0) + h4 = Histogram([0., 5., 8.], [0.875, 0.125], [1, 1], scope=1) p0 = Product(children=[h1, h2]) p1 = Product(children=[h3, h4]) diff --git a/python-interface/test/vector/test_vector_histogram.py b/python-interface/test/vector/test_vector_histogram.py index 7d24dbc0..b66ebbbf 100644 --- a/python-interface/test/vector/test_vector_histogram.py +++ b/python-interface/test/vector/test_vector_histogram.py @@ -20,10 +20,10 @@ @pytest.mark.skipif(not CPUCompiler.isVectorizationSupported(), reason="CPU vectorization not supported") def test_vector_histogram(): # Construct a minimal SPN. - h1 = Histogram([0., 1., 2.], [0.25, 0.75], [1, 1], scope=0) - h2 = Histogram([0., 1., 2.], [0.45, 0.55], [1, 1], scope=1) - h3 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=0) - h4 = Histogram([0., 1., 2.], [0.875, 0.125], [1, 1], scope=1) + h1 = Histogram([0., 1., 3.], [0.25, 0.75], [1, 1], scope=0) + h2 = Histogram([0., 3., 6., 8.], [0.45, 0.1, 0.55], [1, 1], scope=1) + h3 = Histogram([0., 1., 4.], [0.33, 0.67], [1, 1], scope=0) + h4 = Histogram([0., 5., 8.], [0.875, 0.125], [1, 1], scope=1) p0 = Product(children=[h1, h2]) p1 = Product(children=[h3, h4]) From 6ebc4e04a47db1dade47424cd3132e2fb55aabc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Halkenh=C3=A4user?= Date: Mon, 4 Oct 2021 19:36:19 +0200 Subject: [PATCH 12/12] Implemented avoidance of out-of-bounds access. - Added corresponding scalar and vectorized codegen - Adapted regression tests - Added Python computation tests --- .../Conversion/LoSPNtoCPU/NodePatterns.cpp | 166 ++++++++++++---- .../Vectorization/VectorizeNodePatterns.cpp | 56 +++++- mlir/lib/Dialect/LoSPN/LoSPNOps.cpp | 3 +- .../lower-to-cpu-nodes-boundary-check.mlir | 74 +++++++ .../lower-to-cpu-nodes-scalar-log.mlir | 136 ++++++++----- .../lower-to-cpu-nodes-scalar.mlir | 130 ++++++++----- .../lower-to-cpu-nodes-select.mlir | 19 +- .../lower-to-cpu-nodes-vectorize-log.mlir | 180 ++++++++++-------- .../lower-to-cpu-nodes-vectorize.mlir | 172 ++++++++++------- .../test/cpu/test_cpu_out_of_bounds.py | 63 ++++++ ...pu_transformation_categorical_to_select.py | 50 +++++ ..._cpu_transformation_histogram_to_select.py | 50 +++++ ...t_marginal_cpu_transformation_to_select.py | 52 +++++ .../test/vector/test_vector_out_of_bounds.py | 67 +++++++ 14 files changed, 930 insertions(+), 288 deletions(-) create mode 100644 mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-boundary-check.mlir create mode 100644 python-interface/test/cpu/test_cpu_out_of_bounds.py create mode 100644 python-interface/test/cpu/test_cpu_transformation_categorical_to_select.py create mode 100644 python-interface/test/cpu/test_cpu_transformation_histogram_to_select.py create mode 100644 python-interface/test/cpu/test_marginal_cpu_transformation_to_select.py create mode 100644 python-interface/test/vector/test_vector_out_of_bounds.py diff --git a/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp b/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp index 3d4be1cb..447eff8f 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/NodePatterns.cpp @@ -336,11 +336,28 @@ namespace { mlir::LogicalResult replaceOpWithGlobalMemref(SourceOp op, mlir::ConversionPatternRewriter& rewriter, mlir::Value indexOperand, llvm::ArrayRef arrayValues, mlir::Type resultType, const std::string& tablePrefix, - bool computesLog) { + bool computesLog, int lowerBound, int upperBound) { static int tableCount = 0; if (!resultType.isIntOrFloat()) { // Currently only handling Int and Float result types. - return mlir::failure(); + return rewriter.notifyMatchFailure(op, "Match failed because result is neither float nor integer"); + } + + // Convert input value from float to integer if necessary. + mlir::Value index = indexOperand; + if (!index.getType().isIntOrIndex()) { + // If the input type is not an integer, but also not a float, we cannot convert it and this pattern fails. + if (!index.getType().isIntOrFloat()) { + return rewriter.notifyMatchFailure(op, "Match failed because input is neither float nor integer/index"); + } + index = rewriter.template create(op.getLoc(), index, rewriter.getI64Type()); + } + auto integerIndex = index; + auto idxTy = index.getType(); + + // Cast input value to index if necessary. + if (!index.getType().isIndex()) { + index = rewriter.template create(op.getLoc(), rewriter.getIndexType(), index); } // Construct a DenseElementsAttr to hold the array values. @@ -361,34 +378,66 @@ namespace { // Restore insertion point rewriter.restoreInsertionPoint(restore); + auto idxMin = rewriter.create(op->getLoc(), + idxTy, + rewriter.getIntegerAttr(idxTy, lowerBound)); + auto idxMax = rewriter.create(op->getLoc(), + idxTy, + rewriter.getIntegerAttr(idxTy, upperBound)); + + auto boolTy = mlir::IntegerType::get(op.getContext(), 1); + auto inBoundsLB = + rewriter.create(op->getLoc(), boolTy, mlir::CmpIPredicate::sge, integerIndex, idxMin); + auto inBoundsUB = + rewriter.create(op->getLoc(), boolTy, mlir::CmpIPredicate::slt, integerIndex, idxMax); + auto inBounds = rewriter.create(op->getLoc(), inBoundsLB, inBoundsUB); + + // Perform range-check: if (inBoundsUB) { return memRefValue; } else { return defaultValue; } + auto boundaryIf = rewriter.create(op->getLoc(), resultType, inBounds, true); + + // If-Then branch: Legal range access -> memref load + rewriter.setInsertionPointToStart(&boundaryIf.thenRegion().front()); + // Use GetGlobalMemref operation to access the global created above. auto addressOf = rewriter.template create(op.getLoc(), memrefType, symbolName); - // Convert input value from float to integer if necessary. - mlir::Value index = indexOperand; - if (!index.getType().isIntOrIndex()) { - // If the input type is not an integer, but also not a float, we cannot convert it and this pattern fails. - if (!index.getType().isIntOrFloat()) { - return mlir::failure(); - } - index = rewriter.template create(op.getLoc(), index, rewriter.getI64Type()); - } - // Cast input value to index if necessary. - if (!index.getType().isIndex()) { - index = rewriter.template create(op.getLoc(), rewriter.getIndexType(), index); - } + // Replace the source operation with a load from the global memref, // using the source operation's input value as index. - mlir::Value leaf = rewriter.template create(op.getLoc(), addressOf, mlir::ValueRange{index}); + mlir::Value leafThen = rewriter.template create(op.getLoc(), addressOf, mlir::ValueRange{index}); + + // If-Else branch: Illegal range access -> default return + rewriter.setInsertionPointToStart(&boundaryIf.elseRegion().front()); + double defaultValue = (computesLog) ? static_cast(-INFINITY) : 0.0; + mlir::Value leafElse = rewriter.create(op.getLoc(), resultType, rewriter.getFloatAttr(resultType, defaultValue)); + if (op.supportMarginal()) { - assert(indexOperand.getType().template isa()); + // Set the insertion point right before the if-op + rewriter.setInsertionPoint(boundaryIf); + // Define the values needed for marginal-support (for both [then/else] branches) auto isNan = rewriter.create(op->getLoc(), mlir::CmpFPredicate::UNO, indexOperand, indexOperand); auto marginalValue = (computesLog) ? 0.0 : 1.0; auto constOne = rewriter.create(op.getLoc(), rewriter.getFloatAttr(resultType, marginalValue)); - leaf = rewriter.create(op.getLoc(), isNan, constOne, leaf); + + // Place select and yield ops in the corresponding branches + rewriter.setInsertionPointToEnd(&boundaryIf.thenRegion().back()); + mlir::Value leaf = rewriter.create(op.getLoc(), isNan, constOne, leafThen); + rewriter.create(op->getLoc(), leaf); + + rewriter.setInsertionPointToEnd(&boundaryIf.elseRegion().back()); + leaf = rewriter.create(op.getLoc(), isNan, constOne, leafElse); + rewriter.create(op->getLoc(), leaf); + } else { + // Place yield ops in the corresponding branches + rewriter.setInsertionPointToEnd(&boundaryIf.thenRegion().back()); + rewriter.create(op->getLoc(), leafThen); + rewriter.setInsertionPointToEnd(&boundaryIf.elseRegion().back()); + rewriter.create(op->getLoc(), leafElse); } - rewriter.replaceOp(op, leaf); + + // Replace access by result(s) of created if-op + rewriter.replaceOp(op, boundaryIf.results()); return mlir::success(); } @@ -408,7 +457,7 @@ mlir::LogicalResult mlir::spn::HistogramLowering::matchAndRewrite(mlir::spn::low llvm::DenseMap values; int minLB = std::numeric_limits::max(); int maxUB = std::numeric_limits::min(); - for (auto& b : op.bucketsAttr()) { + for (auto& b: op.bucketsAttr()) { auto bucket = b.cast(); auto lb = bucket.lb().getInt(); auto ub = bucket.ub().getInt(); @@ -443,7 +492,7 @@ mlir::LogicalResult mlir::spn::HistogramLowering::matchAndRewrite(mlir::spn::low if (values.count(i)) { indexVal = (computesLog) ? log(values[i]) : values[i]; } else { - // Fill up with 0 if no value was defined by the histogram. + // Fill up with corresponding zero if no value was defined by the histogram. indexVal = (computesLog) ? static_cast(-INFINITY) : 0; } // Construct attribute with constant value. Need to distinguish cases here due to different builder methods. @@ -455,7 +504,7 @@ mlir::LogicalResult mlir::spn::HistogramLowering::matchAndRewrite(mlir::spn::low } return replaceOpWithGlobalMemref(op, rewriter, operands[0], valArray, - resultType, "histogram_", computesLog); + resultType, "histogram_", computesLog, minLB, maxUB); } mlir::LogicalResult mlir::spn::CategoricalLowering::matchAndRewrite(mlir::spn::low::SPNCategoricalLeaf op, llvm::ArrayRef operands, @@ -471,8 +520,9 @@ mlir::LogicalResult mlir::spn::CategoricalLowering::matchAndRewrite(mlir::spn::l resultType = logType.getBaseType(); computesLog = true; } + SmallVector values; - for (auto val : op.probabilities().getValue()) { + for (auto val: op.probabilities().getValue()) { if (computesLog) { auto floatVal = val.dyn_cast(); assert(floatVal); @@ -481,8 +531,15 @@ mlir::LogicalResult mlir::spn::CategoricalLowering::matchAndRewrite(mlir::spn::l values.push_back(val); } } - return replaceOpWithGlobalMemref(op, rewriter, operands[0], - values, resultType, "categorical_", computesLog); + return replaceOpWithGlobalMemref(op, + rewriter, + operands[0], + values, + resultType, + "categorical_", + computesLog, + 0, + op.probabilities().size()); } mlir::LogicalResult mlir::spn::SelectLowering::matchAndRewrite(mlir::spn::low::SPNSelectLeaf op, @@ -491,21 +548,56 @@ mlir::LogicalResult mlir::spn::SelectLowering::matchAndRewrite(mlir::spn::low::S if (op.checkVectorized()) { return rewriter.notifyMatchFailure(op, "Pattern only matches non-vectorized SelectLeaf"); } - // If the input type is not an integer, but also not a float, we cannot convert it and this pattern fails. - mlir::Value cond; + auto inputTy = op.input().getType(); + if (!inputTy.isIntOrFloat()) { + return rewriter.notifyMatchFailure(op, "Input type should be either int or float."); + } + + mlir::Value cond; + mlir::Value idxMin, idxMax; + mlir::Value inBoundsLB, inBoundsUB; + auto boolTy = IntegerType::get(op.getContext(), 1); + // Regarding the actual LB/UB values we will use the knowledge that: + // (UB - LB = 2) && input_true_thresholdAttr == LB + 1 == UB - 1 if (inputTy.isa()) { + idxMin = rewriter.create(op->getLoc(), + rewriter.getF64Type(), + rewriter.getFloatAttr(rewriter.getF64Type(), + op.input_true_thresholdAttr().getValueAsDouble() + - 1.0)); + idxMax = rewriter.create(op->getLoc(), + rewriter.getF64Type(), + rewriter.getFloatAttr(rewriter.getF64Type(), + op.input_true_thresholdAttr().getValueAsDouble() + + 1.0)); + inBoundsLB = + rewriter.create(op->getLoc(), boolTy, mlir::CmpFPredicate::UGE, op.input(), idxMin); + inBoundsUB = + rewriter.create(op->getLoc(), boolTy, mlir::CmpFPredicate::ULT, op.input(), idxMax); auto thresholdAttr = FloatAttr::get(inputTy, op.input_true_thresholdAttr().getValueAsDouble()); auto input_true_threshold = rewriter.create(op->getLoc(), inputTy, thresholdAttr); - cond = rewriter.create(op->getLoc(), IntegerType::get(op.getContext(), 1), - mlir::CmpFPredicate::ULT, op.input(), input_true_threshold); + cond = + rewriter.create(op->getLoc(), boolTy, mlir::CmpFPredicate::ULT, op.input(), input_true_threshold); } else if (inputTy.isa()) { + idxMin = rewriter.create(op->getLoc(), + op.input().getType(), + rewriter.getIntegerAttr(op.input().getType(), + op.input_true_thresholdAttr().getValueAsDouble() + - 1.0)); + idxMax = rewriter.create(op->getLoc(), + op.input().getType(), + rewriter.getIntegerAttr(op.input().getType(), + op.input_true_thresholdAttr().getValueAsDouble() + + 1.0)); + inBoundsLB = + rewriter.create(op->getLoc(), boolTy, mlir::CmpIPredicate::sge, op.input(), idxMin); + inBoundsUB = + rewriter.create(op->getLoc(), boolTy, mlir::CmpIPredicate::slt, op.input(), idxMax); auto thresholdAttr = IntegerAttr::get(inputTy, op.input_true_thresholdAttr().getValueAsDouble()); auto input_true_threshold = rewriter.create(op->getLoc(), inputTy, thresholdAttr); - cond = rewriter.create(op->getLoc(), IntegerType::get(op.getContext(), 1), - mlir::CmpIPredicate::ult, op.input(), input_true_threshold); - } else { - return rewriter.notifyMatchFailure(op, "Expected condition-value to be either Float- or IntegerType"); + cond = + rewriter.create(op->getLoc(), boolTy, mlir::CmpIPredicate::ult, op.input(), input_true_threshold); } Type resultType = op.getResult().getType(); @@ -529,7 +621,6 @@ mlir::LogicalResult mlir::spn::SelectLowering::matchAndRewrite(mlir::spn::low::S val_false = rewriter.create(op->getLoc(), op.val_falseAttr().getType(), op.val_falseAttr()); } - mlir::Value leaf = rewriter.create(op.getLoc(), cond, val_true, val_false); if (op.supportMarginal()) { assert(inputTy.isa()); @@ -539,6 +630,13 @@ mlir::LogicalResult mlir::spn::SelectLowering::matchAndRewrite(mlir::spn::low::S rewriter.getFloatAttr(resultType, marginalValue)); leaf = rewriter.create(op.getLoc(), isNan, constOne, leaf); } + + auto inBounds = rewriter.create(op->getLoc(), inBoundsLB, inBoundsUB); + double defaultValue = (computesLog) ? static_cast(-INFINITY) : 0.0; + mlir::Value defaultReturn = + rewriter.create(op.getLoc(), resultType, rewriter.getFloatAttr(resultType, defaultValue)); + + leaf = rewriter.create(op.getLoc(), inBounds, leaf, defaultReturn); rewriter.replaceOp(op, leaf); return success(); diff --git a/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp b/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp index 3a345732..d07a4620 100644 --- a/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp +++ b/mlir/lib/Conversion/LoSPNtoCPU/Vectorization/VectorizeNodePatterns.cpp @@ -446,6 +446,8 @@ namespace { mlir::LogicalResult replaceOpWithGatherFromGlobalMemref(SourceOp op, mlir::ConversionPatternRewriter& rewriter, mlir::Value indexOperand, + int lowerBound, + int upperBound, llvm::ArrayRef arrayValues, mlir::Type resultType, const std::string& tablePrefix, bool computesLog) { @@ -453,7 +455,7 @@ namespace { auto inputType = indexOperand.getType(); if (!inputType.template isa()) { // This pattern only handles vectorized implementations and fails if the input is not a vector. - return mlir::failure(); + return rewriter.notifyMatchFailure(op, "Match failed because input is not a VectorType"); } // Construct a DenseElementsAttr to hold the array values. @@ -483,20 +485,53 @@ namespace { auto indexType = inputType.template dyn_cast().getElementType(); if (!indexType.isIntOrIndex()) { if (indexType.template isa()) { + indexType = rewriter.getI64Type(); index = rewriter.template create(op.getLoc(), index, - mlir::VectorType::get(vectorShape, rewriter.getI64Type())); + mlir::VectorType::get(vectorShape, indexType)); } else { // The input type is neither int/index nor float, conversion unknown, fail this pattern. - return mlir::failure(); + return rewriter.notifyMatchFailure(op, "Match failed because input is neither float nor integer/index"); } } // Construct the constant pass-thru value (values used if the mask is false for an element of the vector). auto vectorType = mlir::VectorType::get(vectorShape, resultType); - auto passThru = broadcastVectorConstant(vectorType, 0.0, + double defaultValue = (computesLog) ? static_cast(-INFINITY) : 0.0; + auto passThru = broadcastVectorConstant(vectorType, defaultValue, rewriter, op->getLoc()); - // Construct the constant mask. - auto mask = broadcastVectorConstant(mlir::VectorType::get(vectorShape, rewriter.getI1Type()), true, - rewriter, op->getLoc()); + + mlir::ConstantOp idxMinVec; + mlir::ConstantOp idxMaxVec; + if (indexType.isInteger(32)) { + idxMinVec = + broadcastVectorConstant(mlir::VectorType::get(vectorShape, indexType), lowerBound, rewriter, op->getLoc()); + idxMaxVec = + broadcastVectorConstant(mlir::VectorType::get(vectorShape, indexType), upperBound, rewriter, op->getLoc()); + } else { + idxMinVec = broadcastVectorConstant(mlir::VectorType::get(vectorShape, indexType), + static_cast(lowerBound), + rewriter, + op->getLoc()); + idxMaxVec = broadcastVectorConstant(mlir::VectorType::get(vectorShape, indexType), + static_cast(upperBound), + rewriter, + op->getLoc()); + } + + auto boolVecTy = mlir::VectorType::get(vectorShape, rewriter.getI1Type()); + auto inBoundsLB = + rewriter.create(op->getLoc(), boolVecTy, mlir::CmpIPredicate::sge, index, idxMinVec); + auto inBoundsUB = + rewriter.create(op->getLoc(), boolVecTy, mlir::CmpIPredicate::slt, index, idxMaxVec); + auto inBounds = rewriter.create(op->getLoc(), inBoundsLB, inBoundsUB); + auto trueVec = + broadcastVectorConstant(mlir::VectorType::get(vectorShape, rewriter.getI1Type()), true, rewriter, op->getLoc()); + auto falseVec = broadcastVectorConstant(mlir::VectorType::get(vectorShape, rewriter.getI1Type()), + false, + rewriter, + op->getLoc()); + // Create mask, dependant on an "inBounds" select -> Mask all out-of-bounds accesses. + auto mask = rewriter.create(op.getLoc(), inBounds, trueVec, falseVec); + // Replace the source operation with a gather load from the global memref. mlir::Value constIndex = rewriter.template create(op.getLoc(), rewriter.getIndexAttr(0)); mlir::Value leaf = rewriter.template create(op.getLoc(), vectorType, addressOf, @@ -539,7 +574,8 @@ mlir::LogicalResult mlir::spn::VectorizeCategorical::matchAndRewrite(mlir::spn:: values.push_back(val); } } - return replaceOpWithGatherFromGlobalMemref(op, rewriter, operands[0], + int idxMax = op.probabilities().size(); + return replaceOpWithGatherFromGlobalMemref(op, rewriter, operands[0], 0, idxMax, values, resultType, "categorical_vec_", computesLog); } @@ -558,7 +594,7 @@ mlir::LogicalResult mlir::spn::VectorizeHistogram::matchAndRewrite(mlir::spn::lo llvm::DenseMap values; int minLB = std::numeric_limits::max(); int maxUB = std::numeric_limits::min(); - for (auto& b : op.bucketsAttr()) { + for (auto& b: op.bucketsAttr()) { auto bucket = b.cast(); auto lb = bucket.lb().getInt(); auto ub = bucket.ub().getInt(); @@ -604,7 +640,7 @@ mlir::LogicalResult mlir::spn::VectorizeHistogram::matchAndRewrite(mlir::spn::lo } } - return replaceOpWithGatherFromGlobalMemref(op, rewriter, operands[0], valArray, + return replaceOpWithGatherFromGlobalMemref(op, rewriter, operands[0], minLB, maxUB, valArray, resultType, "histogram_vec_", computesLog); } diff --git a/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp b/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp index 569b9c6f..cbc46d01 100644 --- a/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp +++ b/mlir/lib/Dialect/LoSPN/LoSPNOps.cpp @@ -10,6 +10,7 @@ #include "LoSPN/LoSPNDialect.h" #include "LoSPN/LoSPNAttributes.h" #include "LoSPN/LoSPNInterfaces.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -423,7 +424,7 @@ ::mlir::LogicalResult mlir::spn::low::SPNHistogramLeaf::canonicalize(SPNHistogra return success(); } } - return rewriter.notifyMatchFailure(op, "Histogram held != 2 buckets (no reduction to select possible)"); + return rewriter.notifyMatchFailure(op, "Histogram was not eligible for reduction to select"); } #define GET_OP_CLASSES diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-boundary-check.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-boundary-check.mlir new file mode 100644 index 00000000..90246a2c --- /dev/null +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-boundary-check.mlir @@ -0,0 +1,74 @@ +// RUN: %optcall --convert-lospn-nodes-to-cpu %s | FileCheck %s + +module { + func @task_0(%arg0: memref, %arg1: memref) { + %c0 = constant 0 : index + %c0_0 = constant 0 : index + %0 = memref.dim %arg0, %c0_0 : memref + %c1 = constant 1 : index + scf.for %arg2 = %c0 to %0 step %c1 { + %1 = "lo_spn.batch_read"(%arg0, %arg2) {sampleIndex = 0 : ui32} : (memref, index) -> i32 + %2 = "lo_spn.select"(%1) {input_true_threshold = 1.000000e+00 : f64, supportMarginal = false, val_false = 7.500000e-01 : f64, val_true = 2.500000e-01 : f64} : (i32) -> f64 + %3 = "lo_spn.log"(%2) : (f64) -> f64 + "lo_spn.batch_write"(%3, %arg1, %arg2) : (f64, memref, index) -> () + } + return + } + func @spn_kernel(%arg0: memref, %arg1: memref) { + %c0 = constant 0 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.alloc(%0) : memref + call @task_0(%arg0, %1) : (memref, memref) -> () + "lo_spn.copy"(%1, %arg1) : (memref, memref) -> () + "lo_spn.return"() : () -> () + } +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + + +// CHECK-LABEL: func @task_0( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref) { +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_5]] { +// CHECK: %[[VAL_7:.*]] = constant 0 : index +// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_7]]] : memref +// CHECK: %[[VAL_9:.*]] = constant 0 : i32 +// CHECK: %[[VAL_10:.*]] = constant 2 : i32 +// CHECK: %[[VAL_11:.*]] = cmpi sge, %[[VAL_8]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_12:.*]] = cmpi slt, %[[VAL_8]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_13:.*]] = constant 1 : i32 +// CHECK: %[[VAL_14:.*]] = cmpi ult, %[[VAL_8]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_15:.*]] = constant 2.500000e-01 : f64 +// CHECK: %[[VAL_16:.*]] = constant 7.500000e-01 : f64 +// CHECK: %[[VAL_17:.*]] = select %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] : f64 +// CHECK: %[[VAL_18:.*]] = and %[[VAL_11]], %[[VAL_12]] : i1 +// CHECK: %[[VAL_19:.*]] = constant 0.000000e+00 : f64 +// CHECK: %[[VAL_20:.*]] = select %[[VAL_18]], %[[VAL_17]], %[[VAL_19]] : f64 +// CHECK: %[[VAL_21:.*]] = math.log %[[VAL_20]] : f64 +// CHECK: memref.store %[[VAL_21]], %[[VAL_1]]{{\[}}%[[VAL_6]]] : memref +// CHECK: } +// CHECK: return +// CHECK: } + +// CHECK-LABEL: func @spn_kernel( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref) { +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = memref.dim %[[VAL_0]], %[[VAL_2]] : memref +// CHECK: %[[VAL_4:.*]] = memref.alloc(%[[VAL_3]]) : memref +// CHECK: call @task_0(%[[VAL_0]], %[[VAL_4]]) : (memref, memref) -> () +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_4]], %[[VAL_5]] : memref +// CHECK: %[[VAL_7:.*]] = constant 0 : index +// CHECK: %[[VAL_8:.*]] = constant 1 : index +// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] { +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] : memref +// CHECK: memref.store %[[VAL_10]], %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar-log.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar-log.mlir index 6a46ad26..0d00ca64 100644 --- a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar-log.mlir +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar-log.mlir @@ -384,52 +384,96 @@ module { // CHECK: %[[VAL_166:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_156]], %[[VAL_165]]] : memref // CHECK: %[[VAL_167:.*]] = constant 5 : index // CHECK: %[[VAL_168:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_156]], %[[VAL_167]]] : memref -// CHECK: %[[VAL_169:.*]] = memref.get_global @categorical_0 : memref<3xf32> -// CHECK: %[[VAL_170:.*]] = fptoui %[[VAL_158]] : f32 to i64 -// CHECK: %[[VAL_171:.*]] = index_cast %[[VAL_170]] : i64 to index -// CHECK: %[[VAL_172:.*]] = memref.load %[[VAL_169]]{{\[}}%[[VAL_171]]] : memref<3xf32> -// CHECK: %[[VAL_173:.*]] = memref.get_global @categorical_1 : memref<3xf32> -// CHECK: %[[VAL_174:.*]] = fptoui %[[VAL_160]] : f32 to i64 -// CHECK: %[[VAL_175:.*]] = index_cast %[[VAL_174]] : i64 to index -// CHECK: %[[VAL_176:.*]] = memref.load %[[VAL_173]]{{\[}}%[[VAL_175]]] : memref<3xf32> -// CHECK: %[[VAL_177:.*]] = memref.get_global @histogram_0 : memref<2xf32> -// CHECK: %[[VAL_178:.*]] = fptoui %[[VAL_162]] : f32 to i64 -// CHECK: %[[VAL_179:.*]] = index_cast %[[VAL_178]] : i64 to index -// CHECK: %[[VAL_180:.*]] = memref.load %[[VAL_177]]{{\[}}%[[VAL_179]]] : memref<2xf32> -// CHECK: %[[VAL_181:.*]] = memref.get_global @histogram_1 : memref<2xf32> -// CHECK: %[[VAL_182:.*]] = fptoui %[[VAL_164]] : f32 to i64 -// CHECK: %[[VAL_183:.*]] = index_cast %[[VAL_182]] : i64 to index -// CHECK: %[[VAL_184:.*]] = memref.load %[[VAL_181]]{{\[}}%[[VAL_183]]] : memref<2xf32> -// CHECK: %[[VAL_185:.*]] = constant -5.000000e-01 : f32 -// CHECK: %[[VAL_186:.*]] = constant -0.918938517 : f32 -// CHECK: %[[VAL_187:.*]] = constant 5.000000e-01 : f32 -// CHECK: %[[VAL_188:.*]] = subf %[[VAL_166]], %[[VAL_187]] : f32 -// CHECK: %[[VAL_189:.*]] = mulf %[[VAL_188]], %[[VAL_188]] : f32 -// CHECK: %[[VAL_190:.*]] = mulf %[[VAL_189]], %[[VAL_185]] : f32 -// CHECK: %[[VAL_191:.*]] = addf %[[VAL_186]], %[[VAL_190]] : f32 -// CHECK: %[[VAL_192:.*]] = constant -5.000000e+01 : f32 -// CHECK: %[[VAL_193:.*]] = constant 1.38364661 : f32 -// CHECK: %[[VAL_194:.*]] = constant 2.500000e-01 : f32 -// CHECK: %[[VAL_195:.*]] = subf %[[VAL_168]], %[[VAL_194]] : f32 -// CHECK: %[[VAL_196:.*]] = mulf %[[VAL_195]], %[[VAL_195]] : f32 -// CHECK: %[[VAL_197:.*]] = mulf %[[VAL_196]], %[[VAL_192]] : f32 -// CHECK: %[[VAL_198:.*]] = addf %[[VAL_193]], %[[VAL_197]] : f32 -// CHECK: %[[VAL_199:.*]] = addf %[[VAL_172]], %[[VAL_176]] : f32 -// CHECK: %[[VAL_200:.*]] = addf %[[VAL_199]], %[[VAL_180]] : f32 -// CHECK: %[[VAL_201:.*]] = constant 1.000000e-01 : f32 -// CHECK: %[[VAL_202:.*]] = addf %[[VAL_200]], %[[VAL_201]] : f32 -// CHECK: %[[VAL_203:.*]] = addf %[[VAL_184]], %[[VAL_191]] : f32 -// CHECK: %[[VAL_204:.*]] = addf %[[VAL_203]], %[[VAL_198]] : f32 -// CHECK: %[[VAL_205:.*]] = constant 1.000000e-01 : f32 -// CHECK: %[[VAL_206:.*]] = addf %[[VAL_204]], %[[VAL_205]] : f32 -// CHECK: %[[VAL_207:.*]] = cmpf ogt, %[[VAL_202]], %[[VAL_206]] : f32 -// CHECK: %[[VAL_208:.*]] = select %[[VAL_207]], %[[VAL_202]], %[[VAL_206]] : f32 -// CHECK: %[[VAL_209:.*]] = select %[[VAL_207]], %[[VAL_206]], %[[VAL_202]] : f32 -// CHECK: %[[VAL_210:.*]] = subf %[[VAL_209]], %[[VAL_208]] : f32 -// CHECK: %[[VAL_211:.*]] = math.exp %[[VAL_210]] : f32 -// CHECK: %[[VAL_212:.*]] = math.log1p %[[VAL_211]] : f32 -// CHECK: %[[VAL_213:.*]] = addf %[[VAL_208]], %[[VAL_212]] : f32 -// CHECK: memref.store %[[VAL_213]], %[[VAL_1]]{{\[}}%[[VAL_156]]] : memref +// CHECK: %[[VAL_169:.*]] = fptoui %[[VAL_158]] : f32 to i64 +// CHECK: %[[VAL_170:.*]] = index_cast %[[VAL_169]] : i64 to index +// CHECK: %[[VAL_171:.*]] = constant 0 : i64 +// CHECK: %[[VAL_172:.*]] = constant 3 : i64 +// CHECK: %[[VAL_173:.*]] = cmpi sge, %[[VAL_169]], %[[VAL_171]] : i64 +// CHECK: %[[VAL_174:.*]] = cmpi slt, %[[VAL_169]], %[[VAL_172]] : i64 +// CHECK: %[[VAL_175:.*]] = and %[[VAL_173]], %[[VAL_174]] : i1 +// CHECK: %[[VAL_176:.*]] = scf.if %[[VAL_175]] -> (f32) { +// CHECK: %[[VAL_177:.*]] = memref.get_global @categorical_0 : memref<3xf32> +// CHECK: %[[VAL_178:.*]] = memref.load %[[VAL_177]]{{\[}}%[[VAL_170]]] : memref<3xf32> +// CHECK: scf.yield %[[VAL_178]] : f32 +// CHECK: } else { +// CHECK: %[[VAL_179:.*]] = constant 0xFF800000 : f32 +// CHECK: scf.yield %[[VAL_179]] : f32 +// CHECK: } +// CHECK: %[[VAL_180:.*]] = fptoui %[[VAL_160]] : f32 to i64 +// CHECK: %[[VAL_181:.*]] = index_cast %[[VAL_180]] : i64 to index +// CHECK: %[[VAL_182:.*]] = constant 0 : i64 +// CHECK: %[[VAL_183:.*]] = constant 3 : i64 +// CHECK: %[[VAL_184:.*]] = cmpi sge, %[[VAL_180]], %[[VAL_182]] : i64 +// CHECK: %[[VAL_185:.*]] = cmpi slt, %[[VAL_180]], %[[VAL_183]] : i64 +// CHECK: %[[VAL_186:.*]] = and %[[VAL_184]], %[[VAL_185]] : i1 +// CHECK: %[[VAL_187:.*]] = scf.if %[[VAL_186]] -> (f32) { +// CHECK: %[[VAL_188:.*]] = memref.get_global @categorical_1 : memref<3xf32> +// CHECK: %[[VAL_189:.*]] = memref.load %[[VAL_188]]{{\[}}%[[VAL_181]]] : memref<3xf32> +// CHECK: scf.yield %[[VAL_189]] : f32 +// CHECK: } else { +// CHECK: %[[VAL_190:.*]] = constant 0xFF800000 : f32 +// CHECK: scf.yield %[[VAL_190]] : f32 +// CHECK: } +// CHECK: %[[VAL_191:.*]] = fptoui %[[VAL_162]] : f32 to i64 +// CHECK: %[[VAL_192:.*]] = index_cast %[[VAL_191]] : i64 to index +// CHECK: %[[VAL_193:.*]] = constant 0 : i64 +// CHECK: %[[VAL_194:.*]] = constant 2 : i64 +// CHECK: %[[VAL_195:.*]] = cmpi sge, %[[VAL_191]], %[[VAL_193]] : i64 +// CHECK: %[[VAL_196:.*]] = cmpi slt, %[[VAL_191]], %[[VAL_194]] : i64 +// CHECK: %[[VAL_197:.*]] = and %[[VAL_195]], %[[VAL_196]] : i1 +// CHECK: %[[VAL_198:.*]] = scf.if %[[VAL_197]] -> (f32) { +// CHECK: %[[VAL_199:.*]] = memref.get_global @histogram_0 : memref<2xf32> +// CHECK: %[[VAL_200:.*]] = memref.load %[[VAL_199]]{{\[}}%[[VAL_192]]] : memref<2xf32> +// CHECK: scf.yield %[[VAL_200]] : f32 +// CHECK: } else { +// CHECK: %[[VAL_201:.*]] = constant 0xFF800000 : f32 +// CHECK: scf.yield %[[VAL_201]] : f32 +// CHECK: } +// CHECK: %[[VAL_202:.*]] = fptoui %[[VAL_164]] : f32 to i64 +// CHECK: %[[VAL_203:.*]] = index_cast %[[VAL_202]] : i64 to index +// CHECK: %[[VAL_204:.*]] = constant 0 : i64 +// CHECK: %[[VAL_205:.*]] = constant 2 : i64 +// CHECK: %[[VAL_206:.*]] = cmpi sge, %[[VAL_202]], %[[VAL_204]] : i64 +// CHECK: %[[VAL_207:.*]] = cmpi slt, %[[VAL_202]], %[[VAL_205]] : i64 +// CHECK: %[[VAL_208:.*]] = and %[[VAL_206]], %[[VAL_207]] : i1 +// CHECK: %[[VAL_209:.*]] = scf.if %[[VAL_208]] -> (f32) { +// CHECK: %[[VAL_210:.*]] = memref.get_global @histogram_1 : memref<2xf32> +// CHECK: %[[VAL_211:.*]] = memref.load %[[VAL_210]]{{\[}}%[[VAL_203]]] : memref<2xf32> +// CHECK: scf.yield %[[VAL_211]] : f32 +// CHECK: } else { +// CHECK: %[[VAL_212:.*]] = constant 0xFF800000 : f32 +// CHECK: scf.yield %[[VAL_212]] : f32 +// CHECK: } +// CHECK: %[[VAL_213:.*]] = constant -5.000000e-01 : f32 +// CHECK: %[[VAL_214:.*]] = constant -0.918938517 : f32 +// CHECK: %[[VAL_215:.*]] = constant 5.000000e-01 : f32 +// CHECK: %[[VAL_216:.*]] = subf %[[VAL_166]], %[[VAL_215]] : f32 +// CHECK: %[[VAL_217:.*]] = mulf %[[VAL_216]], %[[VAL_216]] : f32 +// CHECK: %[[VAL_218:.*]] = mulf %[[VAL_217]], %[[VAL_213]] : f32 +// CHECK: %[[VAL_219:.*]] = addf %[[VAL_214]], %[[VAL_218]] : f32 +// CHECK: %[[VAL_220:.*]] = constant -5.000000e+01 : f32 +// CHECK: %[[VAL_221:.*]] = constant 1.38364661 : f32 +// CHECK: %[[VAL_222:.*]] = constant 2.500000e-01 : f32 +// CHECK: %[[VAL_223:.*]] = subf %[[VAL_168]], %[[VAL_222]] : f32 +// CHECK: %[[VAL_224:.*]] = mulf %[[VAL_223]], %[[VAL_223]] : f32 +// CHECK: %[[VAL_225:.*]] = mulf %[[VAL_224]], %[[VAL_220]] : f32 +// CHECK: %[[VAL_226:.*]] = addf %[[VAL_221]], %[[VAL_225]] : f32 +// CHECK: %[[VAL_227:.*]] = addf %[[VAL_228:.*]], %[[VAL_229:.*]] : f32 +// CHECK: %[[VAL_230:.*]] = addf %[[VAL_227]], %[[VAL_231:.*]] : f32 +// CHECK: %[[VAL_232:.*]] = constant 1.000000e-01 : f32 +// CHECK: %[[VAL_233:.*]] = addf %[[VAL_230]], %[[VAL_232]] : f32 +// CHECK: %[[VAL_234:.*]] = addf %[[VAL_235:.*]], %[[VAL_219]] : f32 +// CHECK: %[[VAL_236:.*]] = addf %[[VAL_234]], %[[VAL_226]] : f32 +// CHECK: %[[VAL_237:.*]] = constant 1.000000e-01 : f32 +// CHECK: %[[VAL_238:.*]] = addf %[[VAL_236]], %[[VAL_237]] : f32 +// CHECK: %[[VAL_239:.*]] = cmpf ogt, %[[VAL_233]], %[[VAL_238]] : f32 +// CHECK: %[[VAL_240:.*]] = select %[[VAL_239]], %[[VAL_233]], %[[VAL_238]] : f32 +// CHECK: %[[VAL_241:.*]] = select %[[VAL_239]], %[[VAL_238]], %[[VAL_233]] : f32 +// CHECK: %[[VAL_242:.*]] = subf %[[VAL_241]], %[[VAL_240]] : f32 +// CHECK: %[[VAL_243:.*]] = math.exp %[[VAL_242]] : f32 +// CHECK: %[[VAL_244:.*]] = math.log1p %[[VAL_243]] : f32 +// CHECK: %[[VAL_245:.*]] = addf %[[VAL_240]], %[[VAL_244]] : f32 +// CHECK: memref.store %[[VAL_245]], %[[VAL_1]]{{\[}}%[[VAL_156]]] : memref // CHECK: } // CHECK: return // CHECK: } diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar.mlir index 09ba1e0c..42dc5f8b 100644 --- a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar.mlir +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-scalar.mlir @@ -374,49 +374,93 @@ module { // CHECK: %[[VAL_161:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_151]], %[[VAL_160]]] : memref // CHECK: %[[VAL_162:.*]] = constant 5 : index // CHECK: %[[VAL_163:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_151]], %[[VAL_162]]] : memref -// CHECK: %[[VAL_164:.*]] = memref.get_global @categorical_0 : memref<3xf64> -// CHECK: %[[VAL_165:.*]] = fptoui %[[VAL_153]] : f64 to i64 -// CHECK: %[[VAL_166:.*]] = index_cast %[[VAL_165]] : i64 to index -// CHECK: %[[VAL_167:.*]] = memref.load %[[VAL_164]]{{\[}}%[[VAL_166]]] : memref<3xf64> -// CHECK: %[[VAL_168:.*]] = memref.get_global @categorical_1 : memref<3xf64> -// CHECK: %[[VAL_169:.*]] = fptoui %[[VAL_155]] : f64 to i64 -// CHECK: %[[VAL_170:.*]] = index_cast %[[VAL_169]] : i64 to index -// CHECK: %[[VAL_171:.*]] = memref.load %[[VAL_168]]{{\[}}%[[VAL_170]]] : memref<3xf64> -// CHECK: %[[VAL_172:.*]] = memref.get_global @histogram_0 : memref<2xf64> -// CHECK: %[[VAL_173:.*]] = fptoui %[[VAL_157]] : f64 to i64 -// CHECK: %[[VAL_174:.*]] = index_cast %[[VAL_173]] : i64 to index -// CHECK: %[[VAL_175:.*]] = memref.load %[[VAL_172]]{{\[}}%[[VAL_174]]] : memref<2xf64> -// CHECK: %[[VAL_176:.*]] = memref.get_global @histogram_1 : memref<2xf64> -// CHECK: %[[VAL_177:.*]] = fptoui %[[VAL_159]] : f64 to i64 -// CHECK: %[[VAL_178:.*]] = index_cast %[[VAL_177]] : i64 to index -// CHECK: %[[VAL_179:.*]] = memref.load %[[VAL_176]]{{\[}}%[[VAL_178]]] : memref<2xf64> -// CHECK: %[[VAL_180:.*]] = constant 0.3989422804014327 : f64 -// CHECK: %[[VAL_181:.*]] = constant -5.000000e-01 : f64 -// CHECK: %[[VAL_182:.*]] = constant 5.000000e-01 : f64 -// CHECK: %[[VAL_183:.*]] = subf %[[VAL_161]], %[[VAL_182]] : f64 -// CHECK: %[[VAL_184:.*]] = mulf %[[VAL_183]], %[[VAL_183]] : f64 -// CHECK: %[[VAL_185:.*]] = mulf %[[VAL_184]], %[[VAL_181]] : f64 -// CHECK: %[[VAL_186:.*]] = math.exp %[[VAL_185]] : f64 -// CHECK: %[[VAL_187:.*]] = mulf %[[VAL_180]], %[[VAL_186]] : f64 -// CHECK: %[[VAL_188:.*]] = constant 3.9894228040143269 : f64 -// CHECK: %[[VAL_189:.*]] = constant -49.999999999999993 : f64 -// CHECK: %[[VAL_190:.*]] = constant 2.500000e-01 : f64 -// CHECK: %[[VAL_191:.*]] = subf %[[VAL_163]], %[[VAL_190]] : f64 -// CHECK: %[[VAL_192:.*]] = mulf %[[VAL_191]], %[[VAL_191]] : f64 -// CHECK: %[[VAL_193:.*]] = mulf %[[VAL_192]], %[[VAL_189]] : f64 -// CHECK: %[[VAL_194:.*]] = math.exp %[[VAL_193]] : f64 -// CHECK: %[[VAL_195:.*]] = mulf %[[VAL_188]], %[[VAL_194]] : f64 -// CHECK: %[[VAL_196:.*]] = mulf %[[VAL_167]], %[[VAL_171]] : f64 -// CHECK: %[[VAL_197:.*]] = mulf %[[VAL_196]], %[[VAL_175]] : f64 -// CHECK: %[[VAL_198:.*]] = constant 1.000000e-01 : f64 -// CHECK: %[[VAL_199:.*]] = mulf %[[VAL_197]], %[[VAL_198]] : f64 -// CHECK: %[[VAL_200:.*]] = mulf %[[VAL_179]], %[[VAL_187]] : f64 -// CHECK: %[[VAL_201:.*]] = mulf %[[VAL_200]], %[[VAL_195]] : f64 -// CHECK: %[[VAL_202:.*]] = constant 1.000000e-01 : f64 -// CHECK: %[[VAL_203:.*]] = mulf %[[VAL_201]], %[[VAL_202]] : f64 -// CHECK: %[[VAL_204:.*]] = addf %[[VAL_199]], %[[VAL_203]] : f64 -// CHECK: %[[VAL_205:.*]] = math.log %[[VAL_204]] : f64 -// CHECK: memref.store %[[VAL_205]], %[[VAL_1]]{{\[}}%[[VAL_151]]] : memref +// CHECK: %[[VAL_164:.*]] = fptoui %[[VAL_153]] : f64 to i64 +// CHECK: %[[VAL_165:.*]] = index_cast %[[VAL_164]] : i64 to index +// CHECK: %[[VAL_166:.*]] = constant 0 : i64 +// CHECK: %[[VAL_167:.*]] = constant 3 : i64 +// CHECK: %[[VAL_168:.*]] = cmpi sge, %[[VAL_164]], %[[VAL_166]] : i64 +// CHECK: %[[VAL_169:.*]] = cmpi slt, %[[VAL_164]], %[[VAL_167]] : i64 +// CHECK: %[[VAL_170:.*]] = and %[[VAL_168]], %[[VAL_169]] : i1 +// CHECK: %[[VAL_171:.*]] = scf.if %[[VAL_170]] -> (f64) { +// CHECK: %[[VAL_172:.*]] = memref.get_global @categorical_0 : memref<3xf64> +// CHECK: %[[VAL_173:.*]] = memref.load %[[VAL_172]]{{\[}}%[[VAL_165]]] : memref<3xf64> +// CHECK: scf.yield %[[VAL_173]] : f64 +// CHECK: } else { +// CHECK: %[[VAL_174:.*]] = constant 0.000000e+00 : f64 +// CHECK: scf.yield %[[VAL_174]] : f64 +// CHECK: } +// CHECK: %[[VAL_175:.*]] = fptoui %[[VAL_155]] : f64 to i64 +// CHECK: %[[VAL_176:.*]] = index_cast %[[VAL_175]] : i64 to index +// CHECK: %[[VAL_177:.*]] = constant 0 : i64 +// CHECK: %[[VAL_178:.*]] = constant 3 : i64 +// CHECK: %[[VAL_179:.*]] = cmpi sge, %[[VAL_175]], %[[VAL_177]] : i64 +// CHECK: %[[VAL_180:.*]] = cmpi slt, %[[VAL_175]], %[[VAL_178]] : i64 +// CHECK: %[[VAL_181:.*]] = and %[[VAL_179]], %[[VAL_180]] : i1 +// CHECK: %[[VAL_182:.*]] = scf.if %[[VAL_181]] -> (f64) { +// CHECK: %[[VAL_183:.*]] = memref.get_global @categorical_1 : memref<3xf64> +// CHECK: %[[VAL_184:.*]] = memref.load %[[VAL_183]]{{\[}}%[[VAL_176]]] : memref<3xf64> +// CHECK: scf.yield %[[VAL_184]] : f64 +// CHECK: } else { +// CHECK: %[[VAL_185:.*]] = constant 0.000000e+00 : f64 +// CHECK: scf.yield %[[VAL_185]] : f64 +// CHECK: } +// CHECK: %[[VAL_186:.*]] = fptoui %[[VAL_157]] : f64 to i64 +// CHECK: %[[VAL_187:.*]] = index_cast %[[VAL_186]] : i64 to index +// CHECK: %[[VAL_188:.*]] = constant 0 : i64 +// CHECK: %[[VAL_189:.*]] = constant 2 : i64 +// CHECK: %[[VAL_190:.*]] = cmpi sge, %[[VAL_186]], %[[VAL_188]] : i64 +// CHECK: %[[VAL_191:.*]] = cmpi slt, %[[VAL_186]], %[[VAL_189]] : i64 +// CHECK: %[[VAL_192:.*]] = and %[[VAL_190]], %[[VAL_191]] : i1 +// CHECK: %[[VAL_193:.*]] = scf.if %[[VAL_192]] -> (f64) { +// CHECK: %[[VAL_194:.*]] = memref.get_global @histogram_0 : memref<2xf64> +// CHECK: %[[VAL_195:.*]] = memref.load %[[VAL_194]]{{\[}}%[[VAL_187]]] : memref<2xf64> +// CHECK: scf.yield %[[VAL_195]] : f64 +// CHECK: } else { +// CHECK: %[[VAL_196:.*]] = constant 0.000000e+00 : f64 +// CHECK: scf.yield %[[VAL_196]] : f64 +// CHECK: } +// CHECK: %[[VAL_197:.*]] = fptoui %[[VAL_159]] : f64 to i64 +// CHECK: %[[VAL_198:.*]] = index_cast %[[VAL_197]] : i64 to index +// CHECK: %[[VAL_199:.*]] = constant 0 : i64 +// CHECK: %[[VAL_200:.*]] = constant 2 : i64 +// CHECK: %[[VAL_201:.*]] = cmpi sge, %[[VAL_197]], %[[VAL_199]] : i64 +// CHECK: %[[VAL_202:.*]] = cmpi slt, %[[VAL_197]], %[[VAL_200]] : i64 +// CHECK: %[[VAL_203:.*]] = and %[[VAL_201]], %[[VAL_202]] : i1 +// CHECK: %[[VAL_204:.*]] = scf.if %[[VAL_203]] -> (f64) { +// CHECK: %[[VAL_205:.*]] = memref.get_global @histogram_1 : memref<2xf64> +// CHECK: %[[VAL_206:.*]] = memref.load %[[VAL_205]]{{\[}}%[[VAL_198]]] : memref<2xf64> +// CHECK: scf.yield %[[VAL_206]] : f64 +// CHECK: } else { +// CHECK: %[[VAL_207:.*]] = constant 0.000000e+00 : f64 +// CHECK: scf.yield %[[VAL_207]] : f64 +// CHECK: } +// CHECK: %[[VAL_208:.*]] = constant 0.3989422804014327 : f64 +// CHECK: %[[VAL_209:.*]] = constant -5.000000e-01 : f64 +// CHECK: %[[VAL_210:.*]] = constant 5.000000e-01 : f64 +// CHECK: %[[VAL_211:.*]] = subf %[[VAL_161]], %[[VAL_210]] : f64 +// CHECK: %[[VAL_212:.*]] = mulf %[[VAL_211]], %[[VAL_211]] : f64 +// CHECK: %[[VAL_213:.*]] = mulf %[[VAL_212]], %[[VAL_209]] : f64 +// CHECK: %[[VAL_214:.*]] = math.exp %[[VAL_213]] : f64 +// CHECK: %[[VAL_215:.*]] = mulf %[[VAL_208]], %[[VAL_214]] : f64 +// CHECK: %[[VAL_216:.*]] = constant 3.9894228040143269 : f64 +// CHECK: %[[VAL_217:.*]] = constant -49.999999999999993 : f64 +// CHECK: %[[VAL_218:.*]] = constant 2.500000e-01 : f64 +// CHECK: %[[VAL_219:.*]] = subf %[[VAL_163]], %[[VAL_218]] : f64 +// CHECK: %[[VAL_220:.*]] = mulf %[[VAL_219]], %[[VAL_219]] : f64 +// CHECK: %[[VAL_221:.*]] = mulf %[[VAL_220]], %[[VAL_217]] : f64 +// CHECK: %[[VAL_222:.*]] = math.exp %[[VAL_221]] : f64 +// CHECK: %[[VAL_223:.*]] = mulf %[[VAL_216]], %[[VAL_222]] : f64 +// CHECK: %[[VAL_224:.*]] = mulf %[[VAL_225:.*]], %[[VAL_226:.*]] : f64 +// CHECK: %[[VAL_227:.*]] = mulf %[[VAL_224]], %[[VAL_228:.*]] : f64 +// CHECK: %[[VAL_229:.*]] = constant 1.000000e-01 : f64 +// CHECK: %[[VAL_230:.*]] = mulf %[[VAL_227]], %[[VAL_229]] : f64 +// CHECK: %[[VAL_231:.*]] = mulf %[[VAL_232:.*]], %[[VAL_215]] : f64 +// CHECK: %[[VAL_233:.*]] = mulf %[[VAL_231]], %[[VAL_223]] : f64 +// CHECK: %[[VAL_234:.*]] = constant 1.000000e-01 : f64 +// CHECK: %[[VAL_235:.*]] = mulf %[[VAL_233]], %[[VAL_234]] : f64 +// CHECK: %[[VAL_236:.*]] = addf %[[VAL_230]], %[[VAL_235]] : f64 +// CHECK: %[[VAL_237:.*]] = math.log %[[VAL_236]] : f64 +// CHECK: memref.store %[[VAL_237]], %[[VAL_1]]{{\[}}%[[VAL_151]]] : memref // CHECK: } // CHECK: return // CHECK: } diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir index 2ef17bfc..40ae61e4 100644 --- a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-select.mlir @@ -18,11 +18,18 @@ module { // CHECK-SAME: %[[VAL_1:.*]]: memref) { // CHECK: %[[VAL_2:.*]] = constant 1.000000e-01 : f64 // CHECK: %[[VAL_3:.*]] = constant 1 : index -// CHECK: %[[VAL_4:.*]] = constant 1.000000e+00 : f64 -// CHECK: %[[VAL_5:.*]] = cmpf ult, %[[VAL_2]], %[[VAL_4]] : f64 -// CHECK: %[[VAL_6:.*]] = constant 3.500000e-01 : f64 -// CHECK: %[[VAL_7:.*]] = constant 5.500000e-01 : f64 -// CHECK: %[[VAL_8:.*]] = select %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : f64 -// CHECK: memref.store %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_4:.*]] = constant 0.000000e+00 : f64 +// CHECK: %[[VAL_5:.*]] = constant 2.000000e+00 : f64 +// CHECK: %[[VAL_6:.*]] = cmpf uge, %[[VAL_2]], %[[VAL_4]] : f64 +// CHECK: %[[VAL_7:.*]] = cmpf ult, %[[VAL_2]], %[[VAL_5]] : f64 +// CHECK: %[[VAL_8:.*]] = constant 1.000000e+00 : f64 +// CHECK: %[[VAL_9:.*]] = cmpf ult, %[[VAL_2]], %[[VAL_8]] : f64 +// CHECK: %[[VAL_10:.*]] = constant 3.500000e-01 : f64 +// CHECK: %[[VAL_11:.*]] = constant 5.500000e-01 : f64 +// CHECK: %[[VAL_12:.*]] = select %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] : f64 +// CHECK: %[[VAL_13:.*]] = and %[[VAL_6]], %[[VAL_7]] : i1 +// CHECK: %[[VAL_14:.*]] = constant 0.000000e+00 : f64 +// CHECK: %[[VAL_15:.*]] = select %[[VAL_13]], %[[VAL_12]], %[[VAL_14]] : f64 +// CHECK: memref.store %[[VAL_15]], %[[VAL_1]]{{\[}}%[[VAL_3]]] : memref // CHECK: return // CHECK: } diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize-log.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize-log.mlir index 9c019ad0..049cb514 100644 --- a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize-log.mlir +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize-log.mlir @@ -184,84 +184,112 @@ module { // CHECK: %[[VAL_99:.*]] = vector.gather %[[VAL_97]]{{\[}}%[[VAL_98]]] {{\[}}%[[VAL_90]]], %[[VAL_92]], %[[VAL_91]] : memref, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> // CHECK: %[[VAL_100:.*]] = memref.get_global @categorical_vec_0 : memref<3xf32> // CHECK: %[[VAL_101:.*]] = fptoui %[[VAL_24]] : vector<8xf32> to vector<8xi64> -// CHECK: %[[VAL_102:.*]] = constant dense<0.000000e+00> : vector<8xf32> -// CHECK: %[[VAL_103:.*]] = constant dense : vector<8xi1> -// CHECK: %[[VAL_104:.*]] = constant 0 : index -// CHECK: %[[VAL_105:.*]] = vector.gather %[[VAL_100]]{{\[}}%[[VAL_104]]] {{\[}}%[[VAL_101]]], %[[VAL_103]], %[[VAL_102]] : memref<3xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK: %[[VAL_106:.*]] = memref.get_global @categorical_vec_1 : memref<3xf32> -// CHECK: %[[VAL_107:.*]] = fptoui %[[VAL_39]] : vector<8xf32> to vector<8xi64> -// CHECK: %[[VAL_108:.*]] = constant dense<0.000000e+00> : vector<8xf32> -// CHECK: %[[VAL_109:.*]] = constant dense : vector<8xi1> -// CHECK: %[[VAL_110:.*]] = constant 0 : index -// CHECK: %[[VAL_111:.*]] = vector.gather %[[VAL_106]]{{\[}}%[[VAL_110]]] {{\[}}%[[VAL_107]]], %[[VAL_109]], %[[VAL_108]] : memref<3xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK: %[[VAL_112:.*]] = memref.get_global @histogram_vec_0 : memref<2xf32> -// CHECK: %[[VAL_113:.*]] = fptoui %[[VAL_54]] : vector<8xf32> to vector<8xi64> -// CHECK: %[[VAL_114:.*]] = constant dense<0.000000e+00> : vector<8xf32> -// CHECK: %[[VAL_115:.*]] = constant dense : vector<8xi1> -// CHECK: %[[VAL_116:.*]] = constant 0 : index -// CHECK: %[[VAL_117:.*]] = vector.gather %[[VAL_112]]{{\[}}%[[VAL_116]]] {{\[}}%[[VAL_113]]], %[[VAL_115]], %[[VAL_114]] : memref<2xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK: %[[VAL_118:.*]] = memref.get_global @histogram_vec_1 : memref<2xf32> -// CHECK: %[[VAL_119:.*]] = fptoui %[[VAL_69]] : vector<8xf32> to vector<8xi64> -// CHECK: %[[VAL_120:.*]] = constant dense<0.000000e+00> : vector<8xf32> +// CHECK: %[[VAL_102:.*]] = constant dense<0xFF800000> : vector<8xf32> +// CHECK: %[[VAL_103:.*]] = constant dense<0> : vector<8xi64> +// CHECK: %[[VAL_104:.*]] = constant dense<3> : vector<8xi64> +// CHECK: %[[VAL_105:.*]] = cmpi sge, %[[VAL_101]], %[[VAL_103]] : vector<8xi64> +// CHECK: %[[VAL_106:.*]] = cmpi slt, %[[VAL_101]], %[[VAL_104]] : vector<8xi64> +// CHECK: %[[VAL_107:.*]] = and %[[VAL_105]], %[[VAL_106]] : vector<8xi1> +// CHECK: %[[VAL_108:.*]] = constant dense : vector<8xi1> +// CHECK: %[[VAL_109:.*]] = constant dense : vector<8xi1> +// CHECK: %[[VAL_110:.*]] = select %[[VAL_107]], %[[VAL_108]], %[[VAL_109]] : vector<8xi1>, vector<8xi1> +// CHECK: %[[VAL_111:.*]] = constant 0 : index +// CHECK: %[[VAL_112:.*]] = vector.gather %[[VAL_100]]{{\[}}%[[VAL_111]]] {{\[}}%[[VAL_101]]], %[[VAL_110]], %[[VAL_102]] : memref<3xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK: %[[VAL_113:.*]] = memref.get_global @categorical_vec_1 : memref<3xf32> +// CHECK: %[[VAL_114:.*]] = fptoui %[[VAL_39]] : vector<8xf32> to vector<8xi64> +// CHECK: %[[VAL_115:.*]] = constant dense<0xFF800000> : vector<8xf32> +// CHECK: %[[VAL_116:.*]] = constant dense<0> : vector<8xi64> +// CHECK: %[[VAL_117:.*]] = constant dense<3> : vector<8xi64> +// CHECK: %[[VAL_118:.*]] = cmpi sge, %[[VAL_114]], %[[VAL_116]] : vector<8xi64> +// CHECK: %[[VAL_119:.*]] = cmpi slt, %[[VAL_114]], %[[VAL_117]] : vector<8xi64> +// CHECK: %[[VAL_120:.*]] = and %[[VAL_118]], %[[VAL_119]] : vector<8xi1> // CHECK: %[[VAL_121:.*]] = constant dense : vector<8xi1> -// CHECK: %[[VAL_122:.*]] = constant 0 : index -// CHECK: %[[VAL_123:.*]] = vector.gather %[[VAL_118]]{{\[}}%[[VAL_122]]] {{\[}}%[[VAL_119]]], %[[VAL_121]], %[[VAL_120]] : memref<2xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK: %[[VAL_124:.*]] = constant dense<-5.000000e-01> : vector<8xf32> -// CHECK: %[[VAL_125:.*]] = constant dense<-0.918938517> : vector<8xf32> -// CHECK: %[[VAL_126:.*]] = constant dense<5.000000e-01> : vector<8xf32> -// CHECK: %[[VAL_127:.*]] = subf %[[VAL_84]], %[[VAL_126]] : vector<8xf32> -// CHECK: %[[VAL_128:.*]] = mulf %[[VAL_127]], %[[VAL_127]] : vector<8xf32> -// CHECK: %[[VAL_129:.*]] = mulf %[[VAL_128]], %[[VAL_124]] : vector<8xf32> -// CHECK: %[[VAL_130:.*]] = addf %[[VAL_125]], %[[VAL_129]] : vector<8xf32> -// CHECK: %[[VAL_131:.*]] = constant dense<-5.000000e+01> : vector<8xf32> -// CHECK: %[[VAL_132:.*]] = constant dense<1.38364661> : vector<8xf32> -// CHECK: %[[VAL_133:.*]] = constant dense<2.500000e-01> : vector<8xf32> -// CHECK: %[[VAL_134:.*]] = subf %[[VAL_99]], %[[VAL_133]] : vector<8xf32> -// CHECK: %[[VAL_135:.*]] = mulf %[[VAL_134]], %[[VAL_134]] : vector<8xf32> -// CHECK: %[[VAL_136:.*]] = mulf %[[VAL_135]], %[[VAL_131]] : vector<8xf32> -// CHECK: %[[VAL_137:.*]] = addf %[[VAL_132]], %[[VAL_136]] : vector<8xf32> -// CHECK: %[[VAL_138:.*]] = addf %[[VAL_105]], %[[VAL_111]] : vector<8xf32> -// CHECK: %[[VAL_139:.*]] = addf %[[VAL_138]], %[[VAL_117]] : vector<8xf32> -// CHECK: %[[VAL_140:.*]] = constant dense<1.000000e-01> : vector<8xf32> -// CHECK: %[[VAL_141:.*]] = addf %[[VAL_139]], %[[VAL_140]] : vector<8xf32> -// CHECK: %[[VAL_142:.*]] = addf %[[VAL_123]], %[[VAL_130]] : vector<8xf32> -// CHECK: %[[VAL_143:.*]] = addf %[[VAL_142]], %[[VAL_137]] : vector<8xf32> -// CHECK: %[[VAL_144:.*]] = constant dense<1.000000e-01> : vector<8xf32> -// CHECK: %[[VAL_145:.*]] = addf %[[VAL_143]], %[[VAL_144]] : vector<8xf32> -// CHECK: %[[VAL_146:.*]] = cmpf ogt, %[[VAL_141]], %[[VAL_145]] : vector<8xf32> -// CHECK: %[[VAL_147:.*]] = select %[[VAL_146]], %[[VAL_141]], %[[VAL_145]] : vector<8xi1>, vector<8xf32> -// CHECK: %[[VAL_148:.*]] = select %[[VAL_146]], %[[VAL_145]], %[[VAL_141]] : vector<8xi1>, vector<8xf32> -// CHECK: %[[VAL_149:.*]] = subf %[[VAL_148]], %[[VAL_147]] : vector<8xf32> -// CHECK: %[[VAL_150:.*]] = math.exp %[[VAL_149]] : vector<8xf32> -// CHECK: %[[VAL_151:.*]] = math.log1p %[[VAL_150]] : vector<8xf32> -// CHECK: %[[VAL_152:.*]] = addf %[[VAL_147]], %[[VAL_151]] : vector<8xf32> -// CHECK: vector.transfer_write %[[VAL_152]], %[[VAL_1]]{{\[}}%[[VAL_9]]] : vector<8xf32>, memref +// CHECK: %[[VAL_122:.*]] = constant dense : vector<8xi1> +// CHECK: %[[VAL_123:.*]] = select %[[VAL_120]], %[[VAL_121]], %[[VAL_122]] : vector<8xi1>, vector<8xi1> +// CHECK: %[[VAL_124:.*]] = constant 0 : index +// CHECK: %[[VAL_125:.*]] = vector.gather %[[VAL_113]]{{\[}}%[[VAL_124]]] {{\[}}%[[VAL_114]]], %[[VAL_123]], %[[VAL_115]] : memref<3xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK: %[[VAL_126:.*]] = memref.get_global @histogram_vec_0 : memref<2xf32> +// CHECK: %[[VAL_127:.*]] = fptoui %[[VAL_54]] : vector<8xf32> to vector<8xi64> +// CHECK: %[[VAL_128:.*]] = constant dense<0xFF800000> : vector<8xf32> +// CHECK: %[[VAL_129:.*]] = constant dense<0> : vector<8xi64> +// CHECK: %[[VAL_130:.*]] = constant dense<2> : vector<8xi64> +// CHECK: %[[VAL_131:.*]] = cmpi sge, %[[VAL_127]], %[[VAL_129]] : vector<8xi64> +// CHECK: %[[VAL_132:.*]] = cmpi slt, %[[VAL_127]], %[[VAL_130]] : vector<8xi64> +// CHECK: %[[VAL_133:.*]] = and %[[VAL_131]], %[[VAL_132]] : vector<8xi1> +// CHECK: %[[VAL_134:.*]] = constant dense : vector<8xi1> +// CHECK: %[[VAL_135:.*]] = constant dense : vector<8xi1> +// CHECK: %[[VAL_136:.*]] = select %[[VAL_133]], %[[VAL_134]], %[[VAL_135]] : vector<8xi1>, vector<8xi1> +// CHECK: %[[VAL_137:.*]] = constant 0 : index +// CHECK: %[[VAL_138:.*]] = vector.gather %[[VAL_126]]{{\[}}%[[VAL_137]]] {{\[}}%[[VAL_127]]], %[[VAL_136]], %[[VAL_128]] : memref<2xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK: %[[VAL_139:.*]] = memref.get_global @histogram_vec_1 : memref<2xf32> +// CHECK: %[[VAL_140:.*]] = fptoui %[[VAL_69]] : vector<8xf32> to vector<8xi64> +// CHECK: %[[VAL_141:.*]] = constant dense<0xFF800000> : vector<8xf32> +// CHECK: %[[VAL_142:.*]] = constant dense<0> : vector<8xi64> +// CHECK: %[[VAL_143:.*]] = constant dense<2> : vector<8xi64> +// CHECK: %[[VAL_144:.*]] = cmpi sge, %[[VAL_140]], %[[VAL_142]] : vector<8xi64> +// CHECK: %[[VAL_145:.*]] = cmpi slt, %[[VAL_140]], %[[VAL_143]] : vector<8xi64> +// CHECK: %[[VAL_146:.*]] = and %[[VAL_144]], %[[VAL_145]] : vector<8xi1> +// CHECK: %[[VAL_147:.*]] = constant dense : vector<8xi1> +// CHECK: %[[VAL_148:.*]] = constant dense : vector<8xi1> +// CHECK: %[[VAL_149:.*]] = select %[[VAL_146]], %[[VAL_147]], %[[VAL_148]] : vector<8xi1>, vector<8xi1> +// CHECK: %[[VAL_150:.*]] = constant 0 : index +// CHECK: %[[VAL_151:.*]] = vector.gather %[[VAL_139]]{{\[}}%[[VAL_150]]] {{\[}}%[[VAL_140]]], %[[VAL_149]], %[[VAL_141]] : memref<2xf32>, vector<8xi64>, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK: %[[VAL_152:.*]] = constant dense<-5.000000e-01> : vector<8xf32> +// CHECK: %[[VAL_153:.*]] = constant dense<-0.918938517> : vector<8xf32> +// CHECK: %[[VAL_154:.*]] = constant dense<5.000000e-01> : vector<8xf32> +// CHECK: %[[VAL_155:.*]] = subf %[[VAL_84]], %[[VAL_154]] : vector<8xf32> +// CHECK: %[[VAL_156:.*]] = mulf %[[VAL_155]], %[[VAL_155]] : vector<8xf32> +// CHECK: %[[VAL_157:.*]] = mulf %[[VAL_156]], %[[VAL_152]] : vector<8xf32> +// CHECK: %[[VAL_158:.*]] = addf %[[VAL_153]], %[[VAL_157]] : vector<8xf32> +// CHECK: %[[VAL_159:.*]] = constant dense<-5.000000e+01> : vector<8xf32> +// CHECK: %[[VAL_160:.*]] = constant dense<1.38364661> : vector<8xf32> +// CHECK: %[[VAL_161:.*]] = constant dense<2.500000e-01> : vector<8xf32> +// CHECK: %[[VAL_162:.*]] = subf %[[VAL_99]], %[[VAL_161]] : vector<8xf32> +// CHECK: %[[VAL_163:.*]] = mulf %[[VAL_162]], %[[VAL_162]] : vector<8xf32> +// CHECK: %[[VAL_164:.*]] = mulf %[[VAL_163]], %[[VAL_159]] : vector<8xf32> +// CHECK: %[[VAL_165:.*]] = addf %[[VAL_160]], %[[VAL_164]] : vector<8xf32> +// CHECK: %[[VAL_166:.*]] = addf %[[VAL_112]], %[[VAL_125]] : vector<8xf32> +// CHECK: %[[VAL_167:.*]] = addf %[[VAL_166]], %[[VAL_138]] : vector<8xf32> +// CHECK: %[[VAL_168:.*]] = constant dense<1.000000e-01> : vector<8xf32> +// CHECK: %[[VAL_169:.*]] = addf %[[VAL_167]], %[[VAL_168]] : vector<8xf32> +// CHECK: %[[VAL_170:.*]] = addf %[[VAL_151]], %[[VAL_158]] : vector<8xf32> +// CHECK: %[[VAL_171:.*]] = addf %[[VAL_170]], %[[VAL_165]] : vector<8xf32> +// CHECK: %[[VAL_172:.*]] = constant dense<1.000000e-01> : vector<8xf32> +// CHECK: %[[VAL_173:.*]] = addf %[[VAL_171]], %[[VAL_172]] : vector<8xf32> +// CHECK: %[[VAL_174:.*]] = cmpf ogt, %[[VAL_169]], %[[VAL_173]] : vector<8xf32> +// CHECK: %[[VAL_175:.*]] = select %[[VAL_174]], %[[VAL_169]], %[[VAL_173]] : vector<8xi1>, vector<8xf32> +// CHECK: %[[VAL_176:.*]] = select %[[VAL_174]], %[[VAL_173]], %[[VAL_169]] : vector<8xi1>, vector<8xf32> +// CHECK: %[[VAL_177:.*]] = subf %[[VAL_176]], %[[VAL_175]] : vector<8xf32> +// CHECK: %[[VAL_178:.*]] = math.exp %[[VAL_177]] : vector<8xf32> +// CHECK: %[[VAL_179:.*]] = math.log1p %[[VAL_178]] : vector<8xf32> +// CHECK: %[[VAL_180:.*]] = addf %[[VAL_175]], %[[VAL_179]] : vector<8xf32> +// CHECK: vector.transfer_write %[[VAL_180]], %[[VAL_1]]{{\[}}%[[VAL_9]]] : vector<8xf32>, memref // CHECK: } -// CHECK: %[[VAL_153:.*]] = constant 1 : index -// CHECK: scf.for %[[VAL_154:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_153]] { -// CHECK: %[[VAL_155:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_154]]) {sampleIndex = 0 : ui32} : (memref, index) -> f32 -// CHECK: %[[VAL_156:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_154]]) {sampleIndex = 1 : ui32} : (memref, index) -> f32 -// CHECK: %[[VAL_157:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_154]]) {sampleIndex = 2 : ui32} : (memref, index) -> f32 -// CHECK: %[[VAL_158:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_154]]) {sampleIndex = 3 : ui32} : (memref, index) -> f32 -// CHECK: %[[VAL_159:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_154]]) {sampleIndex = 4 : ui32} : (memref, index) -> f32 -// CHECK: %[[VAL_160:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_154]]) {sampleIndex = 5 : ui32} : (memref, index) -> f32 -// CHECK: %[[VAL_161:.*]] = "lo_spn.categorical"(%[[VAL_155]]) {probabilities = [3.500000e-01, 5.500000e-01, 1.000000e-01], supportMarginal = false} : (f32) -> !lo_spn.log -// CHECK: %[[VAL_162:.*]] = "lo_spn.categorical"(%[[VAL_156]]) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false} : (f32) -> !lo_spn.log -// CHECK: %[[VAL_163:.*]] = "lo_spn.histogram"(%[[VAL_157]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 2.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 7.500000e-01 : f64}], supportMarginal = false} : (f32) -> !lo_spn.log -// CHECK: %[[VAL_164:.*]] = "lo_spn.histogram"(%[[VAL_158]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 4.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 5.500000e-01 : f64}], supportMarginal = false} : (f32) -> !lo_spn.log -// CHECK: %[[VAL_165:.*]] = "lo_spn.gaussian"(%[[VAL_159]]) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false} : (f32) -> !lo_spn.log -// CHECK: %[[VAL_166:.*]] = "lo_spn.gaussian"(%[[VAL_160]]) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false} : (f32) -> !lo_spn.log -// CHECK: %[[VAL_167:.*]] = "lo_spn.mul"(%[[VAL_161]], %[[VAL_162]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log -// CHECK: %[[VAL_168:.*]] = "lo_spn.mul"(%[[VAL_167]], %[[VAL_163]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log -// CHECK: %[[VAL_169:.*]] = "lo_spn.constant"() {type = !lo_spn.log, value = 1.000000e-01 : f64} : () -> !lo_spn.log -// CHECK: %[[VAL_170:.*]] = "lo_spn.mul"(%[[VAL_168]], %[[VAL_169]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log -// CHECK: %[[VAL_171:.*]] = "lo_spn.mul"(%[[VAL_164]], %[[VAL_165]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log -// CHECK: %[[VAL_172:.*]] = "lo_spn.mul"(%[[VAL_171]], %[[VAL_166]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log -// CHECK: %[[VAL_173:.*]] = "lo_spn.constant"() {type = !lo_spn.log, value = 1.000000e-01 : f64} : () -> !lo_spn.log -// CHECK: %[[VAL_174:.*]] = "lo_spn.mul"(%[[VAL_172]], %[[VAL_173]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log -// CHECK: %[[VAL_175:.*]] = "lo_spn.add"(%[[VAL_170]], %[[VAL_174]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log -// CHECK: %[[VAL_176:.*]] = "lo_spn.strip_log"(%[[VAL_175]]) {target = f32} : (!lo_spn.log) -> f32 -// CHECK: "lo_spn.batch_write"(%[[VAL_176]], %[[VAL_1]], %[[VAL_154]]) : (f32, memref, index) -> () +// CHECK: %[[VAL_181:.*]] = constant 1 : index +// CHECK: scf.for %[[VAL_182:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_181]] { +// CHECK: %[[VAL_183:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_182]]) {sampleIndex = 0 : ui32} : (memref, index) -> f32 +// CHECK: %[[VAL_184:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_182]]) {sampleIndex = 1 : ui32} : (memref, index) -> f32 +// CHECK: %[[VAL_185:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_182]]) {sampleIndex = 2 : ui32} : (memref, index) -> f32 +// CHECK: %[[VAL_186:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_182]]) {sampleIndex = 3 : ui32} : (memref, index) -> f32 +// CHECK: %[[VAL_187:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_182]]) {sampleIndex = 4 : ui32} : (memref, index) -> f32 +// CHECK: %[[VAL_188:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_182]]) {sampleIndex = 5 : ui32} : (memref, index) -> f32 +// CHECK: %[[VAL_189:.*]] = "lo_spn.categorical"(%[[VAL_183]]) {probabilities = [3.500000e-01, 5.500000e-01, 1.000000e-01], supportMarginal = false} : (f32) -> !lo_spn.log +// CHECK: %[[VAL_190:.*]] = "lo_spn.categorical"(%[[VAL_184]]) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false} : (f32) -> !lo_spn.log +// CHECK: %[[VAL_191:.*]] = "lo_spn.histogram"(%[[VAL_185]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 2.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 7.500000e-01 : f64}], supportMarginal = false} : (f32) -> !lo_spn.log +// CHECK: %[[VAL_192:.*]] = "lo_spn.histogram"(%[[VAL_186]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 4.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 5.500000e-01 : f64}], supportMarginal = false} : (f32) -> !lo_spn.log +// CHECK: %[[VAL_193:.*]] = "lo_spn.gaussian"(%[[VAL_187]]) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false} : (f32) -> !lo_spn.log +// CHECK: %[[VAL_194:.*]] = "lo_spn.gaussian"(%[[VAL_188]]) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false} : (f32) -> !lo_spn.log +// CHECK: %[[VAL_195:.*]] = "lo_spn.mul"(%[[VAL_189]], %[[VAL_190]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log +// CHECK: %[[VAL_196:.*]] = "lo_spn.mul"(%[[VAL_195]], %[[VAL_191]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log +// CHECK: %[[VAL_197:.*]] = "lo_spn.constant"() {type = !lo_spn.log, value = 1.000000e-01 : f64} : () -> !lo_spn.log +// CHECK: %[[VAL_198:.*]] = "lo_spn.mul"(%[[VAL_196]], %[[VAL_197]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log +// CHECK: %[[VAL_199:.*]] = "lo_spn.mul"(%[[VAL_192]], %[[VAL_193]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log +// CHECK: %[[VAL_200:.*]] = "lo_spn.mul"(%[[VAL_199]], %[[VAL_194]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log +// CHECK: %[[VAL_201:.*]] = "lo_spn.constant"() {type = !lo_spn.log, value = 1.000000e-01 : f64} : () -> !lo_spn.log +// CHECK: %[[VAL_202:.*]] = "lo_spn.mul"(%[[VAL_200]], %[[VAL_201]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log +// CHECK: %[[VAL_203:.*]] = "lo_spn.add"(%[[VAL_198]], %[[VAL_202]]) : (!lo_spn.log, !lo_spn.log) -> !lo_spn.log +// CHECK: %[[VAL_204:.*]] = "lo_spn.strip_log"(%[[VAL_203]]) {target = f32} : (!lo_spn.log) -> f32 +// CHECK: "lo_spn.batch_write"(%[[VAL_204]], %[[VAL_1]], %[[VAL_182]]) : (f32, memref, index) -> () // CHECK: } // CHECK: return // CHECK: } diff --git a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize.mlir b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize.mlir index d23e6ac8..2779f499 100644 --- a/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize.mlir +++ b/mlir/test/lowering/lospn-to-cpu/lower-to-cpu-nodes-vectorize.mlir @@ -185,80 +185,108 @@ module { // CHECK: %[[VAL_100:.*]] = memref.get_global @categorical_vec_0 : memref<3xf64> // CHECK: %[[VAL_101:.*]] = fptoui %[[VAL_24]] : vector<4xf64> to vector<4xi64> // CHECK: %[[VAL_102:.*]] = constant dense<0.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_103:.*]] = constant dense : vector<4xi1> -// CHECK: %[[VAL_104:.*]] = constant 0 : index -// CHECK: %[[VAL_105:.*]] = vector.gather %[[VAL_100]]{{\[}}%[[VAL_104]]] {{\[}}%[[VAL_101]]], %[[VAL_103]], %[[VAL_102]] : memref<3xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_106:.*]] = memref.get_global @categorical_vec_1 : memref<3xf64> -// CHECK: %[[VAL_107:.*]] = fptoui %[[VAL_39]] : vector<4xf64> to vector<4xi64> -// CHECK: %[[VAL_108:.*]] = constant dense<0.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_109:.*]] = constant dense : vector<4xi1> -// CHECK: %[[VAL_110:.*]] = constant 0 : index -// CHECK: %[[VAL_111:.*]] = vector.gather %[[VAL_106]]{{\[}}%[[VAL_110]]] {{\[}}%[[VAL_107]]], %[[VAL_109]], %[[VAL_108]] : memref<3xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_112:.*]] = memref.get_global @histogram_vec_0 : memref<2xf64> -// CHECK: %[[VAL_113:.*]] = fptoui %[[VAL_54]] : vector<4xf64> to vector<4xi64> -// CHECK: %[[VAL_114:.*]] = constant dense<0.000000e+00> : vector<4xf64> -// CHECK: %[[VAL_115:.*]] = constant dense : vector<4xi1> -// CHECK: %[[VAL_116:.*]] = constant 0 : index -// CHECK: %[[VAL_117:.*]] = vector.gather %[[VAL_112]]{{\[}}%[[VAL_116]]] {{\[}}%[[VAL_113]]], %[[VAL_115]], %[[VAL_114]] : memref<2xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_118:.*]] = memref.get_global @histogram_vec_1 : memref<2xf64> -// CHECK: %[[VAL_119:.*]] = fptoui %[[VAL_69]] : vector<4xf64> to vector<4xi64> -// CHECK: %[[VAL_120:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_103:.*]] = constant dense<0> : vector<4xi64> +// CHECK: %[[VAL_104:.*]] = constant dense<3> : vector<4xi64> +// CHECK: %[[VAL_105:.*]] = cmpi sge, %[[VAL_101]], %[[VAL_103]] : vector<4xi64> +// CHECK: %[[VAL_106:.*]] = cmpi slt, %[[VAL_101]], %[[VAL_104]] : vector<4xi64> +// CHECK: %[[VAL_107:.*]] = and %[[VAL_105]], %[[VAL_106]] : vector<4xi1> +// CHECK: %[[VAL_108:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_109:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_110:.*]] = select %[[VAL_107]], %[[VAL_108]], %[[VAL_109]] : vector<4xi1>, vector<4xi1> +// CHECK: %[[VAL_111:.*]] = constant 0 : index +// CHECK: %[[VAL_112:.*]] = vector.gather %[[VAL_100]]{{\[}}%[[VAL_111]]] {{\[}}%[[VAL_101]]], %[[VAL_110]], %[[VAL_102]] : memref<3xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_113:.*]] = memref.get_global @categorical_vec_1 : memref<3xf64> +// CHECK: %[[VAL_114:.*]] = fptoui %[[VAL_39]] : vector<4xf64> to vector<4xi64> +// CHECK: %[[VAL_115:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_116:.*]] = constant dense<0> : vector<4xi64> +// CHECK: %[[VAL_117:.*]] = constant dense<3> : vector<4xi64> +// CHECK: %[[VAL_118:.*]] = cmpi sge, %[[VAL_114]], %[[VAL_116]] : vector<4xi64> +// CHECK: %[[VAL_119:.*]] = cmpi slt, %[[VAL_114]], %[[VAL_117]] : vector<4xi64> +// CHECK: %[[VAL_120:.*]] = and %[[VAL_118]], %[[VAL_119]] : vector<4xi1> // CHECK: %[[VAL_121:.*]] = constant dense : vector<4xi1> -// CHECK: %[[VAL_122:.*]] = constant 0 : index -// CHECK: %[[VAL_123:.*]] = vector.gather %[[VAL_118]]{{\[}}%[[VAL_122]]] {{\[}}%[[VAL_119]]], %[[VAL_121]], %[[VAL_120]] : memref<2xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> -// CHECK: %[[VAL_124:.*]] = constant dense<0.3989422804014327> : vector<4xf64> -// CHECK: %[[VAL_125:.*]] = constant dense<-5.000000e-01> : vector<4xf64> -// CHECK: %[[VAL_126:.*]] = constant dense<5.000000e-01> : vector<4xf64> -// CHECK: %[[VAL_127:.*]] = subf %[[VAL_84]], %[[VAL_126]] : vector<4xf64> -// CHECK: %[[VAL_128:.*]] = mulf %[[VAL_127]], %[[VAL_127]] : vector<4xf64> -// CHECK: %[[VAL_129:.*]] = mulf %[[VAL_128]], %[[VAL_125]] : vector<4xf64> -// CHECK: %[[VAL_130:.*]] = math.exp %[[VAL_129]] : vector<4xf64> -// CHECK: %[[VAL_131:.*]] = mulf %[[VAL_124]], %[[VAL_130]] : vector<4xf64> -// CHECK: %[[VAL_132:.*]] = constant dense<3.9894228040143269> : vector<4xf64> -// CHECK: %[[VAL_133:.*]] = constant dense<-49.999999999999993> : vector<4xf64> -// CHECK: %[[VAL_134:.*]] = constant dense<2.500000e-01> : vector<4xf64> -// CHECK: %[[VAL_135:.*]] = subf %[[VAL_99]], %[[VAL_134]] : vector<4xf64> -// CHECK: %[[VAL_136:.*]] = mulf %[[VAL_135]], %[[VAL_135]] : vector<4xf64> -// CHECK: %[[VAL_137:.*]] = mulf %[[VAL_136]], %[[VAL_133]] : vector<4xf64> -// CHECK: %[[VAL_138:.*]] = math.exp %[[VAL_137]] : vector<4xf64> -// CHECK: %[[VAL_139:.*]] = mulf %[[VAL_132]], %[[VAL_138]] : vector<4xf64> -// CHECK: %[[VAL_140:.*]] = mulf %[[VAL_105]], %[[VAL_111]] : vector<4xf64> -// CHECK: %[[VAL_141:.*]] = mulf %[[VAL_140]], %[[VAL_117]] : vector<4xf64> -// CHECK: %[[VAL_142:.*]] = constant dense<1.000000e-01> : vector<4xf64> -// CHECK: %[[VAL_143:.*]] = mulf %[[VAL_141]], %[[VAL_142]] : vector<4xf64> -// CHECK: %[[VAL_144:.*]] = mulf %[[VAL_123]], %[[VAL_131]] : vector<4xf64> -// CHECK: %[[VAL_145:.*]] = mulf %[[VAL_144]], %[[VAL_139]] : vector<4xf64> -// CHECK: %[[VAL_146:.*]] = constant dense<1.000000e-01> : vector<4xf64> -// CHECK: %[[VAL_147:.*]] = mulf %[[VAL_145]], %[[VAL_146]] : vector<4xf64> -// CHECK: %[[VAL_148:.*]] = addf %[[VAL_143]], %[[VAL_147]] : vector<4xf64> -// CHECK: %[[VAL_149:.*]] = math.log %[[VAL_148]] : vector<4xf64> -// CHECK: vector.transfer_write %[[VAL_149]], %[[VAL_1]]{{\[}}%[[VAL_9]]] : vector<4xf64>, memref +// CHECK: %[[VAL_122:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_123:.*]] = select %[[VAL_120]], %[[VAL_121]], %[[VAL_122]] : vector<4xi1>, vector<4xi1> +// CHECK: %[[VAL_124:.*]] = constant 0 : index +// CHECK: %[[VAL_125:.*]] = vector.gather %[[VAL_113]]{{\[}}%[[VAL_124]]] {{\[}}%[[VAL_114]]], %[[VAL_123]], %[[VAL_115]] : memref<3xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_126:.*]] = memref.get_global @histogram_vec_0 : memref<2xf64> +// CHECK: %[[VAL_127:.*]] = fptoui %[[VAL_54]] : vector<4xf64> to vector<4xi64> +// CHECK: %[[VAL_128:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_129:.*]] = constant dense<0> : vector<4xi64> +// CHECK: %[[VAL_130:.*]] = constant dense<2> : vector<4xi64> +// CHECK: %[[VAL_131:.*]] = cmpi sge, %[[VAL_127]], %[[VAL_129]] : vector<4xi64> +// CHECK: %[[VAL_132:.*]] = cmpi slt, %[[VAL_127]], %[[VAL_130]] : vector<4xi64> +// CHECK: %[[VAL_133:.*]] = and %[[VAL_131]], %[[VAL_132]] : vector<4xi1> +// CHECK: %[[VAL_134:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_135:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_136:.*]] = select %[[VAL_133]], %[[VAL_134]], %[[VAL_135]] : vector<4xi1>, vector<4xi1> +// CHECK: %[[VAL_137:.*]] = constant 0 : index +// CHECK: %[[VAL_138:.*]] = vector.gather %[[VAL_126]]{{\[}}%[[VAL_137]]] {{\[}}%[[VAL_127]]], %[[VAL_136]], %[[VAL_128]] : memref<2xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_139:.*]] = memref.get_global @histogram_vec_1 : memref<2xf64> +// CHECK: %[[VAL_140:.*]] = fptoui %[[VAL_69]] : vector<4xf64> to vector<4xi64> +// CHECK: %[[VAL_141:.*]] = constant dense<0.000000e+00> : vector<4xf64> +// CHECK: %[[VAL_142:.*]] = constant dense<0> : vector<4xi64> +// CHECK: %[[VAL_143:.*]] = constant dense<2> : vector<4xi64> +// CHECK: %[[VAL_144:.*]] = cmpi sge, %[[VAL_140]], %[[VAL_142]] : vector<4xi64> +// CHECK: %[[VAL_145:.*]] = cmpi slt, %[[VAL_140]], %[[VAL_143]] : vector<4xi64> +// CHECK: %[[VAL_146:.*]] = and %[[VAL_144]], %[[VAL_145]] : vector<4xi1> +// CHECK: %[[VAL_147:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_148:.*]] = constant dense : vector<4xi1> +// CHECK: %[[VAL_149:.*]] = select %[[VAL_146]], %[[VAL_147]], %[[VAL_148]] : vector<4xi1>, vector<4xi1> +// CHECK: %[[VAL_150:.*]] = constant 0 : index +// CHECK: %[[VAL_151:.*]] = vector.gather %[[VAL_139]]{{\[}}%[[VAL_150]]] {{\[}}%[[VAL_140]]], %[[VAL_149]], %[[VAL_141]] : memref<2xf64>, vector<4xi64>, vector<4xi1>, vector<4xf64> into vector<4xf64> +// CHECK: %[[VAL_152:.*]] = constant dense<0.3989422804014327> : vector<4xf64> +// CHECK: %[[VAL_153:.*]] = constant dense<-5.000000e-01> : vector<4xf64> +// CHECK: %[[VAL_154:.*]] = constant dense<5.000000e-01> : vector<4xf64> +// CHECK: %[[VAL_155:.*]] = subf %[[VAL_84]], %[[VAL_154]] : vector<4xf64> +// CHECK: %[[VAL_156:.*]] = mulf %[[VAL_155]], %[[VAL_155]] : vector<4xf64> +// CHECK: %[[VAL_157:.*]] = mulf %[[VAL_156]], %[[VAL_153]] : vector<4xf64> +// CHECK: %[[VAL_158:.*]] = math.exp %[[VAL_157]] : vector<4xf64> +// CHECK: %[[VAL_159:.*]] = mulf %[[VAL_152]], %[[VAL_158]] : vector<4xf64> +// CHECK: %[[VAL_160:.*]] = constant dense<3.9894228040143269> : vector<4xf64> +// CHECK: %[[VAL_161:.*]] = constant dense<-49.999999999999993> : vector<4xf64> +// CHECK: %[[VAL_162:.*]] = constant dense<2.500000e-01> : vector<4xf64> +// CHECK: %[[VAL_163:.*]] = subf %[[VAL_99]], %[[VAL_162]] : vector<4xf64> +// CHECK: %[[VAL_164:.*]] = mulf %[[VAL_163]], %[[VAL_163]] : vector<4xf64> +// CHECK: %[[VAL_165:.*]] = mulf %[[VAL_164]], %[[VAL_161]] : vector<4xf64> +// CHECK: %[[VAL_166:.*]] = math.exp %[[VAL_165]] : vector<4xf64> +// CHECK: %[[VAL_167:.*]] = mulf %[[VAL_160]], %[[VAL_166]] : vector<4xf64> +// CHECK: %[[VAL_168:.*]] = mulf %[[VAL_112]], %[[VAL_125]] : vector<4xf64> +// CHECK: %[[VAL_169:.*]] = mulf %[[VAL_168]], %[[VAL_138]] : vector<4xf64> +// CHECK: %[[VAL_170:.*]] = constant dense<1.000000e-01> : vector<4xf64> +// CHECK: %[[VAL_171:.*]] = mulf %[[VAL_169]], %[[VAL_170]] : vector<4xf64> +// CHECK: %[[VAL_172:.*]] = mulf %[[VAL_151]], %[[VAL_159]] : vector<4xf64> +// CHECK: %[[VAL_173:.*]] = mulf %[[VAL_172]], %[[VAL_167]] : vector<4xf64> +// CHECK: %[[VAL_174:.*]] = constant dense<1.000000e-01> : vector<4xf64> +// CHECK: %[[VAL_175:.*]] = mulf %[[VAL_173]], %[[VAL_174]] : vector<4xf64> +// CHECK: %[[VAL_176:.*]] = addf %[[VAL_171]], %[[VAL_175]] : vector<4xf64> +// CHECK: %[[VAL_177:.*]] = math.log %[[VAL_176]] : vector<4xf64> +// CHECK: vector.transfer_write %[[VAL_177]], %[[VAL_1]]{{\[}}%[[VAL_9]]] : vector<4xf64>, memref // CHECK: } -// CHECK: %[[VAL_150:.*]] = constant 1 : index -// CHECK: scf.for %[[VAL_151:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_150]] { -// CHECK: %[[VAL_152:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_151]]) {sampleIndex = 0 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_153:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_151]]) {sampleIndex = 1 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_154:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_151]]) {sampleIndex = 2 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_155:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_151]]) {sampleIndex = 3 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_156:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_151]]) {sampleIndex = 4 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_157:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_151]]) {sampleIndex = 5 : ui32} : (memref, index) -> f64 -// CHECK: %[[VAL_158:.*]] = "lo_spn.categorical"(%[[VAL_152]]) {probabilities = [3.500000e-01, 5.500000e-01, 1.000000e-01], supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_159:.*]] = "lo_spn.categorical"(%[[VAL_153]]) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_160:.*]] = "lo_spn.histogram"(%[[VAL_154]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 2.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 7.500000e-01 : f64}], supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_161:.*]] = "lo_spn.histogram"(%[[VAL_155]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 4.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 5.500000e-01 : f64}], supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_162:.*]] = "lo_spn.gaussian"(%[[VAL_156]]) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_163:.*]] = "lo_spn.gaussian"(%[[VAL_157]]) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false} : (f64) -> f64 -// CHECK: %[[VAL_164:.*]] = "lo_spn.mul"(%[[VAL_158]], %[[VAL_159]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_165:.*]] = "lo_spn.mul"(%[[VAL_164]], %[[VAL_160]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_166:.*]] = "lo_spn.constant"() {type = f64, value = 1.000000e-01 : f64} : () -> f64 -// CHECK: %[[VAL_167:.*]] = "lo_spn.mul"(%[[VAL_165]], %[[VAL_166]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_168:.*]] = "lo_spn.mul"(%[[VAL_161]], %[[VAL_162]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_169:.*]] = "lo_spn.mul"(%[[VAL_168]], %[[VAL_163]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_170:.*]] = "lo_spn.constant"() {type = f64, value = 1.000000e-01 : f64} : () -> f64 -// CHECK: %[[VAL_171:.*]] = "lo_spn.mul"(%[[VAL_169]], %[[VAL_170]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_172:.*]] = "lo_spn.add"(%[[VAL_167]], %[[VAL_171]]) : (f64, f64) -> f64 -// CHECK: %[[VAL_173:.*]] = "lo_spn.log"(%[[VAL_172]]) : (f64) -> f64 -// CHECK: "lo_spn.batch_write"(%[[VAL_173]], %[[VAL_1]], %[[VAL_151]]) : (f64, memref, index) -> () +// CHECK: %[[VAL_178:.*]] = constant 1 : index +// CHECK: scf.for %[[VAL_179:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_178]] { +// CHECK: %[[VAL_180:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_179]]) {sampleIndex = 0 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_181:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_179]]) {sampleIndex = 1 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_182:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_179]]) {sampleIndex = 2 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_183:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_179]]) {sampleIndex = 3 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_184:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_179]]) {sampleIndex = 4 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_185:.*]] = "lo_spn.batch_read"(%[[VAL_0]], %[[VAL_179]]) {sampleIndex = 5 : ui32} : (memref, index) -> f64 +// CHECK: %[[VAL_186:.*]] = "lo_spn.categorical"(%[[VAL_180]]) {probabilities = [3.500000e-01, 5.500000e-01, 1.000000e-01], supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_187:.*]] = "lo_spn.categorical"(%[[VAL_181]]) {probabilities = [2.500000e-01, 6.250000e-01, 1.250000e-01], supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_188:.*]] = "lo_spn.histogram"(%[[VAL_182]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 2.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 7.500000e-01 : f64}], supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_189:.*]] = "lo_spn.histogram"(%[[VAL_183]]) {bucketCount = 2 : ui32, buckets = [{lb = 0 : i32, ub = 1 : i32, val = 4.500000e-01 : f64}, {lb = 1 : i32, ub = 2 : i32, val = 5.500000e-01 : f64}], supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_190:.*]] = "lo_spn.gaussian"(%[[VAL_184]]) {mean = 5.000000e-01 : f64, stddev = 1.000000e+00 : f64, supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_191:.*]] = "lo_spn.gaussian"(%[[VAL_185]]) {mean = 2.500000e-01 : f64, stddev = 1.000000e-01 : f64, supportMarginal = false} : (f64) -> f64 +// CHECK: %[[VAL_192:.*]] = "lo_spn.mul"(%[[VAL_186]], %[[VAL_187]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_193:.*]] = "lo_spn.mul"(%[[VAL_192]], %[[VAL_188]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_194:.*]] = "lo_spn.constant"() {type = f64, value = 1.000000e-01 : f64} : () -> f64 +// CHECK: %[[VAL_195:.*]] = "lo_spn.mul"(%[[VAL_193]], %[[VAL_194]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_196:.*]] = "lo_spn.mul"(%[[VAL_189]], %[[VAL_190]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_197:.*]] = "lo_spn.mul"(%[[VAL_196]], %[[VAL_191]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_198:.*]] = "lo_spn.constant"() {type = f64, value = 1.000000e-01 : f64} : () -> f64 +// CHECK: %[[VAL_199:.*]] = "lo_spn.mul"(%[[VAL_197]], %[[VAL_198]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_200:.*]] = "lo_spn.add"(%[[VAL_195]], %[[VAL_199]]) : (f64, f64) -> f64 +// CHECK: %[[VAL_201:.*]] = "lo_spn.log"(%[[VAL_200]]) : (f64) -> f64 +// CHECK: "lo_spn.batch_write"(%[[VAL_201]], %[[VAL_1]], %[[VAL_179]]) : (f64, memref, index) -> () // CHECK: } // CHECK: return // CHECK: } diff --git a/python-interface/test/cpu/test_cpu_out_of_bounds.py b/python-interface/test/cpu/test_cpu_out_of_bounds.py new file mode 100644 index 00000000..bad14d58 --- /dev/null +++ b/python-interface/test/cpu/test_cpu_out_of_bounds.py @@ -0,0 +1,63 @@ +# ============================================================================== +# This file is part of the SPNC project under the Apache License v2.0 by the +# Embedded Systems and Applications Group, TU Darmstadt. +# For the full copyright and license information, please view the LICENSE +# file that was distributed with this source code. +# SPDX-License-Identifier: Apache-2.0 +# ============================================================================== + +import numpy as np + +from spn.structure.Base import Product, Sum +from spn.structure.leaves.histogram.Histograms import Histogram +from spn.structure.leaves.parametric.Parametric import Categorical +from spn.algorithms.Inference import log_likelihood + +from spnc.cpu import CPUCompiler + +import pytest + + +def test_cpu_out_of_bounds(): + # Construct a minimal SPN. + h1 = Histogram([0., 1., 2., 3.], [0.125, 0.250, 0.625], [1, 1], scope=0) + h2 = Histogram([0., 1., 2., 3.], [0.100, 0.200, 0.700], [1, 1], scope=1) + c1 = Categorical(p=[0.1, 0.7, 0.2], scope=0) + c2 = Categorical(p=[0.4, 0.2, 0.4], scope=1) + + p0 = Product(children=[h1, h2]) + p1 = Product(children=[c1, c2]) + spn = Sum([0.3, 0.7], [p0, p1]) + + # Generate some out-of-bounds accesses + max_index = 2 + inputs = np.column_stack(( + np.random.randint(2 * max_index, size=30), + np.random.randint(2 * max_index, size=30), + )).astype("float64") + + # Execute the compiled Kernel. + results = CPUCompiler(verbose=False, computeInLogSpace=False, vectorize=False).log_likelihood(spn, inputs, + supportMarginal=False) + + # Compute the reference results using the inference from SPFlow. + reference = log_likelihood(spn, inputs) + reference = reference.reshape(30) + + # Account for SPFlow behavior: Out-of-bounds values are not returned as -inf, but: + # for Categoricals: 0.0 + # for Histograms: (np.log(np.finfo(float).eps)) + # Note: np.log(np.finfo(float).eps) is equal to: np.log(2.220446049250313e-16) = -36.04365338911715 + # Find all inputs where an out-of-bounds access might occur. + # If the access is in-bounds, the corresponding index will be 0 (otherwise: > 0). + cond = np.sum(np.where((inputs > max_index), 1, 0), axis=1) + reference[cond > 0] = -np.inf + + # Check the computation results against the reference + # Check in normal space if log-results are not very close to each other. + assert np.all(np.isclose(results, reference)) or np.all(np.isclose(np.exp(results), np.exp(reference))) + + +if __name__ == "__main__": + test_cpu_out_of_bounds() + print("COMPUTATION OK") diff --git a/python-interface/test/cpu/test_cpu_transformation_categorical_to_select.py b/python-interface/test/cpu/test_cpu_transformation_categorical_to_select.py new file mode 100644 index 00000000..402a87c6 --- /dev/null +++ b/python-interface/test/cpu/test_cpu_transformation_categorical_to_select.py @@ -0,0 +1,50 @@ +# ============================================================================== +# This file is part of the SPNC project under the Apache License v2.0 by the +# Embedded Systems and Applications Group, TU Darmstadt. +# For the full copyright and license information, please view the LICENSE +# file that was distributed with this source code. +# SPDX-License-Identifier: Apache-2.0 +# ============================================================================== + +import numpy as np + +from spn.structure.Base import Product, Sum +from spn.structure.leaves.parametric.Parametric import Categorical +from spn.algorithms.Inference import log_likelihood + +from spnc.cpu import CPUCompiler + + +def test_cpu_transformation_categorical_to_select(): + # Construct a minimal SPN. + c1 = Categorical(p=[0.25, 0.75], scope=0) + c2 = Categorical(p=[0.33, 0.67], scope=1) + c3 = Categorical(p=[0.80, 0.20], scope=0) + c4 = Categorical(p=[0.50, 0.50], scope=1) + + p0 = Product(children=[c1, c2]) + p1 = Product(children=[c3, c4]) + spn = Sum([0.3, 0.7], [p0, p1]) + + inputs = np.column_stack(( + np.random.randint(2, size=30), + np.random.randint(2, size=30), + )).astype("float64") + + # Execute the compiled Kernel. + results = CPUCompiler(computeInLogSpace=False, vectorize=False).log_likelihood(spn, inputs, + supportMarginal=False, + batchSize=10) + + # Compute the reference results using the inference from SPFlow. + reference = log_likelihood(spn, inputs) + reference = reference.reshape(30) + + # Check the computation results against the reference + # Check in normal space if log-results are not very close to each other. + assert np.all(np.isclose(results, reference)) or np.all(np.isclose(np.exp(results), np.exp(reference))) + + +if __name__ == "__main__": + test_cpu_transformation_categorical_to_select() + print("COMPUTATION OK") diff --git a/python-interface/test/cpu/test_cpu_transformation_histogram_to_select.py b/python-interface/test/cpu/test_cpu_transformation_histogram_to_select.py new file mode 100644 index 00000000..ebe0515e --- /dev/null +++ b/python-interface/test/cpu/test_cpu_transformation_histogram_to_select.py @@ -0,0 +1,50 @@ +# ============================================================================== +# This file is part of the SPNC project under the Apache License v2.0 by the +# Embedded Systems and Applications Group, TU Darmstadt. +# For the full copyright and license information, please view the LICENSE +# file that was distributed with this source code. +# SPDX-License-Identifier: Apache-2.0 +# ============================================================================== + +import numpy as np + +from spn.structure.Base import Product, Sum +from spn.structure.leaves.histogram.Histograms import Histogram +from spn.algorithms.Inference import log_likelihood + +from spnc.cpu import CPUCompiler + + +def test_cpu_transformation_histogram_to_select(): + # Construct a minimal SPN. + h1 = Histogram([0., 1., 2.], [0.25, 0.75], [1, 1], scope=0) + h2 = Histogram([0., 1., 2.], [0.35, 0.65], [1, 1], scope=1) + h3 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=1) + h4 = Histogram([0., 1., 2.], [0.875, 0.125], [1, 1], scope=0) + + p0 = Product(children=[h1, h2]) + p1 = Product(children=[h3, h4]) + spn = Sum([0.3, 0.7], [p0, p1]) + + inputs = np.column_stack(( + np.random.randint(2, size=30), + np.random.randint(2, size=30), + )).astype("float64") + + # Execute the compiled Kernel. + results = CPUCompiler(verbose=False, computeInLogSpace=False, vectorize=False).log_likelihood(spn, inputs, + supportMarginal=False, + batchSize=10) + + # Compute the reference results using the inference from SPFlow. + reference = log_likelihood(spn, inputs) + reference = reference.reshape(30) + + # Check the computation results against the reference + # Check in normal space if log-results are not very close to each other. + assert np.all(np.isclose(results, reference)) or np.all(np.isclose(np.exp(results), np.exp(reference))) + + +if __name__ == "__main__": + test_cpu_transformation_histogram_to_select() + print("COMPUTATION OK") diff --git a/python-interface/test/cpu/test_marginal_cpu_transformation_to_select.py b/python-interface/test/cpu/test_marginal_cpu_transformation_to_select.py new file mode 100644 index 00000000..4b33d121 --- /dev/null +++ b/python-interface/test/cpu/test_marginal_cpu_transformation_to_select.py @@ -0,0 +1,52 @@ +# ============================================================================== +# This file is part of the SPNC project under the Apache License v2.0 by the +# Embedded Systems and Applications Group, TU Darmstadt. +# For the full copyright and license information, please view the LICENSE +# file that was distributed with this source code. +# SPDX-License-Identifier: Apache-2.0 +# ============================================================================== + +import numpy as np + +from spn.structure.Base import Product, Sum +from spn.structure.leaves.histogram.Histograms import Histogram +from spn.structure.leaves.parametric.Parametric import Categorical +from spn.algorithms.Inference import log_likelihood + +from spnc.cpu import CPUCompiler + + +def test_marginal_cpu_transformation_to_select(): + # Construct a minimal SPN. + h1 = Histogram([0., 1., 2.], [0.33, 0.67], [1, 1], scope=0) + h2 = Histogram([0., 1., 2.], [0.875, 0.125], [1, 1], scope=1) + c1 = Categorical(p=[0.3, 0.7], scope=0) + c2 = Categorical(p=[0.4, 0.6], scope=1) + + p0 = Product(children=[h1, h2]) + p1 = Product(children=[c1, c2]) + spn = Sum([0.3, 0.7], [p0, p1]) + + inputs = np.column_stack(( + np.random.randint(2, size=30), + np.random.randint(2, size=30), + )).astype("float64") + + # Insert some NaN in random places into the input data. + inputs.ravel()[np.random.choice(inputs.size, 5, replace=False)] = np.nan + + # Execute the compiled Kernel. + results = CPUCompiler(computeInLogSpace=False, vectorize=False).log_likelihood(spn, inputs, supportMarginal=True, batchSize=10) + + # Compute the reference results using the inference from SPFlow. + reference = log_likelihood(spn, inputs) + reference = reference.reshape(30) + + # Check the computation results against the reference + # Check in normal space if log-results are not very close to each other. + assert np.all(np.isclose(results, reference)) or np.all(np.isclose(np.exp(results), np.exp(reference))) + + +if __name__ == "__main__": + test_marginal_cpu_transformation_to_select() + print("COMPUTATION OK") diff --git a/python-interface/test/vector/test_vector_out_of_bounds.py b/python-interface/test/vector/test_vector_out_of_bounds.py new file mode 100644 index 00000000..918cf557 --- /dev/null +++ b/python-interface/test/vector/test_vector_out_of_bounds.py @@ -0,0 +1,67 @@ +# ============================================================================== +# This file is part of the SPNC project under the Apache License v2.0 by the +# Embedded Systems and Applications Group, TU Darmstadt. +# For the full copyright and license information, please view the LICENSE +# file that was distributed with this source code. +# SPDX-License-Identifier: Apache-2.0 +# ============================================================================== + +import numpy as np + +from spn.structure.Base import Product, Sum +from spn.structure.leaves.histogram.Histograms import Histogram +from spn.structure.leaves.parametric.Parametric import Categorical +from spn.algorithms.Inference import log_likelihood + +from spnc.cpu import CPUCompiler + +import pytest + + +@pytest.mark.skipif(not CPUCompiler.isVectorizationSupported(), reason="CPU vectorization not supported") +def test_vector_out_of_bounds(): + # Construct a minimal SPN. + h1 = Histogram([0., 1., 2., 3.], [0.125, 0.250, 0.625], [1, 1], scope=0) + h2 = Histogram([0., 1., 2., 3.], [0.100, 0.200, 0.700], [1, 1], scope=1) + c1 = Categorical(p=[0.1, 0.7, 0.2], scope=0) + c2 = Categorical(p=[0.4, 0.2, 0.4], scope=1) + + p0 = Product(children=[h1, h2]) + p1 = Product(children=[c1, c2]) + spn = Sum([0.3, 0.7], [p0, p1]) + + # Generate some out-of-bounds accesses + max_index = 2 + inputs = np.column_stack(( + np.random.randint(2 * max_index, size=30), + np.random.randint(2 * max_index, size=30), + )).astype("float64") + + if not CPUCompiler.isVectorizationSupported(): + print("Test not supported by the compiler installation") + return 0 + + # Execute the compiled Kernel. + results = CPUCompiler(verbose=False, computeInLogSpace=False).log_likelihood(spn, inputs, supportMarginal=False) + + # Compute the reference results using the inference from SPFlow. + reference = log_likelihood(spn, inputs) + reference = reference.reshape(30) + + # Account for SPFlow behavior: Out-of-bounds values are not returned as -inf, but: + # for Categoricals: 0.0 + # for Histograms: (np.log(np.finfo(float).eps)) + # Note: np.log(np.finfo(float).eps) is equal to: np.log(2.220446049250313e-16) = -36.04365338911715 + # Find all inputs where an out-of-bounds access will occur. + # If the access is in-bounds, the corresponding index will be 0 (otherwise: > 0). + cond = np.sum(np.where((inputs > max_index), 1, 0), axis=1) + reference[cond > 0] = -np.inf + + # Check the computation results against the reference + # Check in normal space if log-results are not very close to each other. + assert np.all(np.isclose(results, reference)) or np.all(np.isclose(np.exp(results), np.exp(reference))) + + +if __name__ == "__main__": + test_vector_out_of_bounds() + print("COMPUTATION OK")