Skip to content

Commit

Permalink
Improve aggressive factor propagation strategy in Shardy. There are t…
Browse files Browse the repository at this point in the history
…wo main differences from `BasicFactorPropagation`.

### Difference 1
`BasicFactorPropagation` propagates the same sharding axes to all the tensors
along a factor. This strategy can propagate different sharding axes to
different tensors. For example, Tensors T0, T1, T2 contains Factor F0. T0/F0
is already sharded along ["a", "b"], and "b" is already used by T2 ("b" can be
explicitly replicated, or it is used to shard another factor).
`BasicFactorPropagation` propagates ["a"] to both T1/F0 and T2/F0, while this
strategy propagates ["a", "b"] to T1/F0 and ["a"] to T2/F0, respectively.

### Difference 2
`BasicFactorPropagation` is conservative in terms of conflicts across
factors. The overlapped axis between factors cannot be propagated. This
strategy is more aggressive by allowing the overlapped axis being propagated
along different factors if there is no overlapped axis in the result
shardings.

PiperOrigin-RevId: 657641564
  • Loading branch information
ZixuanJiang authored and copybara-github committed Jul 30, 2024
1 parent 838d1aa commit 2eeee20
Show file tree
Hide file tree
Showing 14 changed files with 441 additions and 386 deletions.
3 changes: 2 additions & 1 deletion shardy/dialect/sdy/transforms/propagation/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,10 @@ cc_test(
srcs = ["aggressive_factor_propagation_test.cc"],
deps = [
":aggressive_factor_propagation",
":basic_factor_propagation",
":factor_propagation",
":sharding_projection",
":testing_utils",
":utils",
"//shardy/dialect/sdy/ir:dialect",
"@com_google_googletest//:gtest_main",
"@llvm-project//llvm:Support",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,146 +23,99 @@ limitations under the License.
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h"

namespace mlir {
namespace sdy {

AxesPerFactor
AggressiveFactorPropagation::getCompatibleMajorShardingAxesForAllFactors(
const ShardingProjection& projection, PropagationDirection direction,
namespace {

bool updateTensorSharding(ShardingProjection& projection, int64_t tensorIndex,
int64_t factorIndex, ArrayRef<AxisRefAttr> newAxes) {
if (tensorIndex < projection.getNumOperands()) {
return projection.updateOperandSharding(tensorIndex, factorIndex, newAxes);
}
return projection.updateResultSharding(
tensorIndex - projection.getNumOperands(), factorIndex, newAxes);
}

} // namespace

UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
ShardingProjection& projection, PropagationDirection direction,
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
bool conservativePropagation) const {
UpdateTensorShardings result{
.updateOperands = BitVector(projection.getNumOperands()),
.updateResults = BitVector(projection.getNumResults())};
if (direction == PropagationDirection::NONE) {
return AxesPerFactor(factorSizes.size());
return result;
}

// Finds the compatible major axes ignoring conflicts.
AxesPerFactor result;
result.reserve(factorSizes.size());
// Find the compatible major axes ignoring conflicts.
SmallVector<SmallVector<AxisRefAttr>> axesPerFactor;
axesPerFactor.reserve(factorSizes.size());
bool allElementsAreEmpty = true;
for (int64_t i = 0; i < factorSizes.size(); ++i) {
result.push_back(getCompatibleMajorAxes(projection, i, direction, op));
SmallVector<AxisRefAttr>& axes = axesPerFactor.emplace_back(
getCompatibleMajorAxes(projection, i, direction, op));
if (!axes.empty()) {
allElementsAreEmpty = false;
}
}
if (allElementsAreEmpty) {
return result;
}

// Removes the conflicts within every single factor. This strategy and
// `BasicFactorPropagation` handles conflicts within a factor in the same way.
for (const TensorFactorShardings& tensorFactorShardings :
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
projection.getResults())) {
for (const auto& [factorIndex, factorSharding] :
tensorFactorShardings.factorIndexToSharding) {
// The propagation on each tensor is independent. This strategy can propagate
// different shardings to different tensors along the same factor. Examples
// are provided in the docstring of this class.
for (const auto& [tensorIndex, tensorFactorShardings] :
llvm::enumerate(llvm::concat<const TensorFactorShardings>(
projection.getOperands(), projection.getResults()))) {
// Propagate the axes got in Step 1, and resolve conflicts within a factor.
FactorIndexToSharding newSharding =
tensorFactorShardings.factorIndexToSharding;
BitVector factorUpdated(factorSizes.size());
for (auto& [factorIndex, factorSharding] : newSharding) {
SmallVector<AxisRefAttr> newAxes = axesPerFactor[factorIndex];
truncateAxesByRemovingConflicts(
result[factorIndex],
newAxes,
[&, factorIndex = factorIndex, &factorSharding = factorSharding](
AxisRefAttr axisRef, int64_t shardedSize) {
return compatiblePrefixNoConflictsWithinFactor(
axisRef, tensorFactorShardings.replicatedAxes, factorSharding,
shardedSize, factorSizes[factorIndex]);
},
mesh, conservativePropagation);
if (shouldUpdate(factorSharding.axisRefs, newAxes)) {
factorSharding.axisRefs = newAxes;
factorUpdated.set(factorIndex);
}
}
}

// Removes the conflicts across factors, where this strategy and
// `BasicFactorPropagation` diverge.
//
// With `BasicFactorPropagation`, the compatible axes of a factor Fi cannot
// overlap with the existing sharding axes or the overflow axes related to all
// other factors. This criterion is considered for all tensors, no matter if
// Fi is mapped to the tensor or not. The table below shows the criterion:
//
// existing sharding axes & overflow axes new sharding axes
// factor in tensor remove overlap -
// factor not in tensor remove overlap -
//
// On the contrary, `AggressiveFactorPropagation` has the following criterion:
//
// existing sharding axes & overflow axes new sharding axes
// factor in tensor remove overlap remove overlap
// factor not in tensor - -
//
// There are two differences:
//
// 1. `BasicFactorPropagation` removes the overlap between the compatible axes
// of a factor Fi with the existing sharding axes and overflow axes in a
// tensor Tj even if Fi is not in Tj. `AggressiveFactorPropagation` does not
// remove this overlap if Fi is not in Tj. `BasicFactorPropagation` is too
// strict, since we cannot propagate sharding axes to Tj along Fi.
//
// `AggressiveFactorPropagation` cannot handle the following case if we only
// have difference #1. `-` means that the factor is not mapped to the tensor.
// After removing conflicts within factors, we will propagate "x" to T2 along
// F0 and F1 at the same time, which induces a conflict. To resolve this
// conflict, we have difference #2.
//
// F0 F1
// T0 "x" -
// T1 - "x"
// T2 ? ?
//
// 2. `AggressiveFactorPropagation` removes the overlap between compatible
// axes of a factor Fi with the potential new sharding axes of other factors
// in Tj if Fi is in Tj. Thus, it is safe to propagate the axes to Tj along Fi
// without conflicts with other factors. In the example, we will not propagate
// "x" along F0 or F1 since their potential new sharding axes overlap.
//
// The potential new sharding axes are saved in `resultSnapshot`. It is a hard
// copy since we need to handle the following case.
//
// F0 F1 F2
// T0 "x" - -
// T1 - "x" -
// T2 - - "x"
// T3 ? ? ?
//
// The `result` and `resultSnapshot` is [["x"], ["x"], ["x"]] before removing
// conflicts across factors. After removing conflicts between F0/F1 and other
// factors, `result` is [[], [], ["x"]]. When we remove conflicts between F2
// and other factors, if we use `result` as the potential new sharding axes,
// we will not remove "x" for F2 because it is no longer present in 'result'
// for F0 and F1. We have to use `resultSnapshot` to save the potential new
// sharding axes and remove "x" for F2.
const AxesPerFactor resultSnapshot = result;
for (const TensorFactorShardings& tensorFactorSharding :
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
projection.getResults())) {
for (const auto& [factorIndex, factorSharding] :
tensorFactorSharding.factorIndexToSharding) {
// Resolve conflicts (overlapping sharding axes) between factors.
bool tensorUpdated = false;
for (const int64_t factorIndex : factorUpdated.set_bits()) {
SmallVector<AxisRefAttr> newAxes = newSharding[factorIndex].axisRefs;
truncateAxesByRemovingConflicts(
result[factorIndex],
newAxes,
[&, factorIndex = factorIndex](AxisRefAttr axisRef, int64_t) {
return compatiblePrefixNoConflictsAcrossFactors(
axisRef, tensorFactorSharding.factorIndexToSharding,
factorIndex, resultSnapshot);
axisRef, newSharding, factorIndex);
},
mesh, conservativePropagation);
tensorUpdated |=
updateTensorSharding(projection, tensorIndex, factorIndex, newAxes);
}
}

return result;
}

UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
ShardingProjection& projection, PropagationDirection direction,
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
bool conservativePropagation) const {
UpdateTensorShardings result{
.updateOperands = BitVector(projection.getNumOperands()),
.updateResults = BitVector(projection.getNumResults())};

// We get the compatible major sharding axes for all factors.
AxesPerFactor axesPerFactor = getCompatibleMajorShardingAxesForAllFactors(
projection, direction, factorSizes, mesh, op, conservativePropagation);

for (auto [factorIndex, axesToPropagate] : llvm::enumerate(axesPerFactor)) {
// Update all shardings along this factor if possible.
auto [updateOperandForFactor, updateResultForFactor] =
projection.updateSharding(factorIndex, axesToPropagate);

result.updateOperands |= updateOperandForFactor;
result.updateResults |= updateResultForFactor;
if (tensorIndex < projection.getNumOperands()) {
result.updateOperands[tensorIndex] = tensorUpdated;
} else {
result.updateResults[tensorIndex - projection.getNumOperands()] =
tensorUpdated;
}
}

return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,40 +22,53 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h"

namespace mlir {
namespace sdy {

// An aggressive strategy of propagating sharding axes along factors.
// An aggressive strategy of propagating sharding axes along factors. There are
// two main differences from `BasicFactorPropagation`.
//
// This strategy is the same as `BasicFactorPropagation` on the conflicts within
// a factor. They are different on the conflicts across factors.
// `BasicFactorPropagation` propagates the same sharding axes to all tensors
// along a factor. This strategy can propagate different sharding axes to
// different tensors along the same factor. For example, Tensors T0, T1, T2
// contain Factor F0. T0/F0 is already sharded along ["a", "b"], and "b" is
// already used by T2 ("b" can be explicitly replicated, or it is used to shard
// another factor). `BasicFactorPropagation` propagates ["a"] to both T1/F0 and
// T2/F0, while this strategy propagates ["a", "b"] to T1/F0 and ["a"] to T2/F0,
// respectively. If T2/F0 is closed, `BasicFactorPropagation` propagates
// nothing, while this strategy propagates nothing to T2/F0 and still propagates
// ["a", "b"] to T1/F0.
//
// `BasicFactorPropagation` considers the conflicts across factors with a strict
// criterion. The result cannot overlap with the sharded axes or overflow axes
// related to all other factors. This aggressive strategy ignores "fake
// conflicts", which are propagation choices that can co-exist. This aggressive
// strategy ensures that the resultant axes can be propagated to all tensors
// containing the factor. Several examples of fake conflicts:
// `BasicFactorPropagation` is conservative in terms of conflicts across
// factors. The overlapped axis between factors cannot be propagated. This
// strategy is more aggressive by allowing the overlapped axis being propagated
// along different factors if there is no overlapped axis in the result
// shardings.
//
// 1. An axis is in factors Fi and Fj. If it is infeasible to propagate that
// axis along factor Fi, we may propagate that axis along factor Fj if all the
// destination tensors have not used that axis.
// Let us take C = dot(A, B) as an example. F0 is the factor corresponding to a
// non-contracting dimension of A. F1 corresponds to a non-contracting dimension
// of B. F2 corresponds to a contracting dimension. "-" means that the tensor
// does not contain the factor.
//
// 2. Two factors Fi and Fj do not co-exist in any tensor, so they never
// interfere with each other. If Fi and Fj are sharded along the same axis, we
// can propagate that axis along both factors.
// F0 F1 F2
// A "a" -
// B -
// C "a" -
// Case 1. Fake conflict. `BasicFactorPropagation` propagates nothing, while
// this strategy propagates "a" to B/F1.
//
// Although fake conflicts can co-exist without inference, we may still need to
// all-gather some tensors.
// F0 F1 F2
// A "a" -
// B - "a"
// C -
// Case 2. Real conflict. Both `BasicFactorPropagation` and this strategy
// propagate nothing. We can propagate "a" to C/F0 or C/F1, which is illegal
// since "a" cannot be used twice in C.
class AggressiveFactorPropagation : public BasicFactorPropagation {
public:
AxesPerFactor getCompatibleMajorShardingAxesForAllFactors(
const ShardingProjection& projection, PropagationDirection direction,
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
bool conservativePropagation) const override;

UpdateTensorShardings propagateFactorShardings(
ShardingProjection& projection, PropagationDirection direction,
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
Expand Down
Loading

0 comments on commit 2eeee20

Please sign in to comment.