diff --git a/paddle/cinn/backends/codegen_c.cc b/paddle/cinn/backends/codegen_c.cc index 573f5d54083c0..eb6ed6263a156 100644 --- a/paddle/cinn/backends/codegen_c.cc +++ b/paddle/cinn/backends/codegen_c.cc @@ -22,7 +22,7 @@ #include "paddle/cinn/ir/lowered_func.h" #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/utils/ir_verify.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/cinn/runtime/cpu/thread_backend.h" #include "paddle/cinn/runtime/intrinsic.h" #include "paddle/cinn/utils/string.h" @@ -170,7 +170,7 @@ void CodeGenC::Visit(const ir::Mul *op) { IrPrinter::Visit(op); } void CodeGenC::Visit(const ir::Div *op) { IrPrinter::Visit(op); } void CodeGenC::Visit(const ir::Mod *op) { auto copied = op->b(); - optim::Simplify(&copied); + optim::Simplify(&copied, optim::SimplifyType::kExpr); if (copied.is_constant()) { int temp = static_cast(copied.get_constant()); if ((temp & (temp - 1)) == 0) { @@ -709,8 +709,6 @@ void CodeGenC::Visit(const ir::_LoweredFunc_ *op) { Expr func_body = ir::Block::Make(new_body); - optim::SimplifyBlocks(&func_body); - IrPrinter::Visit(func_body); } void CodeGenC::PrintIncludes() { diff --git a/paddle/cinn/backends/codegen_cuda_generate_test.cc b/paddle/cinn/backends/codegen_cuda_generate_test.cc index 4530bb88bcaf8..6855b50fc589b 100644 --- a/paddle/cinn/backends/codegen_cuda_generate_test.cc +++ b/paddle/cinn/backends/codegen_cuda_generate_test.cc @@ -33,7 +33,6 @@ #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/lang/lower.h" -#include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/utils/timer.h" #include "paddle/common/enforce.h" diff --git a/paddle/cinn/backends/codegen_gpu_dev.cc b/paddle/cinn/backends/codegen_gpu_dev.cc index 9886d7c3a9fc4..3751c4859e93b 100644 --- a/paddle/cinn/backends/codegen_gpu_dev.cc +++ b/paddle/cinn/backends/codegen_gpu_dev.cc @@ -25,7 +25,7 @@ #include "paddle/cinn/common/ir_util.h" #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/utils/ir_verify.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/common/enforce.h" #include "paddle/common/errors.h" namespace cinn { @@ -167,7 +167,6 @@ void CodeGenGpuDev::Visit(const ir::_LoweredFunc_ *op) { Expr func_body = ir::Block::Make(new_body); - optim::SimplifyBlocks(&func_body); // Make sure that the function's body is wrapped by a block if (!func_body.As()) { func_body = ir::Block::Make({func_body}); @@ -307,7 +306,7 @@ void CodeGenGpuDev::PrintTempBufferCreation(const ir::Buffer &buffer) { for (int i = 0; i < buffer->shape.size(); i++) { buffer_size = buffer_size * buffer->shape[i]; } - optim::Simplify(&buffer_size); + optim::Simplify(&buffer_size, optim::SimplifyType::kExpr); bool has_symbolic_constant = false; ir::ir_utils::CollectIRNodes(buffer_size, [&](const Expr *x) { if (x->as_var()) { @@ -339,7 +338,7 @@ void CodeGenGpuDev::PrintTempBufferCreation(const ir::Buffer &buffer) { int type_bytes = buffer->dtype.bytes(); dyn_shared_mem_offset_ = dyn_shared_mem_offset_ + buffer_size * Expr(type_bytes); - optim::Simplify(&dyn_shared_mem_offset_); + optim::Simplify(&dyn_shared_mem_offset_, optim::SimplifyType::kExpr); VLOG(6) << "dyn_shared_mem_offset_ = " << dyn_shared_mem_offset_; } else if (buffer->memory_type == ir::MemoryType::GPULocal) { // print func of static allocation diff --git a/paddle/cinn/common/cas.h b/paddle/cinn/common/cas.h index 53a37eb58004c..111b5397ad1e4 100644 --- a/paddle/cinn/common/cas.h +++ b/paddle/cinn/common/cas.h @@ -21,7 +21,7 @@ #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_printer.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/common/enforce.h" namespace cinn { namespace common { diff --git a/paddle/cinn/hlir/op/reduction.cc b/paddle/cinn/hlir/op/reduction.cc index f1e90157a8185..663c09191684a 100644 --- a/paddle/cinn/hlir/op/reduction.cc +++ b/paddle/cinn/hlir/op/reduction.cc @@ -27,7 +27,7 @@ #include "paddle/cinn/hlir/pe/transform.h" #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/cinn/runtime/flags.h" PD_DECLARE_bool(cinn_enable_map_expr); @@ -205,7 +205,6 @@ std::shared_ptr StrategyForReduce( // support the length-1 for loop yet. So we simplify here. The todo // is that remove SimplifyForLoops below and change reduction schedule optim::SimplifyForLoops(&temp); - optim::SimplifyBlocks(&temp); vec_ast.emplace_back(temp); } else if (arg_pack[i].is_tensor()) { Expr temp = arg_pack[i]; diff --git a/paddle/cinn/hlir/pe/ir_schedule_pe.cc b/paddle/cinn/hlir/pe/ir_schedule_pe.cc index aeb91e4335e26..a4b661a26c7d8 100644 --- a/paddle/cinn/hlir/pe/ir_schedule_pe.cc +++ b/paddle/cinn/hlir/pe/ir_schedule_pe.cc @@ -32,7 +32,7 @@ #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/utils/ir_copy.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/poly/isl_utils.h" #include "paddle/cinn/utils/string.h" @@ -1300,9 +1300,9 @@ void IRCudaScheduleConv(ir::IRSchedule &ir_sch, // NOLINT int n = output->shape[0].as_int32(); int c = output->shape[1].as_int32(); - optim::Simplify(&(output->shape[2])); + optim::Simplify(&(output->shape[2]), optim::SimplifyType::kExpr); int h = output->shape[2].as_int32(); - optim::Simplify(&(output->shape[3])); + optim::Simplify(&(output->shape[3]), optim::SimplifyType::kExpr); int w = output->shape[3].as_int32(); int rc = input_pad->shape[1].as_int32(); @@ -1480,8 +1480,8 @@ void IRCudaScheduleConv2(ir::IRSchedule &ir_sch, // NOLINT // stages[input_pad]->ComputeInline(); - optim::Simplify(&(output->shape[2])); - optim::Simplify(&(output->shape[3])); + optim::Simplify(&(output->shape[2]), optim::SimplifyType::kExpr); + optim::Simplify(&(output->shape[3]), optim::SimplifyType::kExpr); VLOG(3) << "Begin IRCudaScheduleConv2 with expr : " << ir_sch.GetModule().GetExprs().at(0); diff --git a/paddle/cinn/hlir/pe/schedule.cc b/paddle/cinn/hlir/pe/schedule.cc index fada77826134b..31b61eb1086f3 100644 --- a/paddle/cinn/hlir/pe/schedule.cc +++ b/paddle/cinn/hlir/pe/schedule.cc @@ -27,7 +27,7 @@ #include "paddle/cinn/common/cas.h" #include "paddle/cinn/hlir/pe/load_x86_params.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/cinn/poly/isl_utils.h" #include "paddle/cinn/utils/string.h" #include "paddle/common/enforce.h" diff --git a/paddle/cinn/ir/ir.cc b/paddle/cinn/ir/ir.cc index 0b7e195f0d85c..ae63aa6895253 100644 --- a/paddle/cinn/ir/ir.cc +++ b/paddle/cinn/ir/ir.cc @@ -28,7 +28,6 @@ #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/ir/utils/ir_copy.h" -#include "paddle/cinn/optim/ir_simplify.h" #include "paddle/common/enforce.h" #include "paddle/common/errors.h" diff --git a/paddle/cinn/ir/ir_printer.cc b/paddle/cinn/ir/ir_printer.cc index 93f7c3b9861b1..5d4f767609fda 100644 --- a/paddle/cinn/ir/ir_printer.cc +++ b/paddle/cinn/ir/ir_printer.cc @@ -21,7 +21,6 @@ #include "paddle/cinn/ir/lowered_func.h" #include "paddle/cinn/ir/module.h" #include "paddle/cinn/ir/tensor.h" -#include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/runtime/intrinsic.h" #include "paddle/cinn/utils/string.h" #include "paddle/common/enforce.h" diff --git a/paddle/cinn/ir/module.cc b/paddle/cinn/ir/module.cc index d6c92481f706b..a754415e6cff7 100644 --- a/paddle/cinn/ir/module.cc +++ b/paddle/cinn/ir/module.cc @@ -17,7 +17,6 @@ #include #include "paddle/cinn/ir/ir_printer.h" -#include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/optim/optimize.h" #include "paddle/common/enforce.h" diff --git a/paddle/cinn/ir/schedule/impl/loop_transformation.cc b/paddle/cinn/ir/schedule/impl/loop_transformation.cc index e0797212ad4d7..973c56b3878bb 100644 --- a/paddle/cinn/ir/schedule/impl/loop_transformation.cc +++ b/paddle/cinn/ir/schedule/impl/loop_transformation.cc @@ -435,7 +435,7 @@ Expr DyScheduleImpl::Fuse(const std::vector& loops) { Expr fused_body = ir::ir_utils::IRCopy(for_nodes.back()->body); ReplaceExpr(&fused_body, loop_vars, substitute_value); - optim::Simplify(&fused_body); + optim::Simplify(&fused_body, optim::SimplifyType::kBlock); Expr fused_extent(int64_t(1)); for (int i = 0; i < loops_number; ++i) { fused_extent = fused_extent * for_nodes[i]->extent; diff --git a/paddle/cinn/ir/schedule/ir_schedule.cc b/paddle/cinn/ir/schedule/ir_schedule.cc index 6b1d161500ceb..0b09e06dc2033 100644 --- a/paddle/cinn/ir/schedule/ir_schedule.cc +++ b/paddle/cinn/ir/schedule/ir_schedule.cc @@ -40,7 +40,6 @@ #include "paddle/cinn/ir/schedule/ir_schedule_util.h" #include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/lang/compute.h" -#include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/utils/string.h" #include "paddle/common/enforce.h" diff --git a/paddle/cinn/ir/schedule/ir_schedule_util.cc b/paddle/cinn/ir/schedule/ir_schedule_util.cc index 756b76f271efb..a8b24e5c85592 100644 --- a/paddle/cinn/ir/schedule/ir_schedule_util.cc +++ b/paddle/cinn/ir/schedule/ir_schedule_util.cc @@ -33,7 +33,7 @@ #include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/ir/utils/ir_nodes_collector.h" #include "paddle/cinn/lang/compute.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/common/enforce.h" namespace cinn { diff --git a/paddle/cinn/lang/compute.cc b/paddle/cinn/lang/compute.cc index 0fea7f91daa9b..1e4f0fb30e3a0 100644 --- a/paddle/cinn/lang/compute.cc +++ b/paddle/cinn/lang/compute.cc @@ -17,7 +17,7 @@ #include "paddle/cinn/backends/extern_func_protos.h" #include "paddle/cinn/common/common.h" #include "paddle/cinn/ir/operation.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/cinn/poly/dim.h" #include "paddle/cinn/poly/domain.h" #include "paddle/cinn/poly/stage.h" @@ -179,13 +179,13 @@ ir::Tensor Compute(const std::vector &domain, // construct the shape. for (auto dim : domain) { auto copied = dim; - optim::Simplify(&copied); + optim::Simplify(&copied, optim::SimplifyType::kExpr); domain_without_reduce_axis.push_back(copied); } for (auto dim : shape) { auto copied = dim; - optim::Simplify(&copied); + optim::Simplify(&copied, optim::SimplifyType::kExpr); shape_simplified.push_back(copied); } diff --git a/paddle/cinn/lang/lower_impl.cc b/paddle/cinn/lang/lower_impl.cc index fdc88ae5c9f23..2eed252849f9a 100644 --- a/paddle/cinn/lang/lower_impl.cc +++ b/paddle/cinn/lang/lower_impl.cc @@ -25,7 +25,7 @@ #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/tensor.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/optim/transform_polyfor_to_for.h" #include "paddle/cinn/poly/stage.h" @@ -384,7 +384,6 @@ std::vector LowerImpl::operator()() { if (support_ir_schedule_) { optim::TransformPolyForToFor(&func->body); - optim::SimplifyBlocks(&func->body); func->body = ir::Block::Make({func->body}); result.push_back(func); num_func++; diff --git a/paddle/cinn/lang/lower_tensor_group.cc b/paddle/cinn/lang/lower_tensor_group.cc index 589f88f866f26..3364bae4d12d4 100644 --- a/paddle/cinn/lang/lower_tensor_group.cc +++ b/paddle/cinn/lang/lower_tensor_group.cc @@ -28,7 +28,7 @@ #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/tensor.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/optim/transform_polyfor_to_for.h" #include "paddle/cinn/poly/stage.h" @@ -147,7 +147,6 @@ std::vector LowerTensorGroup::operator()() { actual_fn_name, func_args, func_body, temp_buffers); // 6. Final clean up - optim::SimplifyBlocks(&func->body); func->body = ir::Block::Make({func->body}); result.push_back(ir::LoweredFunc(func.get())); num_func++; diff --git a/paddle/cinn/optim/CMakeLists.txt b/paddle/cinn/optim/CMakeLists.txt index d95c6e1d23840..0536628395dc2 100755 --- a/paddle/cinn/optim/CMakeLists.txt +++ b/paddle/cinn/optim/CMakeLists.txt @@ -5,7 +5,7 @@ gather_srcs( SRCS replace_call_with_expr.cc replace_var_with_expr.cc - ir_simplify.cc + ir_simplify_pass.cc optimize.cc vectorize_loops.cc unroll_loops.cc diff --git a/paddle/cinn/optim/cast_simplify_test.cc b/paddle/cinn/optim/cast_simplify_test.cc index d9f9ffab1be6c..4b28add5acf10 100644 --- a/paddle/cinn/optim/cast_simplify_test.cc +++ b/paddle/cinn/optim/cast_simplify_test.cc @@ -16,7 +16,7 @@ #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/op/ir_operators.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" namespace cinn::optim { TEST(CastSimplify, same_type) { diff --git a/paddle/cinn/optim/ir_simplify.cc b/paddle/cinn/optim/ir_simplify.cc deleted file mode 100644 index 7dc54f5b47c1a..0000000000000 --- a/paddle/cinn/optim/ir_simplify.cc +++ /dev/null @@ -1,531 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/optim/ir_simplify.h" - -#include -#include -#include - -#include -#include - -#include "paddle/cinn/common/arithmetic.h" -#include "paddle/cinn/common/cas.h" -#include "paddle/cinn/common/ir_util.h" -#include "paddle/cinn/ir/ir_mutator.h" -#include "paddle/cinn/ir/ir_printer.h" -#include "paddle/cinn/ir/ir_visitor.h" -#include "paddle/cinn/ir/op/ir_operators.h" -#include "paddle/cinn/ir/tensor.h" -#include "paddle/cinn/ir/utils/ir_copy.h" -#include "paddle/cinn/utils/string.h" - -namespace cinn { -namespace optim { -using namespace ir; // NOLINT -using cinn::common::bfloat16; -using cinn::common::ExprToGinacConverter; -using cinn::common::float16; -using utils::GetStreamCnt; -using utils::Replace; - -namespace { - -//! Simplify the expression but Load. -struct SimplifyNoPureMathMutator : public ir::IRMutator { - void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } - - using ir::IRMutator<>::Visit; - -#define __(op__) \ - void Visit(const op__* op, Expr* expr) override { \ - *expr = ArithSimplify(*expr); \ - } - - __(Add) - __(Mul) - __(Sub) - __(Div) - __(Min) - __(Max) -#undef __ -}; - -struct SimplifyLoadMutator : public ir::IRMutator { - void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } - - void Visit(const Load* expr, Expr* op) override { - auto* node = op->As(); - for (auto& idx : node->indices) { - if (cinn::common::IsPureMath(idx)) { - idx = ArithSimplify(idx); - } else { - SimplifyNoPureMathMutator()(&idx); - } - } - } - - void Visit(const For* op, Expr* expr) override { - auto* node = expr->As(); - operator()(&node->body); - operator()(&node->extent); - } -}; - -struct SimplifyStoreMutator : public ir::IRMutator { - void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } - - void Visit(const Store* expr, Expr* op) override { - auto* node = op->As(); - - for (auto& idx : node->indices) { - if (cinn::common::IsPureMath(idx)) { - idx = ArithSimplify(idx); - } else { - SimplifyNoPureMathMutator()(&idx); - } - } - } - - void Visit(const For* op, Expr* expr) override { - auto* node = expr->As(); - operator()(&node->body); - operator()(&node->extent); - } -}; - -struct SimplifyRampMutator : public ir::IRMutator { - void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } - - void Visit(const Ramp* op, Expr* expr) override { - auto* node = expr->As(); - - PADDLE_ENFORCE_EQ( - cinn::common::IsPureMath(node->base), - true, - ::common::errors::InvalidArgument("node->base is not a pure math!")); - PADDLE_ENFORCE_EQ( - cinn::common::IsPureMath(node->stride), - true, - ::common::errors::InvalidArgument("node->stride is not a pure math!")); - node->base = ArithSimplify(node->base); - node->stride = ArithSimplify(node->stride); - } - // ramp + ramp - void Visit(const Add* op, Expr* expr) override { - auto* node = expr->As(); - Expr a = node->a(); - Expr b = node->b(); - auto a_ramp = a.As(); - auto b_ramp = b.As(); - - if (a_ramp && b_ramp && a_ramp->lanes == b_ramp->lanes) { - Expr base_add = optim::ArithSimplify(a_ramp->base + b_ramp->base); - Expr stride_add = optim::ArithSimplify(a_ramp->stride + b_ramp->stride); - *expr = ir::Ramp::Make(base_add, stride_add, a_ramp->lanes); - } - } -}; - -struct SimplifyIfThenElseMutator : public ir::IRMutator<> { - void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); } - - using ir::IRMutator<>::Visit; - - void Visit(const IfThenElse* op, Expr* expr) override { - auto* node = expr->As(); - - auto* condition_int = node->condition.As(); - auto* condition_uint = node->condition.As(); - - // not deterministic - if (!condition_int && !condition_uint) { - Visit(&node->true_case, &node->true_case); - if (node->false_case.defined()) { - Visit(&node->false_case, &node->false_case); - } - return; - } - - bool value = condition_int ? condition_int->value : condition_uint->value; - if (value) { - *expr = op->true_case; - Visit(expr, expr); - } else if (op->false_case.defined()) { - *expr = op->false_case; - Visit(expr, expr); - } else { - *expr = ir::Block::Make({}); - } - } -}; - -struct SimplifySelectMutator : public ir::IRMutator<> { - void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); } - - using ir::IRMutator<>::Visit; - - void Visit(const Select* op, Expr* expr) override { - auto* node = expr->As(); - - auto* condition_int = node->condition.As(); - auto* condition_uint = node->condition.As(); - - // not deterministic - if (!condition_int && !condition_uint) { - Visit(&node->true_value, &node->true_value); - Visit(&node->false_value, &node->false_value); - return; - } - - bool value = condition_int ? condition_int->value : condition_uint->value; - if (value) { - *expr = op->true_value; - Visit(expr, expr); - } else { - *expr = op->false_value; - Visit(expr, expr); - } - } -}; - -struct SimplifyLogicalMutator : public ir::ExprMutator<> { - void operator()(Expr* expr) { ir::ExprMutator<>::Visit(expr, expr); } - -#define DEFINE_VISIT_CMP_OP(OpType, Method) \ - void Visit(const ir::OpType* op, Expr* expr) override { \ - VLOG(7) << "Begin Visit Cmp op: " << *expr; \ - auto* node = expr->As(); \ - ir::ExprMutator<>::Visit(&node->a(), &node->a()); \ - ir::ExprMutator<>::Visit(&node->b(), &node->b()); \ - if (node->a().is_constant() && node->b().is_constant()) \ - if (node->a().get_constant() Method node->b().get_constant()) \ - *expr = Expr(true); \ - VLOG(7) << "End Visit Cmp op: " << *expr; \ - } - DEFINE_VISIT_CMP_OP(LE, <=) - DEFINE_VISIT_CMP_OP(LT, <) - DEFINE_VISIT_CMP_OP(GE, >=) - DEFINE_VISIT_CMP_OP(GT, >) - DEFINE_VISIT_CMP_OP(EQ, ==) - DEFINE_VISIT_CMP_OP(NE, !=) - -#undef DEFINE_VISIT_CMP_OP - - void Visit(const ir::And* op, Expr* expr) override { - VLOG(7) << "Begin Visit And op: " << *expr; - auto* node = expr->As(); - ir::ExprMutator<>::Visit(&node->a(), &node->a()); - if (common::IsZero(node->a())) { - *expr = Expr(false); - VLOG(7) << "End Visit And op: " << *expr; - return; - } - ir::ExprMutator<>::Visit(&node->b(), &node->b()); - if (common::IsZero(node->b())) { - VLOG(7) << "End Visit And op: " << *expr; - *expr = Expr(false); - return; - } - if (common::IsOne(node->a()) && common::IsOne(node->b())) - *expr = Expr(true); - VLOG(7) << "End Visit And op: " << *expr; - } - - void Visit(const ir::Or* op, Expr* expr) override { - VLOG(7) << "Begin Visit Or op: " << *expr; - auto* node = expr->As(); - ir::ExprMutator<>::Visit(&node->a(), &node->a()); - if (common::IsOne(node->a())) { - *expr = Expr(true); - VLOG(7) << "End visit Or op: " << *expr; - return; - } - ir::ExprMutator<>::Visit(&node->b(), &node->b()); - if (common::IsOne(node->b())) { - *expr = Expr(true); - VLOG(7) << "End visit Or op: " << *expr; - return; - } - if (common::IsZero(node->a()) && common::IsZero(node->b())) - *expr = Expr(false); - VLOG(7) << "End visit Or op: " << *expr; - } - - void Visit(const ir::Not* op, Expr* expr) override { - VLOG(7) << "Begin Visit Not op: " << *expr; - auto* node = expr->As(); - auto v = node->v(); - ir::ExprMutator<>::Visit(&v, &v); - switch (v.node_type()) { - case ir::IrNodeTy::IntImm: - case ir::IrNodeTy::UIntImm: - *expr = common::IsZero(v) ? Expr(true) : Expr(false); - return; - case ir::IrNodeTy::Not: - *expr = v.As()->v(); - return; - case ir::IrNodeTy::LE: - *expr = ir::GT::Make(v->operand(0), v->operand(1)); - return; - case ir::IrNodeTy::LT: - *expr = ir::GE::Make(v->operand(0), v->operand(1)); - return; - case ir::IrNodeTy::GE: - *expr = ir::LT::Make(v->operand(0), v->operand(1)); - return; - case ir::IrNodeTy::GT: - *expr = ir::LE::Make(v->operand(0), v->operand(1)); - return; - default: - VLOG(7) << "End Visit Not op: " << *expr; - return; - } - VLOG(7) << "End Visit Not op: " << *expr; - } -}; - -struct ReplaceFracWithDivMutator : public ir::IRMutator<> { - void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); } - - void Visit(const FracOp* op, Expr* expr) override { - auto* node = expr->As(); - - ir::IRMutator<>::Visit(&node->operand(0), &node->operand(0)); - ir::IRMutator<>::Visit(&node->operand(1), &node->operand(1)); - - *expr = ir::Div::Make(node->operand(0), node->operand(1)); - } -}; - -struct SimplifyBlocksMutator : public ir::IRMutator<> { - SimplifyBlocksMutator() {} - - void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } - - using ir::IRMutator<>::Visit; - - void Visit(const Block* op, Expr* expr) override { - auto* node = expr->As(); - - if (node->stmts.size() == 1 && node->stmts[0].As()) { - VLOG(6) << "Simplify size-1 ir::Block"; - *expr = node->stmts[0]; - Visit(expr, expr); - } else { - for (auto& s : node->stmts) { - Visit(&s, &s); - } - std::vector stmts; - for (auto& s : node->stmts) { - if (s.As()) { - VLOG(6) << "Simplify ir::Block inside ir::Block"; - auto inner_block = s.As(); - for (auto inner_stmt : inner_block->stmts) { - stmts.push_back(inner_stmt); - } - } else { - stmts.push_back(s); - } - } - expr->As()->stmts = stmts; - } - } - - void Visit(const ScheduleBlock* op, Expr* expr) override { - auto* node = expr->As(); - PADDLE_ENFORCE_NOT_NULL(node, - ::common::errors::InvalidArgument( - "The node expr->As() is null")); - for (auto& var : node->iter_vars) { - if (var->lower_bound.defined()) { - Visit(&var->lower_bound, &var->lower_bound); - } - if (var->upper_bound.defined()) { - Visit(&var->upper_bound, &var->upper_bound); - } - } - for (auto& buffer_region : node->read_buffers) { - Visit(&buffer_region, &buffer_region); - } - for (auto& buffer_region : node->write_buffers) { - Visit(&buffer_region, &buffer_region); - } - - if (node->body.As()) { - if (node->body.As()->stmts.size() == 1) { - node->body = node->body.As()->stmts[0]; - } - } - - Visit(&(node->body), &(node->body)); - } -}; - -struct SimplifyForLoopsMutator : public ir::IRMutator<> { - absl::flat_hash_map var_mins; - SimplifyForLoopsMutator() {} - - void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } - - using ir::IRMutator<>::Visit; - - void Visit(const For* op, Expr* expr) override { - auto* node = expr->As(); - Visit(&node->min, &node->min); - Visit(&node->extent, &node->extent); - auto* min_i = node->min.As(); - auto* extent_i = node->extent.As(); - if (min_i && extent_i && extent_i->value - min_i->value == 1) { - VLOG(6) << "Simplify current For Loop"; - std::string var_name = node->loop_var->name; - var_mins.emplace(var_name, node->min); - - *expr = node->body; - - Visit(expr, expr); - var_mins.erase(var_name); - } else { - Visit(&node->body, &node->body); - } - } - - void Visit(const _Var_* op, Expr* expr) override { - auto* node = expr->As(); - - if (var_mins.count(node->name)) { - *expr = var_mins.at(node->name); - } - } -}; - -template -CastType NormCastValue(T value) { - if (type_of().is_uint() || type_of().is_uint()) { - // not support uint - return static_cast(value); - } - - if (std::isinf(value)) { - if (CastType(value) == -std::numeric_limits::infinity()) { - return -std::numeric_limits::infinity(); - } - return std::numeric_limits::infinity(); - } else if (std::isnan(value)) { - return std::numeric_limits::signaling_NaN(); - } else if (value >= static_cast(std::numeric_limits::max())) { - return std::numeric_limits::max(); - } else if (value <= static_cast(std::numeric_limits::lowest())) { - return std::numeric_limits::lowest(); - } - return static_cast(value); -} - -struct SimplifyCastMutator : public ir::IRMutator<> { - void operator()(Expr* expr) { ir::IRMutator::Visit(expr, expr); } - - void Visit(const ir::Cast* op, Expr* expr) { - auto* node = expr->As(); - - ir::IRMutator::Visit(&node->v(), &node->v()); - - if (op->type() == op->v().type()) { - *expr = op->v(); - return; - } - -#define __CAST_TO_TYPE(type__) \ - if (auto* i = op->v().As()) { \ - *expr = Expr(static_cast(i->value)); \ - } else if (auto* f = op->v().As()) { \ - *expr = Expr(static_cast(NormCastValue(f->value))); \ - } else if (auto* u = op->v().As()) { \ - *expr = Expr(static_cast(u->value)); \ - } else { \ - CINN_NOT_IMPLEMENTED \ - } - - if (op->v().is_constant()) { - if (op->type() == type_of()) { - __CAST_TO_TYPE(int8_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(int16_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(int32_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(int64_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(uint8_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(uint16_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(uint32_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(uint64_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(float) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(double) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(bool) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(uint32_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(uint64_t) - } else if (op->type() == type_of()) { - // Cannot simplify!!! pass - __CAST_TO_TYPE(bfloat16) - } else if (op->type() == type_of()) { - // Cannot simplify!!! pass - __CAST_TO_TYPE(float16) - } else { - CINN_NOT_IMPLEMENTED - } - } -#undef __CAST_TO_TYPE - } -}; - -} // namespace - -void SimplifyCast(Expr* expr) { SimplifyCastMutator()(expr); } -void SimplifyForLoops(Expr* expr) { SimplifyForLoopsMutator()(expr); } -void SimplifyBlocks(Expr* expr) { SimplifyBlocksMutator()(expr); } - -void SimplifyLogical(Expr* expr) { SimplifyLogicalMutator()(expr); } - -Expr ArithSimplify(const Expr& u) { - if (!u.is_index()) return u; - auto copied = ir_utils::IRCopy(u); - return copied.as_index().Normalize(); -} - -void Simplify(Expr* expr) { - VLOG(6) << "Begin Simplify " << *expr; - SimplifyNoPureMathMutator()(expr); - SimplifyCastMutator()(expr); - SimplifyRampMutator()(expr); - SimplifyLoadMutator()(expr); - SimplifyStoreMutator()(expr); - SimplifyLogicalMutator()(expr); - SimplifyIfThenElseMutator()(expr); - SimplifySelectMutator()(expr); - SimplifyNoPureMathMutator()(expr); - - ReplaceFracWithDivMutator()(expr); - VLOG(6) << "End Simplify " << *expr; -} -} // namespace optim -} // namespace cinn diff --git a/paddle/cinn/optim/ir_simplify_pass.cc b/paddle/cinn/optim/ir_simplify_pass.cc new file mode 100644 index 0000000000000..d59341f676d73 --- /dev/null +++ b/paddle/cinn/optim/ir_simplify_pass.cc @@ -0,0 +1,762 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/optim/ir_simplify_pass.h" + +#include +#include +#include + +#include +#include + +#include "paddle/cinn/common/arithmetic.h" +#include "paddle/cinn/common/cas.h" +#include "paddle/cinn/common/ir_util.h" +#include "paddle/cinn/ir/ir_mutator.h" +#include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/ir_visitor.h" +#include "paddle/cinn/ir/op/ir_operators.h" +#include "paddle/cinn/ir/tensor.h" +#include "paddle/cinn/ir/utils/ir_copy.h" +#include "paddle/cinn/ir/utils/stmt_converter.h" +#include "paddle/cinn/pass/pass.h" +#include "paddle/cinn/pass/pass_manager.h" +#include "paddle/cinn/utils/string.h" + +namespace cinn { +namespace optim { +using namespace ir; // NOLINT +using cinn::common::bfloat16; +using cinn::common::ExprToGinacConverter; +using cinn::common::float16; +using utils::GetStreamCnt; +using utils::Replace; + +namespace { + +//! Simplify the expression but Load. +class SimplifyNoPureMathMutator : public ir::IRMutator { + public: + void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } + + private: + using ir::IRMutator<>::Visit; + +#define __(op__) \ + void Visit(const op__* op, Expr* expr) override { \ + *expr = ArithSimplify(*expr); \ + } + + __(Add) + __(Mul) + __(Sub) + __(Div) + __(Min) + __(Max) +#undef __ +}; + +class SimplifyLoadMutator : public ir::IRMutator { + public: + void operator()(ir::Expr* x) { ir::IRMutator::Visit(x, x); } + + private: + void Visit(const Load* expr, Expr* op) override { + auto* node = op->As(); + for (auto& idx : node->indices) { + if (cinn::common::IsPureMath(idx)) { + idx = ArithSimplify(idx); + } else { + SimplifyNoPureMathMutator()(&idx); + } + } + } +}; + +class SimplifyStoreMutator { + public: + void operator()(ir::stmt::BlockRef block) { + for (auto& stmt : block->stmts()) { + if (stmt->stmt_type() == ir::StmtNodeTy::Store) { + VisitStmt(stmt.as()); + } + } + } + + private: + void VisitStmt(ir::stmt::Store stmt) { + std::vector new_indices = stmt->indices(); + for (ir::Expr& index : new_indices) { + if (cinn::common::IsPureMath(index)) { + index = ArithSimplify(index); + } else { + SimplifyNoPureMathMutator()(&index); + } + } + stmt->set_indices(new_indices); + } +}; + +class SimplifyRampMutator : public ir::IRMutator { + public: + void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } + + private: + void Visit(const Ramp* op, Expr* expr) override { + auto* node = expr->As(); + + PADDLE_ENFORCE_EQ( + cinn::common::IsPureMath(node->base), + true, + ::common::errors::InvalidArgument("node->base is not a pure math!")); + PADDLE_ENFORCE_EQ( + cinn::common::IsPureMath(node->stride), + true, + ::common::errors::InvalidArgument("node->stride is not a pure math!")); + node->base = ArithSimplify(node->base); + node->stride = ArithSimplify(node->stride); + } + // ramp + ramp + void Visit(const Add* op, Expr* expr) override { + auto* node = expr->As(); + Expr a = node->a(); + Expr b = node->b(); + auto a_ramp = a.As(); + auto b_ramp = b.As(); + + if (a_ramp && b_ramp && a_ramp->lanes == b_ramp->lanes) { + Expr base_add = cinn::common::AutoSimplify(a_ramp->base + b_ramp->base); + Expr stride_add = + cinn::common::AutoSimplify(a_ramp->stride + b_ramp->stride); + *expr = ir::Ramp::Make(base_add, stride_add, a_ramp->lanes); + } + } +}; + +class SimplifyIfThenElseMutator { + public: + void operator()(ir::stmt::BlockRef block) { VisitBlock(block); } + + private: + void VisitBlock(ir::stmt::BlockRef block) { + std::unordered_set empty_stmt_id; + std::vector stmts = block->stmts(); + for (int i = 0; i < stmts.size(); i++) { + if (stmts[i].isa()) + if (IsEmptyIf(stmts[i].as())) + empty_stmt_id.insert(i); + } + + std::vector new_stmts; + for (int i = 0; i < stmts.size(); i++) { + if (!empty_stmt_id.count(i)) new_stmts.push_back(stmts[i]); + } + block->set_stmts(new_stmts); + } + + bool IsEmptyIf(ir::stmt::IfThenElse stmt) { + const Expr& condition = stmt->condition(); + stmt->set_condition(ArithSimplify(condition)); + + auto* condition_int = stmt->condition().As(); + auto* condition_uint = stmt->condition().As(); + + // not deterministic + if (!condition_int && !condition_uint) { + VisitBlock(stmt->true_case()); + if (stmt->false_case().defined()) { + VisitBlock(stmt->false_case()); + } + return false; + } + + bool value = condition_int ? condition_int->value : condition_uint->value; + if (value) { + VisitBlock(stmt->true_case()); + return false; + } else if (stmt->false_case().defined()) { + VisitBlock(stmt->false_case()); + return false; + } else { + return true; + } + } +}; + +class SimplifySelectMutator : public ir::IRMutator<> { + public: + void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); } + + private: + void Visit(const Select* op, Expr* expr) override { + auto* node = expr->As(); + + auto* condition_int = node->condition.As(); + auto* condition_uint = node->condition.As(); + + // not deterministic + if (!condition_int && !condition_uint) { + operator()(&node->true_value); + operator()(&node->false_value); + return; + } + + bool value = condition_int ? condition_int->value : condition_uint->value; + if (value) { + *expr = op->true_value; + ir::IRMutator<>::Visit(expr, expr); + } else { + *expr = op->false_value; + ir::IRMutator<>::Visit(expr, expr); + } + } +}; + +struct SimplifyLogicalMutator : public ir::ExprMutator<> { + public: + void operator()(Expr* expr) { ir::ExprMutator<>::Visit(expr, expr); } + + private: +#define DEFINE_VISIT_CMP_OP(OpType, Method) \ + void Visit(const ir::OpType* op, Expr* expr) override { \ + VLOG(7) << "Begin Visit Cmp op: " << *expr; \ + auto* node = expr->As(); \ + ir::ExprMutator<>::Visit(&node->a(), &node->a()); \ + ir::ExprMutator<>::Visit(&node->b(), &node->b()); \ + if (node->a().is_constant() && node->b().is_constant()) \ + if (node->a().get_constant() Method node->b().get_constant()) \ + *expr = Expr(true); \ + VLOG(7) << "End Visit Cmp op: " << *expr; \ + } + DEFINE_VISIT_CMP_OP(LE, <=) + DEFINE_VISIT_CMP_OP(LT, <) + DEFINE_VISIT_CMP_OP(GE, >=) + DEFINE_VISIT_CMP_OP(GT, >) + DEFINE_VISIT_CMP_OP(EQ, ==) + DEFINE_VISIT_CMP_OP(NE, !=) + +#undef DEFINE_VISIT_CMP_OP + + void Visit(const ir::And* op, Expr* expr) override { + VLOG(7) << "Begin Visit And op: " << *expr; + auto* node = expr->As(); + ir::ExprMutator<>::Visit(&node->a(), &node->a()); + if (common::IsZero(node->a())) { + *expr = Expr(false); + VLOG(7) << "End Visit And op: " << *expr; + return; + } + ir::ExprMutator<>::Visit(&node->b(), &node->b()); + if (common::IsZero(node->b())) { + VLOG(7) << "End Visit And op: " << *expr; + *expr = Expr(false); + return; + } + if (common::IsOne(node->a()) && common::IsOne(node->b())) + *expr = Expr(true); + VLOG(7) << "End Visit And op: " << *expr; + } + + void Visit(const ir::Or* op, Expr* expr) override { + VLOG(7) << "Begin Visit Or op: " << *expr; + auto* node = expr->As(); + ir::ExprMutator<>::Visit(&node->a(), &node->a()); + if (common::IsOne(node->a())) { + *expr = Expr(true); + VLOG(7) << "End visit Or op: " << *expr; + return; + } + ir::ExprMutator<>::Visit(&node->b(), &node->b()); + if (common::IsOne(node->b())) { + *expr = Expr(true); + VLOG(7) << "End visit Or op: " << *expr; + return; + } + if (common::IsZero(node->a()) && common::IsZero(node->b())) + *expr = Expr(false); + VLOG(7) << "End visit Or op: " << *expr; + } + + void Visit(const ir::Not* op, Expr* expr) override { + VLOG(7) << "Begin Visit Not op: " << *expr; + auto* node = expr->As(); + auto v = node->v(); + ir::ExprMutator<>::Visit(&v, &v); + switch (v.node_type()) { + case ir::IrNodeTy::IntImm: + case ir::IrNodeTy::UIntImm: + *expr = common::IsZero(v) ? Expr(true) : Expr(false); + return; + case ir::IrNodeTy::Not: + *expr = v.As()->v(); + return; + case ir::IrNodeTy::LE: + *expr = ir::GT::Make(v->operand(0), v->operand(1)); + return; + case ir::IrNodeTy::LT: + *expr = ir::GE::Make(v->operand(0), v->operand(1)); + return; + case ir::IrNodeTy::GE: + *expr = ir::LT::Make(v->operand(0), v->operand(1)); + return; + case ir::IrNodeTy::GT: + *expr = ir::LE::Make(v->operand(0), v->operand(1)); + return; + default: + VLOG(7) << "End Visit Not op: " << *expr; + return; + } + VLOG(7) << "End Visit Not op: " << *expr; + } +}; + +class SimplifyLogicalPass : public ExprPass { + public: + SimplifyLogicalPass() : ExprPass("simplify_logical_pass") {} + + LogicalResult Run(ir::Expr* expr) override { + SimplifyLogicalMutator mutator; + mutator(expr); + return LogicalResult::success(); + } +}; +std::unique_ptr CreateSimplifyLogicalPass() { + return std::make_unique(); +} + +struct ReplaceFracWithDivMutator : public ir::IRMutator<> { + public: + void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); } + + private: + void Visit(const FracOp* op, Expr* expr) override { + auto* node = expr->As(); + + ir::IRMutator<>::Visit(&node->operand(0), &node->operand(0)); + ir::IRMutator<>::Visit(&node->operand(1), &node->operand(1)); + + *expr = ir::Div::Make(node->operand(0), node->operand(1)); + } +}; + +class SimplifyFracWithDivPass : public ExprPass { + public: + SimplifyFracWithDivPass() : ExprPass("simplify_frac_with_div_pass") {} + + LogicalResult Run(ir::Expr* expr) override { + ReplaceFracWithDivMutator mutator; + mutator(expr); + return LogicalResult::success(); + } +}; +std::unique_ptr CreateSimplifyFracWithDivPass() { + return std::make_unique(); +} + +class SimplifyForLoopsMutator : public ir::IRMutator, + public ir::stmt::StmtMutator { + public: + void operator()(ir::Expr* x) { ir::IRMutator::Visit(x, x); } + + void operator()(ir::stmt::BlockRef block) { VisitBlock(block); } + + private: + bool VisitBlock(ir::stmt::BlockRef block) override { + std::vector stmts = block->stmts(); + for (auto i = stmts.size() - 1; i >= 0; i--) { + if (stmts[i].isa()) { + if (VisitStmt(stmts[i].as())) { + const std::vector& inner_stmts = + stmts[i].as()->body()->stmts(); + stmts.insert( + stmts.begin() + i, inner_stmts.begin(), inner_stmts.end()); + stmts.erase(stmts.begin() + i + inner_stmts.size()); + } + } else { + ir::stmt::StmtMutator::VisitStmt(stmts[i]); + } + } + block->set_stmts(stmts); + } + + bool VisitStmt(ir::stmt::For stmt) override { + Expr min = stmt->min(); + Expr extent = stmt->extent(); + operator()(&min); + operator()(&extent); + stmt->set_min(min); + stmt->set_extent(extent); + const IntImm* min_i = stmt->min().As(); + const IntImm* extent_i = stmt->extent().As(); + if (min_i && extent_i && (extent_i->value - min_i->value) == 1) { + VLOG(6) << "Simplify current For Loop"; + std::string var_name = stmt->loop_var()->name; + var_mins.emplace(var_name, stmt->min()); + VisitBlock(stmt->body()); + var_mins.erase(var_name); + return true; + } else { + VisitBlock(stmt->body()); + return false; + } + } + + bool VisitStmt(ir::stmt::IfThenElse stmt) override { + Expr condition = stmt->condition(); + operator()(&condition); + stmt->set_condition(condition); + + VisitBlock(stmt->true_case()); + if (stmt->false_case().defined()) { + VisitBlock(stmt->false_case()); + } + } + + bool VisitStmt(ir::stmt::Schedule stmt) override { + std::vector iter_values = stmt->iter_values(); + std::vector read_buffers = stmt->read_buffers(); + std::vector write_buffers = stmt->write_buffers(); + + for (Expr& iter_value : iter_values) operator()(&iter_value); + for (Expr& read_buffer : read_buffers) operator()(&read_buffer); + for (Expr& write_buffer : write_buffers) operator()(&write_buffer); + + stmt->set_iter_values(iter_values); + stmt->set_read_buffers(read_buffers); + stmt->set_write_buffers(write_buffers); + } + + bool VisitStmt(ir::stmt::Let stmt) override { + Expr value = stmt->body(); + operator()(&value); + stmt->set_body(value); + } + + bool VisitStmt(ir::stmt::Store stmt) override { + Expr value = stmt->value(); + operator()(&value); + stmt->set_value(value); + + std::vector indices = stmt->indices(); + for (Expr& index : indices) { + operator()(&index); + } + stmt->set_indices(indices); + } + + bool VisitStmt(ir::stmt::Alloc stmt) override { + Expr condition = stmt->condition(); + operator()(&condition); + stmt->set_condition(condition); + } + + bool VisitStmt(ir::stmt::Free stmt) override { + Expr destination = stmt->destination(); + operator()(&destination); + stmt->set_destination(destination); + } + + bool VisitStmt(ir::stmt::Evaluate stmt) override {} + + void Visit(const _Var_* op, Expr* expr) override { + auto* node = expr->As(); + + if (var_mins.count(node->name)) { + *expr = var_mins.at(node->name); + } + } + + absl::flat_hash_map var_mins; +}; + +template +CastType NormCastValue(T value) { + if (type_of().is_uint() || type_of().is_uint()) { + // not support uint + return static_cast(value); + } + + if (std::isinf(value)) { + if (CastType(value) == -std::numeric_limits::infinity()) { + return -std::numeric_limits::infinity(); + } + return std::numeric_limits::infinity(); + } else if (std::isnan(value)) { + return std::numeric_limits::signaling_NaN(); + } else if (value >= static_cast(std::numeric_limits::max())) { + return std::numeric_limits::max(); + } else if (value <= static_cast(std::numeric_limits::lowest())) { + return std::numeric_limits::lowest(); + } + return static_cast(value); +} + +class SimplifyCastMutator : public ir::IRMutator<> { + public: + void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); } + + private: + void Visit(const ir::Cast* op, Expr* expr) override { + auto* node = expr->As(); + + ir::IRMutator::Visit(&node->v(), &node->v()); + + if (op->type() == op->v().type()) { + *expr = op->v(); + return; + } + +#define __CAST_TO_TYPE(type__) \ + if (auto* i = op->v().As()) { \ + *expr = Expr(static_cast(i->value)); \ + } else if (auto* f = op->v().As()) { \ + *expr = Expr(static_cast(NormCastValue(f->value))); \ + } else if (auto* u = op->v().As()) { \ + *expr = Expr(static_cast(u->value)); \ + } else { \ + CINN_NOT_IMPLEMENTED \ + } + + if (op->v().is_constant()) { + if (op->type() == type_of()) { + __CAST_TO_TYPE(int8_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(int16_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(int32_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(int64_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint8_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint16_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint32_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint64_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(float) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(double) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(bool) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint32_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint64_t) + } else if (op->type() == type_of()) { + // Cannot simplify!!! pass + __CAST_TO_TYPE(bfloat16) + } else if (op->type() == type_of()) { + // Cannot simplify!!! pass + __CAST_TO_TYPE(float16) + } else { + CINN_NOT_IMPLEMENTED + } + } +#undef __CAST_TO_TYPE + } +}; +} // namespace + +class SimplifyNoPureMathPass : public ExprPass { + public: + SimplifyNoPureMathPass() : ExprPass("simplify_no_pure_math_pass") {} + + LogicalResult Run(ir::Expr* expr) override { + SimplifyNoPureMathMutator mutator; + mutator(expr); + return LogicalResult::success(); + } +}; +std::unique_ptr CreateSimplifyNoPureMathPass() { + return std::make_unique(); +} + +class SimplifyLoadPass : public ExprPass { + public: + SimplifyLoadPass() : ExprPass("simplify_load_pass") {} + + LogicalResult Run(ir::Expr* expr) override { + SimplifyLoadMutator mutator; + mutator(expr); + return LogicalResult::success(); + } +}; +std::unique_ptr CreateSimplifyLoadPass() { + return std::make_unique(); +} + +class SimplifyStorePass : public BlockPass { + public: + SimplifyStorePass() : BlockPass("simplify_store_pass") {} + + LogicalResult Run(ir::stmt::BlockRef block) override { + SimplifyStoreMutator mutator; + mutator(block); + return LogicalResult::success(); + } +}; +std::unique_ptr CreateSimplifyStorePass() { + return std::make_unique(); +} + +class SimplifySelectPass : public ExprPass { + public: + SimplifySelectPass() : ExprPass("simplify_select_pass") {} + + LogicalResult Run(ir::Expr* expr) override { + SimplifySelectMutator mutator; + mutator(expr); + return LogicalResult::success(); + } +}; +std::unique_ptr CreateSimplifySelectPass() { + return std::make_unique(); +} + +class SimplifyForLoopsPass : public BlockPass { + public: + SimplifyForLoopsPass() : BlockPass("simplify_for_loops_pass") {} + + LogicalResult Run(ir::stmt::BlockRef block) override { + SimplifyForLoopsMutator()(block); + return LogicalResult::success(); + } +}; + +std::unique_ptr CreateSimplifyForLoopsPass() { + return std::make_unique(); +} + +class SimplifyIfThenElsePass : public BlockPass { + public: + SimplifyIfThenElsePass() : BlockPass("simplify_ifthenelse_loops_pass") {} + + LogicalResult Run(ir::stmt::BlockRef block) override { + SimplifyIfThenElseMutator()(block); + return LogicalResult::success(); + } +}; + +std::unique_ptr CreateSimplifyIfThenElsePass() { + return std::make_unique(); +} + +class SimplifyRampPass : public ExprPass { + public: + SimplifyRampPass() : ExprPass("simplify_ramp_pass") {} + + LogicalResult Run(ir::Expr* expr) override { + SimplifyRampMutator mutator; + mutator(expr); + return LogicalResult::success(); + } +}; +std::unique_ptr CreateSimplifyRampPass() { + return std::make_unique(); +} + +class SimplifyCastPass : public ExprPass { + public: + SimplifyCastPass() : ExprPass("simplify_cast_pass") {} + + LogicalResult Run(ir::Expr* expr) override { + SimplifyCastMutator mutator; + mutator(expr); + return LogicalResult::success(); + } +}; + +std::unique_ptr CreateSimplifyCastPass() { + return std::make_unique(); +} + +Expr ArithSimplify(const Expr& u) { + if (!u.is_index()) return u; + auto copied = ir_utils::IRCopy(u); + return copied.as_index().Normalize(); +} + +void Simplify(ir::Expr* expr, const SimplifyType type) { + VLOG(6) << "Begin Simplify: \n " << *expr; + + switch (type) { + case SimplifyType::kExpr: + SimplifyNoPureMathMutator()(expr); + SimplifyCastMutator()(expr); + SimplifyRampMutator()(expr); + SimplifyLoadMutator()(expr); + SimplifyLogicalMutator()(expr); + SimplifySelectMutator()(expr); + SimplifyNoPureMathMutator()(expr); + break; + case SimplifyType::kBlock: + PADDLE_ENFORCE_EQ(expr->node_type(), + ir::Block::_node_type_, + ::common::errors::InvalidArgument( + "The Expr to simplify must be a block.")); + ir::stmt::BlockRef _block = ir::ConvertExprBlockToStmtBlock(*expr); + optim::ExprPassManager expr_pass_manager; + expr_pass_manager.AddPass(CreateSimplifyNoPureMathPass()); + expr_pass_manager.AddPass(CreateSimplifyCastPass()); + expr_pass_manager.AddPass(CreateSimplifyRampPass()); + expr_pass_manager.AddPass(CreateSimplifyLoadPass()); + expr_pass_manager.AddPass(CreateSimplifyLogicalPass()); + expr_pass_manager.AddPass(CreateSimplifySelectPass()); + expr_pass_manager.Run(_block); + + optim::BlockPassManager block_pass_manager; + block_pass_manager.AddPass(CreateSimplifyStorePass()); + block_pass_manager.AddPass(CreateSimplifyIfThenElsePass()); + block_pass_manager.Run(_block); + + optim::ExprPassManager expr_pass_manager1; + expr_pass_manager1.AddPass(CreateSimplifyNoPureMathPass()); + expr_pass_manager1.AddPass(CreateSimplifyFracWithDivPass()); + expr_pass_manager1.Run(_block); + + *expr = ir::ConvertStmtBlockToExprBlock(_block); + break; + } + + VLOG(6) << "End Simplify: \n" << *expr; +} + +void SimplifyForLoops(ir::Expr* expr) { + PADDLE_ENFORCE_EQ(expr->node_type(), + ir::Block::_node_type_, + ::common::errors::InvalidArgument( + "The Expr to simplify must be a block.")); + ir::stmt::BlockRef block = ir::ConvertExprBlockToStmtBlock(*expr); + optim::BlockPassManager pass_manager; + pass_manager.AddPass(CreateSimplifyForLoopsPass()); + pass_manager.Run(block); + *expr = ir::ConvertStmtBlockToExprBlock(block); +} + +void SimplifyCast(ir::Expr* expr) { + SimplifyCastMutator mutator; + mutator(expr); +} + +void SimplifyCast(ir::stmt::BlockRef block) { + optim::ExprPassManager pass_manager; + pass_manager.AddPass(CreateSimplifyCastPass()); + pass_manager.Run(block); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/ir_simplify.h b/paddle/cinn/optim/ir_simplify_pass.h similarity index 78% rename from paddle/cinn/optim/ir_simplify.h rename to paddle/cinn/optim/ir_simplify_pass.h index bfb72b738f1b6..1f5d48e4d8e37 100644 --- a/paddle/cinn/optim/ir_simplify.h +++ b/paddle/cinn/optim/ir_simplify_pass.h @@ -14,10 +14,13 @@ #pragma once #include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/stmt.h" namespace cinn { namespace optim { +enum class SimplifyType { kExpr, kBlock }; + /** * Simplify the expression on Cast, Ramp, Load, Store, IfThenElse and Select * operations. @@ -84,7 +87,7 @@ namespace optim { * Output IR: * true_value */ -void Simplify(Expr *expr); +void Simplify(ir::Expr *expr, SimplifyType type); /** * Simplify type casting expressions. @@ -119,7 +122,7 @@ void Simplify(Expr *expr); * Cast(x) (Type mismatch, remains unchanged) * 5.0 (Constant value will be cast) */ -void SimplifyCast(Expr *expr); +void SimplifyCast(ir::Expr *expr); /** * Simplify for loop structures in the IR. @@ -149,41 +152,7 @@ void SimplifyCast(Expr *expr); * Output IR: * for (int i = 0; i < 2; ++i) { doSomething(i); } (remains unchanged) */ -void SimplifyForLoops(Expr *expr); - -/** - * Simplify block structures in the IR. - * - * This pass is applicable in scenarios where blocks contain redundant or nested - * blocks that can be flattened. This is useful in optimizing the structure of - * the IR for better performance. - - * When applied, this pass will recursively check and simplify blocks of three - * kinds: 1) block(s) containing only a single statement or block will be - * replaced by the inner body; 2) nested block will be flattened by - * extracting the child or current statements into current block; 3) iterative - * variables and buffer regions of ScheduleBlock will be replaced by block body - * when the body is single. - - * Performance impact: This pass can improve performance by reducing the - * overhead of block management and enabling better optimization opportunities. - - * Examples: - * 1. Single statement block: - * Input IR: - * Block { Block { stmt0 } } - * Output IR: - * Block { stmt0 } - * - * 2. Nested blocks: - * Input IR: - * Block { Block { stmt1 }, Block { stmt2 }, stmt3 } - * Output IR: - * Block { stmt1, stmt2, stmt3 } - */ -void SimplifyBlocks(Expr *expr); - -void SimplifyLogical(Expr *expr); +void SimplifyForLoops(ir::Expr *expr); Expr ArithSimplify(const Expr &u); } // namespace optim diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc index fec6877220b8c..95ee3e86b20a6 100644 --- a/paddle/cinn/optim/optimize.cc +++ b/paddle/cinn/optim/optimize.cc @@ -27,7 +27,7 @@ #include "paddle/cinn/optim/if_fold_pass.h" #include "paddle/cinn/optim/if_fusion_pass.h" #include "paddle/cinn/optim/insert_debug_log_callee.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/cinn/optim/lower_function_call_bind_vars.h" #include "paddle/cinn/optim/lower_intrin.h" #include "paddle/cinn/optim/map_extern_call.h" @@ -62,7 +62,7 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn, ReplaceConstParamToInteger(&copied->body); // Simplify already contains CastSimplify - Simplify(&copied->body); + Simplify(&copied->body, optim::SimplifyType::kBlock); EliminateInvariantLoop(&copied->body); VLOG(4) << "After Optimize EliminateInvariantLoop:" << copied; ReplaceCrossThreadReduction(copied); @@ -98,9 +98,6 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn, [&](common::HygonDCUArchSYCL) { CINN_NOT_IMPLEMENTED }, [](auto) {}); - SimplifyBlocks(&copied->body); - VLOG(4) << "After SimplifyBlocks:" << copied; - MapExternCall(&copied->body, target); VLOG(10) << "After Optimize MapExternCall:" << copied; @@ -113,8 +110,8 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn, << copied; // Simplify already contains CastSimplify - Simplify(&copied->body); - VLOG(4) << "After Optimize Simplify:" << copied; + Simplify(&copied->body, optim::SimplifyType::kBlock); + VLOG(10) << "After Optimize Simplify:" << copied; BlockPassManager pass_manager; pass_manager.AddPass(CreateIfFusionPass()); @@ -132,7 +129,7 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn, VectorizeForTrans(&copied->body); VLOG(10) << "After Optimize vectorize" << copied; - Simplify(&copied->body); + Simplify(&copied->body, optim::SimplifyType::kBlock); VLOG(10) << "After Optimize Simplify" << copied; pass_manager.AddPass(CreateRemoveScheduleBlockPass()); diff --git a/paddle/cinn/optim/replace_var_with_expr.cc b/paddle/cinn/optim/replace_var_with_expr.cc index 94514ff440f0c..faa8a135cbd62 100644 --- a/paddle/cinn/optim/replace_var_with_expr.cc +++ b/paddle/cinn/optim/replace_var_with_expr.cc @@ -21,7 +21,6 @@ #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/ir/utils/ir_copy.h" -#include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/optim/replace_const_param_to_integer.h" namespace cinn { diff --git a/paddle/cinn/optim/transform_gpu_forloop.cc b/paddle/cinn/optim/transform_gpu_forloop.cc index 4012acb2ca10d..ff26ba70bc8e2 100644 --- a/paddle/cinn/optim/transform_gpu_forloop.cc +++ b/paddle/cinn/optim/transform_gpu_forloop.cc @@ -30,7 +30,6 @@ #include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/ir/utils/stmt_converter.h" #include "paddle/cinn/optim/eliminate_common_factor_of_local_index.h" -#include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/optim/longlong2int_pass.h" #include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/optim/resize_buffer.h" diff --git a/paddle/cinn/optim/transform_polyfor_to_for.cc b/paddle/cinn/optim/transform_polyfor_to_for.cc index 99a145d924ff3..9fa21f0023f28 100644 --- a/paddle/cinn/optim/transform_polyfor_to_for.cc +++ b/paddle/cinn/optim/transform_polyfor_to_for.cc @@ -26,7 +26,7 @@ #include "paddle/cinn/ir/ir_visitor.h" #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/utils/ir_copy.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" namespace cinn { namespace optim { @@ -40,14 +40,14 @@ Expr PlusOneWithMinMax(Expr expr) { if (min_n) { min_n->a() = min_n->a() + 1; min_n->b() = min_n->b() + 1; - Simplify(&min_n->a()); - Simplify(&min_n->b()); + Simplify(&min_n->a(), optim::SimplifyType::kExpr); + Simplify(&min_n->b(), optim::SimplifyType::kExpr); return expr; } else if (max_n) { max_n->a() = max_n->a() + 1; max_n->b() = max_n->b() + 1; - Simplify(&max_n->a()); - Simplify(&max_n->b()); + Simplify(&max_n->a(), optim::SimplifyType::kExpr); + Simplify(&max_n->b(), optim::SimplifyType::kExpr); return expr; } return expr + 1; diff --git a/paddle/cinn/optim/vectorize_loops.cc b/paddle/cinn/optim/vectorize_loops.cc index e6324d0db2e40..abcfaf4d5e26e 100644 --- a/paddle/cinn/optim/vectorize_loops.cc +++ b/paddle/cinn/optim/vectorize_loops.cc @@ -30,7 +30,7 @@ #include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/ir/utils/ir_nodes_collector.h" #include "paddle/cinn/ir/utils/ir_replace.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/cinn/optim/unroll_loops.h" #include "paddle/cinn/utils/functional.h" @@ -713,7 +713,7 @@ struct VectorizeLoops_ : public IRMutator { for (int i = 0; i < indices.size(); i++) { node->indices[i] = cinn::common::AutoSimplify(node->indices[i], var_intervals); - Simplify(&node->indices[i]); + Simplify(&node->indices[i], optim::SimplifyType::kExpr); if (!node->indices[i].same_as(indices[i])) { is_changed = true; } @@ -734,7 +734,7 @@ struct VectorizeLoops_ : public IRMutator { for (int i = 0; i < indices.size(); i++) { node->indices[i] = cinn::common::AutoSimplify(node->indices[i], var_intervals); - Simplify(&node->indices[i]); + Simplify(&node->indices[i], optim::SimplifyType::kExpr); if (!node->indices[i].same_as(indices[i])) { is_changed = true; } @@ -782,7 +782,7 @@ struct VectorizeLoops_ : public IRMutator { ::common::errors::InvalidArgument( "The minimum of forloop should be zero, please check.")); Expr for_extent = cinn::optim::ArithSimplify(forloop->extent); - Simplify(&for_extent); + Simplify(&for_extent, optim::SimplifyType::kExpr); node->extent = for_extent; auto *extent_min = for_extent.As(); auto *extent_max = for_extent.As(); @@ -918,8 +918,8 @@ struct VectorizeLoops_ : public IRMutator { inner_for, ::common::errors::InvalidArgument( "Inner_for is nullptr in UnrollCmpFor function.")); - Expr inner_for_extent = cinn::optim::ArithSimplify(inner_for->extent); - Simplify(&inner_for_extent); + Expr inner_for_extent = cinn::common::AutoSimplify(inner_for->extent); + Simplify(&inner_for_extent, optim::SimplifyType::kExpr); auto *extent_min = inner_for_extent.As(); if (extent_min) { PADDLE_ENFORCE_EQ( @@ -937,7 +937,7 @@ struct VectorizeLoops_ : public IRMutator { if (a_int || b_int) { condition = cinn::common::SolveInequality(LE::Make(a, b), outer_for->loop_var); - Simplify(&condition); + Simplify(&condition, optim::SimplifyType::kExpr); } if (condition.defined()) { auto le_n = condition.As(); @@ -1023,7 +1023,7 @@ struct VectorizeLoops_ : public IRMutator { } else { times = cinn::optim::ArithSimplify( Div::Make(forloop->extent, make_const(factor))); - Simplify(×); + Simplify(×, optim::SimplifyType::kExpr); } // update the current forloop diff --git a/paddle/cinn/poly/dim.cc b/paddle/cinn/poly/dim.cc index e72a3e5ab264c..d96e1c44e9110 100644 --- a/paddle/cinn/poly/dim.cc +++ b/paddle/cinn/poly/dim.cc @@ -15,7 +15,7 @@ #include "paddle/cinn/poly/dim.h" #include "paddle/cinn/ir/ir_printer.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/cinn/utils/string.h" namespace cinn { @@ -30,8 +30,8 @@ std::string Dim::range_repr() const { Dim::Dim(std::string id, ir::Expr lower_bound, ir::Expr upper_bound) : id(std::move(id)), lower_bound(lower_bound), upper_bound(upper_bound) { - optim::Simplify(&this->lower_bound); - optim::Simplify(&this->upper_bound); + optim::Simplify(&this->lower_bound, optim::SimplifyType::kExpr); + optim::Simplify(&this->upper_bound, optim::SimplifyType::kExpr); } } // namespace poly diff --git a/paddle/cinn/poly/stage.cc b/paddle/cinn/poly/stage.cc index 4d6b02cd78a68..0b5531152ab0d 100644 --- a/paddle/cinn/poly/stage.cc +++ b/paddle/cinn/poly/stage.cc @@ -31,7 +31,7 @@ #include "paddle/cinn/ir/utils/ir_nodes_collector.h" #include "paddle/cinn/ir/utils/ir_replace.h" #include "paddle/cinn/lang/compute.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/poly/compute_at_transform.h" #include "paddle/cinn/poly/isl_utils.h" @@ -330,7 +330,7 @@ void Stage::ChangeIndex(Stage *other) { // Return a - b as integer. int Minus(const Expr &a, const Expr &b) { Expr diff = ir::Sub::Make(a, b); - optim::Simplify(&diff); + optim::Simplify(&diff, optim::SimplifyType::kExpr); if (!diff.is_constant()) { LOG(ERROR) << "Range is not constant"; } @@ -553,7 +553,7 @@ void Stage::EditTempTensor(Stage *other, int level) { optim::ReplaceVarWithExpr(&i, dim_var, Expr(j.second)); } i = ir::Add::Make(i, Expr(1)); - optim::Simplify(&i); + optim::Simplify(&i, optim::SimplifyType::kExpr); } // Set new shape. VLOG(3) << "Tensor is : " << this->tensor()->name; diff --git a/paddle/cinn/pybind/optim.cc b/paddle/cinn/pybind/optim.cc index 6baf3cd8cfd91..a0d525b604db9 100755 --- a/paddle/cinn/pybind/optim.cc +++ b/paddle/cinn/pybind/optim.cc @@ -19,7 +19,7 @@ #include "paddle/cinn/common/type.h" #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/utils/ir_copy.h" -#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/ir_simplify_pass.h" #include "paddle/cinn/pybind/bind.h" #include "paddle/cinn/pybind/bind_utils.h" #include "paddle/cinn/utils/string.h" @@ -37,7 +37,7 @@ void BindSimplify(py::module* m) { "simplify", [](const Expr& expr) -> Expr { auto copied = ir::ir_utils::IRCopy(expr); - Simplify(&copied); + Simplify(&copied, optim::SimplifyType::kBlock); return copied; }, py::arg("expr"));