diff --git a/xla/service/gpu/fusions/reduction_base_test.cc b/xla/service/gpu/fusions/reduction_base_test.cc index 77d9a8c52bec1..f0e2d914a3e8c 100644 --- a/xla/service/gpu/fusions/reduction_base_test.cc +++ b/xla/service/gpu/fusions/reduction_base_test.cc @@ -98,7 +98,6 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { s0 in [0, 0] s1 in [0, 0] s2 in [0, 15] - 0 in [0, 0] d0 mod 32 + s2 * 32 in [0, 511] d3 * 8 + d0 floordiv 32 in [0, 6399] )")); @@ -166,7 +165,6 @@ TEST_F(ReductionTest, ThreadIndexingMultiRowReduction) { s0 in [0, 0] s1 in [0, 0] s2 in [0, 0] - 0 in [0, 0] d0 mod 4 in [0, 3] d3 * 64 + d0 floordiv 4 in [0, 6399] )")); @@ -336,7 +334,6 @@ TEST_F(ReductionTest, ThreadIndexingSideOutput) { s0 in [0, 0] s1 in [0, 0] s2 in [0, 15] - 0 in [0, 0] d0 mod 32 + s2 * 32 in [0, 511] d3 * 8 + d0 floordiv 32 in [0, 6399] )"; @@ -391,7 +388,6 @@ TEST_F(ReductionTest, bla) { s1 in [0, 0] s2 in [0, 7] s3 in [0, 1] - 0 in [0, 0] d0 + s2 * 512 in [0, 4095] )")); } diff --git a/xla/service/gpu/model/indexing_map.cc b/xla/service/gpu/model/indexing_map.cc index 610f7163be447..3780fd2851399 100644 --- a/xla/service/gpu/model/indexing_map.cc +++ b/xla/service/gpu/model/indexing_map.cc @@ -744,6 +744,12 @@ void IndexingMap::AddConstraint(mlir::AffineExpr expr, Interval range) { current_range = Intersect(current_range, range); return; } + if (auto constant_expr = mlir::dyn_cast(expr)) { + if (constant_expr.getValue() >= range.lower && + constant_expr.getValue() <= range.upper) { + return; + } + } if (SimplifyConstraintRange(&expr, &range)) { AddConstraint(expr, range); return; diff --git a/xla/service/gpu/model/indexing_map_test.cc b/xla/service/gpu/model/indexing_map_test.cc index 07587fbc948aa..87d06fdd8edb9 100644 --- a/xla/service/gpu/model/indexing_map_test.cc +++ b/xla/service/gpu/model/indexing_map_test.cc @@ -222,6 +222,32 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { )")); } +TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintIsAConstantWithinRange) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0)", &mlir_context_), {50}, {}); + indexing_map.AddConstraint(ParseAffineExpr("0", &mlir_context_), + Interval{-10, 5}); + EXPECT_THAT(indexing_map, MatchIndexingMap(R"( + (d0) -> (d0) + domain: + d0 in [0, 49] + )")); +} + +TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintIsAConstantOutOfRange) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0)", &mlir_context_), {50}, {}); + // Addition of this constraint makes the domain empty. + indexing_map.AddConstraint(ParseAffineExpr("0", &mlir_context_), + Interval{10, 15}); + EXPECT_THAT(indexing_map, MatchIndexingMap(R"( + (d0) -> (d0) + domain: + d0 in [0, 49] + 0 in [10, 15] + )")); +} + TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { IndexingMap indexing_map = IndexingMap::FromTensorSizes( ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)",