Skip to content

Commit

Permalink
Revert "[CINN] Fix remove reduce one axis in anchor fusion (#70747)" (#…
Browse files Browse the repository at this point in the history
…70775)

* Revert "[CINN] Fix remove reduce one axis in anchor fusion (#70747)"

This reverts commit e5ff051.

* Revert "[CINN] Enhance reduce anchor fusion with different flatten axis (#70665)"

This reverts commit 440570e.
  • Loading branch information
huangjiyi authored Jan 11, 2025
1 parent c18f089 commit 8c78f9d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 71 deletions.
81 changes: 32 additions & 49 deletions paddle/cinn/operator_fusion/policy/iters_fusion_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,37 +154,14 @@ std::optional<ItersTransform> ItersFusionPolicy::GetReuseItersTransform(
}
}

std::optional<ItersTransform> ItersFusionPolicy::GetAppendItersTransform(
FusionIters* source_iters, const FusionIters& target_iters) {
const auto target_unique_iters =
GatherFirstNotInSecond(target_iters, *source_iters);
if (!target_unique_iters.empty()) {
if (!transform_strategy_[ItersTransformType::AppendIters] ||
!FLAGS_enable_append_iters_in_fusion) {
VLOG(4) << "Can not append iters in fusion, because of AppendIters "
"tranform is disabled.";
return std::nullopt;
}
std::vector<int32_t> append_axis;
std::vector<symbol::DimExpr> append_symbols;
for (const auto& iter : target_unique_iters) {
const size_t pos =
std::find(target_iters.begin(), target_iters.end(), iter) -
target_iters.begin();
append_axis.push_back(pos);
append_symbols.push_back(iters_manager_->GetIterSymbol(iter));
source_iters->insert(source_iters->begin() + pos, iter);
}
return AppendItersTransform(append_axis, append_symbols);
}
return IdentityItersTransform();
}

std::optional<ItersTransformRoute>
ItersFusionPolicy::SearchTransformRouteFromReduce2Reduce(
const FusionItersSignature& source, const FusionItersSignature& target) {
VLOG(4) << "Start search transform Route from reduce to reduce.";
if (source.reduce_iter_nums == target.reduce_iter_nums) {
if (source.loop_iters.size() == target.loop_iters.size() &&
source.reduce_iter_nums == target.reduce_iter_nums) {
// Currently only support fusion with same iter_nums and same reduce axis
// TODO(huangjiyi): Analysis fusion with different non reduce axis
auto [source_flatten_iters, source_reduce_iters] = SplitReduceIters(source);
auto [target_flatten_iters, target_reduce_iters] = SplitReduceIters(target);

Expand All @@ -209,15 +186,6 @@ ItersFusionPolicy::SearchTransformRouteFromReduce2Reduce(
route.push_back(flatten_reuse_iters_transform.value());
route.push_back(reduce_reuse_iters_transform.value());

// 2. Apply AppendItersTransform for flatten iters
const auto flatten_append_iters_transform =
GetAppendItersTransform(&source_flatten_iters, target_flatten_iters);
if (flatten_append_iters_transform == std::nullopt) {
return std::nullopt;
} else {
route.push_back(flatten_append_iters_transform.value());
}

// 2. Apply TransposeItersTransform
if (source_flatten_iters == target_flatten_iters &&
source_reduce_iters == target_reduce_iters) {
Expand Down Expand Up @@ -257,13 +225,12 @@ std::optional<ItersTransformRoute> ItersFusionPolicy::SearchItersTransformRoute(
auto squeezed_source = source;
if (squeeze_source) {
// Remove iters equal to one in source
std::vector<int> source_ones;
for (int i = 0; i < source.loop_iters.size() - source.reduce_iter_nums;
++i) {
if (iters_manager_->IterSymbolEqualOne(source.loop_iters[i])) {
source_ones.push_back(i);
}
}
auto source_ones = MapVectorIfTrue<std::pair<std::string, int>, int>(
Enumerate(source.loop_iters),
[this](std::pair<std::string, int> p) { return p.second; },
[this](std::pair<std::string, int> p) {
return this->iters_manager_->IterSymbolEqualOne(p.first);
});
if (!source_ones.empty() &&
source_ones.size() != source.loop_iters.size()) {
iters_transforms.emplace_back(RemoveOnesTransform(source_ones));
Expand Down Expand Up @@ -350,12 +317,28 @@ std::optional<ItersTransformRoute> ItersFusionPolicy::SearchItersTransformRoute(
// 3. Apply AppendItersTransform
// if exist iters in target can not find in source
FusionIters appended_source_iters = reused_source_iters;
const auto append_iters_transform =
GetAppendItersTransform(&appended_source_iters, target_iters);
if (append_iters_transform == std::nullopt) {
return std::nullopt;
} else {
iters_transforms.push_back(append_iters_transform.value());
if (!reused_target_unique_iters.empty()) {
if (!transform_strategy_[ItersTransformType::AppendIters] ||
!FLAGS_enable_append_iters_in_fusion) {
VLOG(4) << "Can not append iters in fusion, because of AppendIters "
"tranform is disabled.";
return std::nullopt;
}
std::vector<int32_t> append_axis;
std::vector<symbol::DimExpr> append_symbols;
for (const auto& iter : reused_target_unique_iters) {
const size_t pos =
std::find(target_iters.begin(), target_iters.end(), iter) -
target_iters.begin();
append_axis.push_back(pos);
append_symbols.push_back(iters_manager_->GetIterSymbol(iter));
appended_source_iters.insert(appended_source_iters.begin() + pos, iter);
}
iters_transforms.push_back(
AppendItersTransform(append_axis, append_symbols));
if (appended_source_iters == target_iters) {
return iters_transforms;
}
}
VLOG(4) << "source iters after reuse and append: "
<< PrintFusionIters(appended_source_iters);
Expand Down
2 changes: 0 additions & 2 deletions paddle/cinn/operator_fusion/policy/iters_fusion_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ struct ItersFusionPolicy final : public PolicyBase {
private:
std::optional<ItersTransform> GetReuseItersTransform(
FusionIters* source_iters, const FusionIters& target_iters);
std::optional<ItersTransform> GetAppendItersTransform(
FusionIters* source_iters, const FusionIters& target_iters);
std::optional<ItersTransformRoute> SearchTransformRouteFromReduce2Reduce(
const FusionItersSignature& source, const FusionItersSignature& target);
std::optional<ItersTransformRoute> SearchItersTransformRoute(
Expand Down
20 changes: 0 additions & 20 deletions test/ir/pir/cinn/test_reduce_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,26 +197,6 @@ def init():

self.check_accuracy_and_kernel_num(init, func)

def test_reduce_anchor_fusion(self):
# T
# / \
# R --> T
# / \
# R --> T
def func(x):
x = x + 1
a = paddle.max(x, axis=-1, keepdim=True)
b = x + a
c = paddle.max(b, axis=-1, keepdim=True)
d = c + b
return d

def init():
x = paddle.rand((1, 32, 4, 8), dtype='float32')
return (x,)

self.check_accuracy_and_kernel_num(init, func, kernel_num=1)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8c78f9d

Please sign in to comment.