diff --git a/xla/service/gpu/model/BUILD b/xla/service/gpu/model/BUILD index 31125e5145aa2..49f076cd14022 100644 --- a/xla/service/gpu/model/BUILD +++ b/xla/service/gpu/model/BUILD @@ -686,6 +686,7 @@ cc_library( "//xla/service/gpu/fusions:fusion_emitter", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", diff --git a/xla/service/gpu/model/coalescing_analysis.cc b/xla/service/gpu/model/coalescing_analysis.cc index 7392ef4a186fd..83fe09366d8b3 100644 --- a/xla/service/gpu/model/coalescing_analysis.cc +++ b/xla/service/gpu/model/coalescing_analysis.cc @@ -18,12 +18,15 @@ limitations under the License. #include #include #include +#include #include +#include #include #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "mlir/IR/AffineExpr.h" // from @llvm-project @@ -95,95 +98,15 @@ bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, namespace { -using mlir::AffineExpr; -using mlir::AffineMap; -using mlir::getAffineConstantExpr; -using mlir::MLIRContext; - -// Performs backtracking to find all feasible dimensions, symbols that satisfy -// the constraints and then evaluates the affine map at those. -// For example, for the following indexing map: -// (d0)[s0] -> (d0 + s0) -// domain: -// d0 in [0, 3] -// s0 in [0, 1, 2] -// s0 mod 2 in [0, 0] -// The function will compute the following indices [0, 2, 1, 3, 2, 4, 3, 5]. -void FindAllIndices(const IndexingMap& thread_id_to_physical_index, - MLIRContext* mlir_context, int dim_id, int symbol_id, - std::vector* dimensions, - std::vector* symbols, - std::vector* indices) { - if (dim_id < thread_id_to_physical_index.GetDimensionCount()) { - Interval dim_range = thread_id_to_physical_index.GetDimensionBound(dim_id); - for (int64_t dim_value = dim_range.lower; dim_value <= dim_range.upper; - ++dim_value) { - dimensions->push_back(getAffineConstantExpr(dim_value, mlir_context)); - FindAllIndices(thread_id_to_physical_index, mlir_context, dim_id + 1, - symbol_id, dimensions, symbols, indices); - dimensions->pop_back(); - } - return; - } - if (symbol_id < thread_id_to_physical_index.GetRangeVarsCount()) { - Interval symbol_range = - thread_id_to_physical_index.GetSymbolBound(symbol_id); - for (int64_t symbol_value = symbol_range.lower; - symbol_value <= symbol_range.upper; ++symbol_value) { - symbols->push_back(getAffineConstantExpr(symbol_value, mlir_context)); - FindAllIndices(thread_id_to_physical_index, mlir_context, dim_id, - symbol_id + 1, dimensions, symbols, indices); - symbols->pop_back(); - } - return; - } - if (!thread_id_to_physical_index.ConstraintsSatisfied(*dimensions, - *symbols)) { - return; - } - indices->push_back( - thread_id_to_physical_index.Evaluate(*dimensions, *symbols).front()); -} - -// Computes contiguous intervals of accessed elements. -// For example, for an indexing map -// (thread_x) -> (thread_x * 4 + s0 + (thread_x floordiv 16) * 1984) -// d0 in [0, 31] -// s0 in [0, 3] -// The intervals are [0, 63] and [2047, 2111]. -// TODO(b/325613460): Make it faster than O(number of elements in the domain). -std::vector FindContiguousIntervals( - const IndexingMap& thread_id_to_physical_index) { - CHECK(thread_id_to_physical_index.GetAffineMap().getNumResults() == 1) - << "Expects an affine map that maps to 1D."; - MLIRContext* mlir_context = thread_id_to_physical_index.GetMLIRContext(); - - // Find all linear indices, sort and deduplicate them. - std::vector dimensions, symbols; - std::vector linear_indices; - FindAllIndices(thread_id_to_physical_index, mlir_context, - /*dim_id=*/0, - /*symbol_id=*/0, &dimensions, &symbols, &linear_indices); - std::sort(linear_indices.begin(), linear_indices.end()); - linear_indices.erase( - std::unique(linear_indices.begin(), linear_indices.end()), - linear_indices.end()); - - // Scan over the sorted unique indices and combine them in intervals. - std::vector intervals; - for (int i = 0, start, end; i < linear_indices.size(); ++i) { - start = linear_indices[i++]; - end = start; - while (i < linear_indices.size() && linear_indices[i] == end + 1) { - ++end; - ++i; - } - intervals.push_back(Interval{start, end}); - } - return intervals; -} - -int64_t CeilDiv(int64_t a, int64_t b) { return a / b + (a % b != 0); } +using ::mlir::AffineBinaryOpExpr; +using ::mlir::AffineConstantExpr; +using ::mlir::AffineDimExpr; +using ::mlir::AffineExpr; +using ::mlir::AffineExprKind; +using ::mlir::AffineMap; +using ::mlir::AffineSymbolExpr; +using ::mlir::getAffineConstantExpr; +using ::mlir::MLIRContext; // Approximately estimate the number of memory transactions needed to load all // elements in every range and compare it with the "ideal" number of memory @@ -212,63 +135,6 @@ bool EstimateCoalescingViaMemoryTransactionsCount( memory_transactions * kIsCoalescedThreshold; } -bool IsCoalesced(const IndexingMap& thread_id_to_input_indexing_map, - PrimitiveType element_type) { - // Undefined indexing maps, i.e. those for which we don't know the indexing - // are assumed to be uncoalesced. - if (thread_id_to_input_indexing_map.IsUndefined()) { - return false; - } - // 0d constants are coalesced. - if (thread_id_to_input_indexing_map.GetAffineMap().getNumResults() == 0) { - return true; - } - MLIRContext* mlir_context = thread_id_to_input_indexing_map.GetMLIRContext(); - AffineExpr thread_x_dim = mlir::getAffineDimExpr( - KernelFusionInterface::kIndexingMapThreadIdxDims[0], mlir_context); - AffineExpr c0 = mlir::getAffineConstantExpr(0, mlir_context); - IndexingMap thread_x_first_32_elements{ - AffineMap::get(1, 0, {thread_x_dim, c0, c0, c0, c0, c0}, mlir_context), - {DimVar{{0, 31}}}, - /*range_vars=*/{}, - /*rt_vars=*/{}}; - IndexingMap thread_x_to_linearized_input = - thread_x_first_32_elements * thread_id_to_input_indexing_map; - - // If RTVars are present, replace them with constants. - if (thread_x_to_linearized_input.GetRTVarsCount() > 0) { - llvm::SmallVector symbol_replacements; - for (int64_t symbol_id = 0; - symbol_id < thread_x_to_linearized_input.GetRangeVarsCount(); - ++symbol_id) { - symbol_replacements.push_back( - mlir::getAffineSymbolExpr(symbol_id, mlir_context)); - } - for (const RTVar& rt_var : thread_x_to_linearized_input.GetRTVars()) { - // Take midpoint of the feasible interval for the RT variable. - symbol_replacements.push_back(getAffineConstantExpr( - (rt_var.feasible_values.lower + rt_var.feasible_values.upper) / 2, - mlir_context)); - } - AffineMap thread_x_to_input_no_rt_symbols = - thread_x_to_linearized_input.GetAffineMap().replaceDimsAndSymbols( - {}, symbol_replacements, - thread_x_to_linearized_input.GetDimVarsCount(), - thread_x_to_linearized_input.GetRangeVarsCount()); - thread_x_to_linearized_input = IndexingMap{ - thread_x_to_input_no_rt_symbols, - thread_x_to_linearized_input.GetDimVars(), - thread_x_to_linearized_input.GetRangeVars(), - thread_x_to_linearized_input.GetRTVars(), - }; - } - thread_x_to_linearized_input.Simplify(GetIndexingMapForInstruction); - thread_x_to_linearized_input.RescaleSymbols(); - thread_x_to_linearized_input.RemoveUnusedSymbols(); - return EstimateCoalescingViaMemoryTransactionsCount( - FindContiguousIntervals(thread_x_to_linearized_input), element_type); -} - // Returns a linearized shape, i.e. tensor. Shape GetLinearizedShape(const Shape& shape) { if (shape.rank() == 0) { @@ -287,7 +153,7 @@ std::optional GetThreadIdToInputMemoryLayoutsMaps( const HloFusionAdaptor& fusion_adaptor, absl::Span operands, const HloFusionAnalysis& fusion_analysis, - KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context) { + KernelFusionInterface* fusion_interface, MLIRContext* mlir_context) { GroupedByOpIndexingMap result; for (const auto& [root_index, hero] : llvm::enumerate(fusion_analysis.fusion_heroes())) { @@ -356,13 +222,358 @@ std::optional GetThreadIdToInputMemoryLayoutsMaps( return result; } +// Replaces RTVars with the midpoints of the feasible intervals. +void AssignValuesToRTVars(IndexingMap* indexing_map) { + // If RTVars are present, replace them with constants. + if (indexing_map->GetRTVarsCount() == 0) { + return; + } + MLIRContext* mlir_context = indexing_map->GetMLIRContext(); + llvm::SmallVector symbol_replacements; + for (int64_t symbol_id = 0; symbol_id < indexing_map->GetRangeVarsCount(); + ++symbol_id) { + symbol_replacements.push_back( + mlir::getAffineSymbolExpr(symbol_id, mlir_context)); + } + for (const RTVar& rt_var : indexing_map->GetRTVars()) { + // Take midpoint of the feasible interval for the RT variable. + symbol_replacements.push_back(getAffineConstantExpr( + (rt_var.feasible_values.lower + rt_var.feasible_values.upper) / 2, + mlir_context)); + } + AffineMap thread_x_to_input_no_dim_symbols = + indexing_map->GetAffineMap().replaceDimsAndSymbols( + {}, symbol_replacements, indexing_map->GetDimVarsCount(), + indexing_map->GetRangeVarsCount()); + *indexing_map = IndexingMap{thread_x_to_input_no_dim_symbols, + indexing_map->GetDimVars(), + indexing_map->GetRangeVars(), + {}}; + indexing_map->Simplify(GetIndexingMapForInstruction); + indexing_map->RemoveUnusedSymbols(); +} + +// Replaces all but one RangeVars with the first elements in the range. +// At the moment, we assume that the last RangeVar symbol corresponds to the +// innermost loop induction variable. +void AssignValuesToOuterLoopIVs(IndexingMap* indexing_map) { + if (indexing_map->GetRangeVarsCount() <= 1) { + return; + } + MLIRContext* mlir_context = indexing_map->GetMLIRContext(); + llvm::SmallVector symbol_replacements; + for (const RangeVar& range_var : indexing_map->GetRangeVars()) { + symbol_replacements.push_back( + getAffineConstantExpr(range_var.range.lower, mlir_context)); + } + symbol_replacements.push_back(mlir::getAffineSymbolExpr( + indexing_map->GetRangeVarsCount() - 1, mlir_context)); + + AffineMap thread_x_to_input_no_dim_symbols = + indexing_map->GetAffineMap().replaceDimsAndSymbols( + {}, symbol_replacements, indexing_map->GetDimVarsCount(), 1); + *indexing_map = IndexingMap{thread_x_to_input_no_dim_symbols, + indexing_map->GetDimVars(), + indexing_map->GetRangeVars(), + {}}; + indexing_map->Simplify(GetIndexingMapForInstruction); + indexing_map->RemoveUnusedSymbols(); +} + +// Result of partitioning of AffineExpr f(d0) + g(s0) into the summands. +struct PartitionedExpr { + explicit PartitionedExpr(MLIRContext* mlir_context) { + AffineExpr zero = getAffineConstantExpr(0, mlir_context); + func_of_d0 = zero; + func_of_s0 = zero; + } + AffineExpr func_of_d0; + AffineExpr func_of_s0; +}; + +// Given an AffineExpr that depends on d0 and s0, attempts to split it into +// f(d0) + g(s0). If it is not possible, returns std::nullopt. +std::optional Partition(AffineExpr expr) { + PartitionedExpr result(expr.getContext()); + + std::vector summands; + std::stack dfs; + dfs.push(expr); + while (!dfs.empty()) { + auto top = dfs.top(); + dfs.pop(); + auto sum = mlir::dyn_cast(top); + if (sum && sum.getKind() == AffineExprKind::Add) { + dfs.push(sum.getLHS()); + dfs.push(sum.getRHS()); + continue; + } + bool depends_on_thread_x = top.isFunctionOfDim(0); + bool depends_on_range = top.isFunctionOfSymbol(0); + + if (depends_on_thread_x && depends_on_range) { + return std::nullopt; + } + if (depends_on_thread_x) { + result.func_of_d0 = top + result.func_of_d0; + } + if (depends_on_range) { + result.func_of_s0 = top + result.func_of_s0; + } + } + return result; +} + +// Given an AffineExpr and the values for its dimensions and symbols, evaluates +// the result. +int64_t EvaluateAffineExpr(AffineExpr expr, + const std::vector& dim_values, + const std::vector& symbol_values = {}) { + if (auto const_expr = mlir::dyn_cast(expr)) { + return const_expr.getValue(); + } + if (auto dim_expr = mlir::dyn_cast(expr)) { + return dim_values[dim_expr.getPosition()]; + } + if (auto symbol_expr = mlir::dyn_cast(expr)) { + return symbol_values[symbol_expr.getPosition()]; + } + auto binary_expr = mlir::cast(expr); + int64_t lhs = + EvaluateAffineExpr(binary_expr.getLHS(), dim_values, symbol_values); + int64_t rhs = + EvaluateAffineExpr(binary_expr.getRHS(), dim_values, symbol_values); + switch (binary_expr.getKind()) { + case AffineExprKind::Add: + return lhs + rhs; + case AffineExprKind::Mul: + return lhs * rhs; + case AffineExprKind::FloorDiv: + return FloorDiv(lhs, rhs); + case AffineExprKind::Mod: + return lhs % rhs; + default: + LOG(FATAL) << "Unsupported expression"; + } +} + +// Performs backtracking to find all feasible dimensions, symbols that satisfy +// the constraints and then evaluates the affine map at those. +// For example, for the following indexing map: +// (d0)[s0] -> (d0 + s0) +// domain: +// d0 in [0, 3] +// s0 in [0, 1, 2] +// s0 mod 2 in [0, 0] +// The function will compute the following indices [0, 2, 1, 3, 2, 4, 3, 5]. +void FindAllIndices(AffineExpr expr, int dim_id, int symbol_id, + const std::vector& dimension_ranges, + const std::vector& symbol_ranges, + std::vector* dimensions, + std::vector* symbols, + std::vector* indices) { + if (dim_id < dimension_ranges.size()) { + Interval dim_range = dimension_ranges[dim_id]; + for (int64_t dim_value = dim_range.lower; dim_value <= dim_range.upper; + ++dim_value) { + dimensions->push_back(dim_value); + FindAllIndices(expr, dim_id + 1, symbol_id, dimension_ranges, + symbol_ranges, dimensions, symbols, indices); + dimensions->pop_back(); + } + return; + } + if (symbol_id < symbol_ranges.size()) { + Interval symbol_range = symbol_ranges[symbol_id]; + for (int64_t symbol_value = symbol_range.lower; + symbol_value <= symbol_range.upper; ++symbol_value) { + symbols->push_back(symbol_value); + FindAllIndices(expr, dim_id, symbol_id + 1, dimension_ranges, + symbol_ranges, dimensions, symbols, indices); + symbols->pop_back(); + } + return; + } + indices->push_back(EvaluateAffineExpr(expr, *dimensions, *symbols)); +} + +// Computes contiguous intervals of accessed elements. +// For example, for an indexing map +// (thread_x) -> (thread_x * 4 + s0 + (thread_x floordiv 16) * 1984) +// d0 in [0, 31] +// s0 in [0, 3] +// The intervals are [0, 63] and [2047, 2111]. +std::vector FindIntervals( + AffineExpr expr, const std::vector& dimension_ranges, + const std::vector& symbol_ranges = {}) { + // Find all linear indices, sort and deduplicate them. + std::vector dimensions, symbols; + std::vector linear_indices; + FindAllIndices(expr, 0, 0, dimension_ranges, symbol_ranges, &dimensions, + &symbols, &linear_indices); + + std::sort(linear_indices.begin(), linear_indices.end()); + linear_indices.erase( + std::unique(linear_indices.begin(), linear_indices.end()), + linear_indices.end()); + + // Scan over the sorted unique indices and combine them in intervals. + std::vector intervals; + for (int i = 0, start, end; i < linear_indices.size();) { + start = linear_indices[i++]; + end = start; + while (i < linear_indices.size() && linear_indices[i] == end + 1) { + ++end; + ++i; + } + intervals.push_back(Interval{start, end}); + } + return intervals; +} + +// Given a vector of interval [lb, ub] computes intervals [lb, ub + length] and +// then computes union of contiguous intervals. +std::vector ExtendIntervals(const std::vector& intervals, + int64_t length) { + // Compute union of overlapped intervals. + std::vector overlapped_intervals; + for (int i = 0; i < intervals.size();) { + int64_t lower = intervals[i].lower; + int64_t upper = intervals[i].upper + length; + ++i; + while (i < intervals.size() && upper >= intervals[i].lower - 1) { + upper = std::max(upper, intervals[i].upper + length); + ++i; + } + overlapped_intervals.push_back(Interval{lower, upper}); + } + return overlapped_intervals; +} + +// Computes contiguous intervals, for the expression of type f(thread_x) + g(s). +std::vector FindContiguousIntervals( + const PartitionedExpr& partitioned_expr, const IndexingMap& indexing_map) { + constexpr int64_t kNumThreadsPerWarp = 32; + MLIRContext* mlir_context = indexing_map.GetMLIRContext(); + AffineExpr thread_x = mlir::getAffineDimExpr(0, mlir_context); + AffineExpr range = mlir::getAffineSymbolExpr(0, mlir_context); + + // Case 1: f(thread_x) = thread_x * multiplier. + // Case 1.1: multiplier == 1. + if (partitioned_expr.func_of_d0 == thread_x) { + return {Interval{0, kNumThreadsPerWarp - 1}}; + } + if (auto mul = + mlir::dyn_cast(partitioned_expr.func_of_d0); + mul && mul.getKind() == AffineExprKind::Mul) { + if (auto multiplier = mlir::dyn_cast(mul.getRHS()); + multiplier) { + // Case 1.2: multiplier == -1. + if (multiplier.getValue() == -1) { + return {Interval{0, kNumThreadsPerWarp - 1}}; + } + // Case 1.3: |multiplier| != 1 and g(s) = s. + if (partitioned_expr.func_of_s0 == range) { + Interval range_interval = indexing_map.GetSymbolBound(0); + int64_t num_elems = range_interval.NumElements(); + // In this case we get a single interval, because the ranges that every + // thread is reading overlap. + if (num_elems >= std::abs(multiplier.getValue())) { + return {Interval{0, multiplier.getValue() * (kNumThreadsPerWarp - 1) + + num_elems - 1}}; + } + std::vector intervals; + for (int i = 0, dm = 0; i < kNumThreadsPerWarp; + ++i, dm += multiplier.getValue()) { + intervals.push_back( + {range_interval.lower + dm, range_interval.upper + dm}); + } + return intervals; + } + // Case 1.4: |multiplier| != 1 and g(s) != s. + std::vector intervals; + for (int i = 0, dm = 0; i < kNumThreadsPerWarp; + ++i, dm += multiplier.getValue()) { + intervals.push_back({dm, dm}); + } + return intervals; + } + } + // Case 2: f(thread_x) != thread_x * multiplier. + auto intervals = FindIntervals(partitioned_expr.func_of_d0, + {indexing_map.GetDimVars(0).bounds}); + // Case 2.1: g(s) != s. + if (partitioned_expr.func_of_s0 != range) { + return intervals; + } + // Case 2.2: g(s) = s. + Interval range_interval = indexing_map.GetSymbolBound(0); + return ExtendIntervals(intervals, range_interval.NumElements() - 1); +} + +bool IsIndexingCoalesced(IndexingMap& thread_x_to_linearized_input, + PrimitiveType element_type) { + // Undefined indexing maps, i.e. those for which we don't know the indexing + // are assumed to be uncoalesced. + if (thread_x_to_linearized_input.IsUndefined()) { + return false; + } + // 0d constants are coalesced. + if (thread_x_to_linearized_input.GetAffineMap().getNumResults() == 0) { + return true; + } + // Replace RTVars with the feasible values. + AssignValuesToRTVars(&thread_x_to_linearized_input); + + // Compute the indexing map for the first [0, 31] threads. This should be + // extended to sampling several warps. + MLIRContext* mlir_context = thread_x_to_linearized_input.GetMLIRContext(); + AffineExpr thread_x_dim = mlir::getAffineDimExpr( + KernelFusionInterface::kIndexingMapThreadIdxDims[0], mlir_context); + AffineExpr c0 = getAffineConstantExpr(0, mlir_context); + IndexingMap thread_x_first_32_elements{ + AffineMap::get(1, 0, {thread_x_dim, c0, c0, c0, c0, c0}, mlir_context), + {DimVar{{0, 31}}}, + /*range_vars=*/{}, + /*rt_vars=*/{}}; + IndexingMap thread_x_to_input_sample = + thread_x_first_32_elements * thread_x_to_linearized_input; + thread_x_to_input_sample.Simplify(GetIndexingMapForInstruction); + thread_x_to_input_sample.RescaleSymbols(); + thread_x_to_input_sample.RemoveUnusedSymbols(); + + // If the indexing map is "empty", then the input is not used in this warp, + // therefore, it's coalesced. + if (thread_x_to_input_sample.IsKnownEmpty()) { + return true; + } + AssignValuesToOuterLoopIVs(&thread_x_to_input_sample); + auto partitioned_expr = + Partition(thread_x_to_input_sample.GetAffineMap().getResult(0)); + if (!partitioned_expr.has_value()) { + return false; + } + // Right now we support only thread_x maps what do not have any constraints or + // have a single constraint that coincides with + // thread_x_to_input_sample.getAffineMap(). + if (thread_x_to_input_sample.GetConstraintsCount() > 1 || + (thread_x_to_input_sample.GetConstraintsCount() == 1 && + thread_x_to_input_sample.GetConstraints().begin()->first != + partitioned_expr->func_of_d0 + partitioned_expr->func_of_s0)) { + return false; + } + return EstimateCoalescingViaMemoryTransactionsCount( + FindContiguousIntervals(*partitioned_expr, thread_x_to_input_sample), + element_type); +} + } // namespace CoalescingAnalysis::CoalescingAnalysis( const HloInstruction* instr, absl::Span operands, const HloFusionAnalysis& fusion_analysis, - KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context, + KernelFusionInterface* fusion_interface, MLIRContext* mlir_context, bool use_heuristic) { auto fusion_adaptor = HloFusionAdaptor::ForInstruction(instr); if (!use_heuristic && ComputeCoalescingForAllOperands( @@ -379,7 +590,7 @@ CoalescingAnalysis::CoalescingAnalysis( const HloInstruction* producer, const HloInstruction* consumer, absl::Span operands, const HloFusionAnalysis& fusion_analysis, - KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context, + KernelFusionInterface* fusion_interface, MLIRContext* mlir_context, bool use_heuristic) { ProducerConsumerFusion fusion_adaptor(producer, consumer); if (!use_heuristic && @@ -396,7 +607,7 @@ bool CoalescingAnalysis::ComputeCoalescingForAllOperands( const HloFusionAdaptor& fusion_adaptor, absl::Span operands, const HloFusionAnalysis& fusion_analysis, - KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context) { + KernelFusionInterface* fusion_interface, MLIRContext* mlir_context) { std::optional thread_id_to_input_memory_layouts = GetThreadIdToInputMemoryLayoutsMaps(fusion_adaptor, operands, fusion_analysis, fusion_interface, @@ -417,10 +628,9 @@ bool CoalescingAnalysis::ComputeCoalescingForAllOperands( coalescing_per_operand_.insert({operand, true}); continue; } - for (const IndexingMap& operand_indexing_map : - operand_indexing_maps->second) { - bool is_coalesced = - IsCoalesced(operand_indexing_map, operand->shape().element_type()); + for (IndexingMap operand_indexing_map : operand_indexing_maps->second) { + bool is_coalesced = IsIndexingCoalesced(operand_indexing_map, + operand->shape().element_type()); auto [it, inserted] = coalescing_per_operand_.insert({operand, is_coalesced}); if (!inserted) { diff --git a/xla/service/gpu/model/coalescing_analysis.h b/xla/service/gpu/model/coalescing_analysis.h index 300036aa453ba..f65f4c8b8ddd4 100644 --- a/xla/service/gpu/model/coalescing_analysis.h +++ b/xla/service/gpu/model/coalescing_analysis.h @@ -23,6 +23,7 @@ limitations under the License. #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/indexing_map.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/model/coalescing_analysis_test.cc b/xla/service/gpu/model/coalescing_analysis_test.cc index 9e0251cf01d73..63c4dba8c5b48 100644 --- a/xla/service/gpu/model/coalescing_analysis_test.cc +++ b/xla/service/gpu/model/coalescing_analysis_test.cc @@ -372,7 +372,7 @@ TEST_F(CoalescingTest, VariadicReduceViaLoopEmitter) { // Operands 1, 2: (d0)[s0] -> ((d0 floordiv 4) * 40 + d0 mod 4 + s0 * 4) // for s0 in [0, 9]. EXPECT_THAT(IsReadCoalescedPerOperand(ir), - ElementsAre(true, true, true, true)); + ElementsAre(false, false, true, true)); } TEST_F(CoalescingTest, VariadicReduceViaReductionEmitter) { diff --git a/xla/service/gpu/model/indexing_map.cc b/xla/service/gpu/model/indexing_map.cc index 3780fd2851399..eef79d639e777 100644 --- a/xla/service/gpu/model/indexing_map.cc +++ b/xla/service/gpu/model/indexing_map.cc @@ -25,7 +25,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/base/optimization.h" @@ -62,16 +61,6 @@ using mlir::getAffineBinaryOpExpr; using mlir::getAffineConstantExpr; using mlir::MLIRContext; -int64_t FloorDiv(int64_t dividend, int64_t divisor) { - return dividend / divisor - - (((dividend >= 0) != (divisor >= 0) && dividend % divisor) ? 1 : 0); -} - -int64_t CeilDiv(int64_t dividend, int64_t divisor) { - return dividend / divisor + - (((dividend >= 0) == (divisor >= 0) && dividend % divisor) ? 1 : 0); -} - class AffineExprSimplifier { public: explicit AffineExprSimplifier(RangeEvaluator* range_evaluator) @@ -626,6 +615,16 @@ SmallVector MapSymbolsToComposedSymbolsList( } // namespace +int64_t FloorDiv(int64_t dividend, int64_t divisor) { + return dividend / divisor - + (((dividend >= 0) != (divisor >= 0) && dividend % divisor) ? 1 : 0); +} + +int64_t CeilDiv(int64_t dividend, int64_t divisor) { + return dividend / divisor + + (((dividend >= 0) == (divisor >= 0) && dividend % divisor) ? 1 : 0); +} + std::string Interval::ToString() const { std::stringstream ss; Print(ss); diff --git a/xla/service/gpu/model/indexing_map.h b/xla/service/gpu/model/indexing_map.h index bfc8abf30bdd3..f8994f9000221 100644 --- a/xla/service/gpu/model/indexing_map.h +++ b/xla/service/gpu/model/indexing_map.h @@ -44,6 +44,7 @@ struct Interval { void Print(std::ostream& out) const; bool IsPoint() const { return lower == upper; } + int64_t NumElements() const { return upper - lower + 1; } bool Contains(int64_t value) const { return value >= lower && value <= upper; @@ -370,6 +371,9 @@ H AbslHashValue(H h, const IndexingMap& indexing_map) { indexing_map.GetConstraintsCount()); } +int64_t FloorDiv(int64_t dividend, int64_t divisor); +int64_t CeilDiv(int64_t dividend, int64_t divisor); + } // namespace gpu } // namespace xla