From ddc6e60e778b2469d21e45399f6931fb5e983e92 Mon Sep 17 00:00:00 2001 From: Bill Varcho Date: Sat, 9 Nov 2024 23:54:18 -0800 Subject: [PATCH] [SDY] refactor propagation functions in basic_propagation.cc to utilize a parameter struct (to cleanup function signatures). PiperOrigin-RevId: 695000954 --- .../propagation/basic_propagation.cc | 354 ++++++++++-------- 1 file changed, 201 insertions(+), 153 deletions(-) diff --git a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc index c7cf2dbb..edac0c1e 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc @@ -121,23 +121,44 @@ void notifyShardingModified(Value value, notifyUsersModified(value, notifyOpModified); } +// Struct to hold common parameters for sharding propagation. +struct PropagationSharedParams { + const ShardingGroupMap& shardingGroupMap; + StringRef meshName; + MeshAttr mesh; + std::optional notifyOpModified; +}; + +struct PropagationTensorParams { + ValueRange tensors; + ArrayRef shardings; + SetShardingPerTensorCallback setShardingCallback; + + PropagationTensorParams(ValueRange tensors, + ArrayRef shardings, + SetShardingPerTensorCallback setShardingCallback) + : tensors(tensors), + shardings(shardings), + setShardingCallback(setShardingCallback) {} +}; + // Update the sharding of `value` to the sharding in `tensorFactorShardings`. // // Returns true if it's possible to update the sharding, i.e., if strided view // isn't needed and all non-minor-most factors are divisible by sharding axes. -bool updateTensorSharding( - TensorShardingAttr oldTensorSharding, - SetTensorShardingCallback setTensorShardingCallback, - const TensorFactorShardings& tensorFactorShardings, - TensorMappingAttr tensorMapping, ArrayRef factorSizes, - StringRef meshName, MeshAttr mesh, Value modifiedValue, - const ShardingGroupMap& shardingGroupMap, - std::optional notifyOpModified) { +bool updateTensorSharding(Value modifiedValue, + TensorShardingAttr oldTensorSharding, + SetTensorShardingCallback setTensorShardingCallback, + const TensorFactorShardings& tensorFactorShardings, + TensorMappingAttr tensorMapping, + ArrayRef factorSizes, + const PropagationSharedParams& params) { // We can assume `modifiedValue` exists since we are updating its sharding. assert(modifiedValue && "modified value should exist"); TensorShardingAttr newSharding = tensorFactorShardings.createTensorShardingAttr( - mesh.getContext(), tensorMapping, factorSizes, meshName, mesh); + params.mesh.getContext(), tensorMapping, factorSizes, params.meshName, + params.mesh); // `oldTensorSharding` may be null if there is no sharding, in which case we // check if `newSharding` is empty. // TODO(tomnatan): remove this checking if the new sharding equals the old @@ -155,19 +176,20 @@ bool updateTensorSharding( setTensorShardingCallback(newSharding); - if (notifyOpModified) { - notifyShardingModified(modifiedValue, *notifyOpModified); + if (params.notifyOpModified) { + notifyShardingModified(modifiedValue, *params.notifyOpModified); } // Set the sharding of all values in the same sharding group to be equivalent // (skipping the modified value which has already been updated). - for (Value groupValue : shardingGroupMap.getGroupMembers(modifiedValue)) { + for (Value groupValue : + params.shardingGroupMap.getGroupMembers(modifiedValue)) { if (groupValue == modifiedValue) { continue; } setSharding(groupValue, newSharding); - if (notifyOpModified) { - notifyShardingModified(groupValue, *notifyOpModified); + if (params.notifyOpModified) { + notifyShardingModified(groupValue, *params.notifyOpModified); } } @@ -182,47 +204,35 @@ bool updateTensorSharding( // `tensorFactorShardings`, e.g., if strided view is required, sets the // respective bit in `updateTensor` or `updateResult` to false. void updateTensorShardings( - ValueRange tensors, ArrayRef tensorShardings, - SetShardingPerTensorCallback setTensorShardingCallback, + const PropagationTensorParams& tensorParams, ArrayRef tensorFactorShardings, ArrayRef tensorMappings, ArrayRef factorSizes, - BitVector& updateTensor, StringRef meshName, MeshAttr mesh, - const ShardingGroupMap& shardingGroupMap, - std::optional notifyOpModified) { + BitVector& updateTensor, const PropagationSharedParams& params) { for (int64_t index : updateTensor.set_bits()) { - if (!updateTensorSharding( - tensorShardings[index], - std::bind(setTensorShardingCallback, std::placeholders::_1, index), - tensorFactorShardings[index], tensorMappings[index], factorSizes, - meshName, mesh, getShardableValue(tensors[index]), shardingGroupMap, - notifyOpModified)) { + if (!updateTensorSharding(getShardableValue(tensorParams.tensors[index]), + tensorParams.shardings[index], + std::bind(tensorParams.setShardingCallback, + std::placeholders::_1, index), + tensorFactorShardings[index], + tensorMappings[index], factorSizes, params)) { updateTensor.reset(index); } } } // Same as the overload above, except operates on both operands and results. -void updateTensorShardings( - ValueRange operands, ValueRange results, - ArrayRef operandShardings, - ArrayRef resultShardings, - SetShardingPerTensorCallback setOperandShardingCallback, - SetShardingPerTensorCallback setResultShardingCallback, - OpShardingRuleAttr shardingRule, - const ShardingProjection& shardingProjection, BitVector& updateOperand, - BitVector& updateResult, StringRef meshName, MeshAttr mesh, - const ShardingGroupMap& shardingGroupMap, - std::optional notifyOpModified) { - updateTensorShardings(operands, operandShardings, setOperandShardingCallback, - shardingProjection.getOperands(), +void updateTensorShardings(const PropagationTensorParams& operandsParams, + const PropagationTensorParams& resultsParams, + OpShardingRuleAttr shardingRule, + const ShardingProjection& shardingProjection, + BitVector& updateOperand, BitVector& updateResult, + const PropagationSharedParams& params) { + updateTensorShardings(operandsParams, shardingProjection.getOperands(), shardingRule.getOperandMappings(), - shardingRule.getFactorSizes(), updateOperand, meshName, - mesh, shardingGroupMap, notifyOpModified); - updateTensorShardings(results, resultShardings, setResultShardingCallback, - shardingProjection.getResults(), + shardingRule.getFactorSizes(), updateOperand, params); + updateTensorShardings(resultsParams, shardingProjection.getResults(), shardingRule.getResultMappings(), - shardingRule.getFactorSizes(), updateResult, meshName, - mesh, shardingGroupMap, notifyOpModified); + shardingRule.getFactorSizes(), updateResult, params); } // Propagates tensor shardings of the given `operands` and `results` according @@ -232,17 +242,15 @@ void updateTensorShardings( // the Operation. For example, for CaseOp, an op with no operands, it's called // with the return values of each branch/region. LogicalResult propagateTensorShardings( - ValueRange operands, ValueRange results, - ArrayRef operandShardings, - ArrayRef resultShardings, - SetShardingPerTensorCallback setOperandShardingCallback, - SetShardingPerTensorCallback setResultShardingCallback, + const PropagationTensorParams& operandsParams, + const PropagationTensorParams& resultsParams, OpShardingRuleAttr shardingRule, PropagationDirection direction, - const FactorPropagation& factorPropagation, - const ShardingGroupMap& shardingGroupMap, bool conservativePropagation, - Operation* op, const SymbolTable& symbolTable, PatternRewriter* rewriter) { - std::optional meshName = - getCommonMeshName(operandShardings, resultShardings, symbolTable); + const FactorPropagation& factorPropagation, bool conservativePropagation, + Operation* op, const SymbolTable& symbolTable, PatternRewriter* rewriter, + ShardingGroupMap shardingGroupMap) { + std::optional meshName = getCommonMeshName( + operandsParams.shardings, resultsParams.shardings, symbolTable); + if (!meshName.has_value()) { // This means none of the operands or results have a sharding attribute or // the sharding attributes use different meshes. @@ -256,7 +264,7 @@ LogicalResult propagateTensorShardings( assert(mesh && "unknown mesh"); ShardingProjection shardingProjection = ShardingProjection::build( - operandShardings, resultShardings, shardingRule, mesh); + operandsParams.shardings, resultsParams.shardings, shardingRule, mesh); auto [updateOperand, updateResult] = factorPropagation.propagateFactorShardings( @@ -278,15 +286,12 @@ LogicalResult propagateTensorShardings( }; } - op->getContext()->executeAction( - [&]() { - updateTensorShardings( - operands, results, operandShardings, resultShardings, - setOperandShardingCallback, setResultShardingCallback, shardingRule, - shardingProjection, updateOperand, updateResult, meshName.value(), - mesh, shardingGroupMap, notifyOpModified); - }, - /*IRUnits=*/{op}, operands, results, operandShardings, resultShardings); + PropagationSharedParams params{shardingGroupMap, meshName.value(), mesh, + notifyOpModified}; + + updateTensorShardings(operandsParams, resultsParams, shardingRule, + shardingProjection, updateOperand, updateResult, + params); bool anyUpdated = updateOperand.any() || updateResult.any(); if (rewriter && !anyUpdated) { @@ -299,10 +304,8 @@ LogicalResult propagateTensorShardings( // Same as the overload above, except there is a single operand and result. LogicalResult propagateTensorShardings( - Value operand, Value result, TensorShardingAttr operandSharding, - TensorShardingAttr resultsSharding, - SetTensorShardingCallback setOperandShardingCallback, - SetTensorShardingCallback setResultShardingCallback, + const PropagationTensorParams& operandsParams, + const PropagationTensorParams& resultsParams, OpShardingRuleAttr shardingRule, Operation* op, const SymbolTable& symbolTable, PatternRewriter* rewriter, const FactorPropagation& factorPropagation, @@ -310,15 +313,8 @@ LogicalResult propagateTensorShardings( PropagationDirection direction = PropagationDirection::BOTH, bool conservativePropagation = false) { return propagateTensorShardings( - operand, result, operandSharding, resultsSharding, - [&](TensorShardingAttr sharding, int64_t) { - setOperandShardingCallback(sharding); - }, - [&](TensorShardingAttr sharding, int64_t) { - setResultShardingCallback(sharding); - }, - shardingRule, direction, factorPropagation, shardingGroupMap, - conservativePropagation, op, symbolTable, rewriter); + operandsParams, resultsParams, shardingRule, direction, factorPropagation, + conservativePropagation, op, symbolTable, rewriter, shardingGroupMap); } // Same as the overload above, except the operand and result shardings are @@ -330,16 +326,24 @@ LogicalResult propagateTensorShardings( const ShardingGroupMap& shardingGroupMap, PropagationDirection direction = PropagationDirection::BOTH, bool conservativePropagation = false) { - return propagateTensorShardings( - operands, results, getShardings(operands), getShardings(results), - [&](TensorShardingAttr sharding, int64_t index) { + SmallVector operandsShardings = getShardings(operands); + SmallVector resultsShardings = getShardings(results); + PropagationTensorParams operandsParams = PropagationTensorParams( + /*tensors=*/operands, + /*shardings=*/operandsShardings, + /*setShardingCallback=*/[&](TensorShardingAttr sharding, int64_t index) { setSharding(operands[index], sharding); - }, - [&](TensorShardingAttr sharding, int64_t index) { + }); + PropagationTensorParams resultsParams = PropagationTensorParams( + /*tensors=*/results, + /*shardings=*/resultsShardings, + /*setShardingCallback=*/[&](TensorShardingAttr sharding, int64_t index) { setSharding(results[index], sharding); - }, - shardingRule, direction, factorPropagation, shardingGroupMap, - conservativePropagation, op, symbolTable, &rewriter); + }); + + return propagateTensorShardings( + operandsParams, resultsParams, shardingRule, direction, factorPropagation, + conservativePropagation, op, symbolTable, &rewriter, shardingGroupMap); } // Propagates the shardings between the operands of the `funcOp`'s terminator @@ -359,23 +363,32 @@ LogicalResult propagateFuncResults(FuncOp funcOp, // NOTE: we void the returned `LogicalResult` since function updates aren't // done through a rewriter, can ignore whether operands/results were // updated. - (void)propagateTensorShardings( - // The operand/result function arguments are used to: - // - invoke the rewriter (if specified) that a value was updated. But - // a rewriter isn't used here. - // - log warnings on the defining op. In this case it would either be on - // the defining op of `returnValue` or `funcOp` if it's a function - // argument. Here it will be okay to log the warning on the defining - // op of `returnValue`. - // As such, we pass `returnValue` as both the operand and result. - returnValue, returnValue, getSharding(returnValue), - getFuncResultSharding(funcOp, resNum), - [&](TensorShardingAttr sharding) { + // The operand/result function arguments are used to: + // - invoke the rewriter (if specified) that a value was updated. But + // a rewriter isn't used here. + // - log warnings on the defining op. In this case it would either be on + // the defining op of `returnValue` or `funcOp` if it's a function + // argument. Here it will be okay to log the warning on the defining + // op of `returnValue`. + // As such, we pass `returnValue` as both the operand and result. + TensorShardingAttr operandShardingRef = getSharding(returnValue); + TensorShardingAttr resultsShardingRef = + getFuncResultSharding(funcOp, resNum); + PropagationTensorParams operandsParams = PropagationTensorParams( + /*tensors=*/returnValue, + /*shardings=*/operandShardingRef, + /*setShardingCallback=*/[&](TensorShardingAttr sharding, int64_t) { setSharding(returnValue, sharding); - }, - [&](TensorShardingAttr sharding) { + }); + PropagationTensorParams resultsParams = PropagationTensorParams( + /*tensors=*/returnValue, + /*shardings=*/resultsShardingRef, + /*setShardingCallback=*/[&](TensorShardingAttr sharding, int64_t) { setFuncResultSharding(funcOp, resNum, sharding); - }, + }); + + (void)propagateTensorShardings( + operandsParams, resultsParams, // Treat the sharding data flow b/w the `funcOp` terminator and func // result attrs as an identity op. Create an equivalent sharding // rule. @@ -469,24 +482,35 @@ class PropagateDataFlowEdgeOp : public OpRewritePattern { SmallVector sources = getDataFlowSources(dataFlowEdgeOp); // The sharding of `dataFlowEdgeOp.getResult()` is the sharding of all // targets. - return propagateTensorShardings( - sources, dataFlowEdgeOp.getResult(), getShardings(sources), - transformTargetSharding( - dataFlowEdgeOp, dataFlowEdgeOp.getShardingAttr(), - DataFlowShardingTransformType::kBeforeEdgePropagation), - [&](TensorShardingAttr sharding, int64_t index) { + SmallVector operandShardingRef = getShardings(sources); + TensorShardingAttr resultsShardingRef = transformTargetSharding( + dataFlowEdgeOp, dataFlowEdgeOp.getShardingAttr(), + DataFlowShardingTransformType::kBeforeEdgePropagation); + PropagationTensorParams operandsParams = PropagationTensorParams( + /*tensors=*/sources, + /*shardings=*/operandShardingRef, + /*setShardingCallback=*/ + [&sources](TensorShardingAttr sharding, int64_t index) { setSharding(sources[index], sharding); - }, - [&](TensorShardingAttr sharding, int64_t _) { + }); + Value result = dataFlowEdgeOp.getResult(); + PropagationTensorParams resultsParams = PropagationTensorParams( + /*tensors=*/result, + /*shardings=*/resultsShardingRef, + /*setShardingCallback=*/ + [&dataFlowEdgeOp](TensorShardingAttr sharding, int64_t _) { dataFlowEdgeOp.setShardingAttr(transformTargetSharding( dataFlowEdgeOp, sharding, DataFlowShardingTransformType::kAfterEdgePropagation)); - }, + }); + + return propagateTensorShardings( + operandsParams, resultsParams, createIdentityShardingRule(cast(dataFlowEdgeOp.getType()), sources.size()), - PropagationDirection::BOTH, factorPropagation, shardingGroupMap, + PropagationDirection::BOTH, factorPropagation, /*conservativePropagation=*/false, dataFlowEdgeOp, symbolTable, - &rewriter); + &rewriter, shardingGroupMap); } private: @@ -542,31 +566,41 @@ class PropagateManualComputationOp manualComputationOp.getBody().getArguments()) { const int64_t argNumber = blockArg.getArgNumber(); Value operand = manualComputationOp->getOperand(argNumber); - updated |= - propagateTensorShardings( - operand, blockArg, - // Since this is propagating outside of the region of the - // `ManualComputationOp`, make sure we keep the manual axes - // as we may be able to propagate those backwards. - // `getSharding` on the block arg would remove them, so - // need to get the right `in_shardings` explicitly using - // `getInSharding`. - getSharding(operand), - manualComputationOp.getInSharding(argNumber), - [&operand](TensorShardingAttr sharding) { - setSharding(operand, sharding); - }, - // Similarly as above, since `setSharding` will add the - // manual axes back, but they already exist, we set the - // `in_shardings` explicitly using `setInSharding`. - [&manualComputationOp, argNumber](TensorShardingAttr sharding) { - manualComputationOp.setInSharding(argNumber, sharding); - }, - createIdentityShardingRule( - cast(operand.getType())), - manualComputationOp, symbolTable, &rewriter, factorPropagation, - shardingGroupMap) - .succeeded(); + + // Since this is propagating outside of the region of the + // `ManualComputationOp`, make sure we keep the manual axes + // as we may be able to propagate those backwards. + // `getSharding` on the block arg would remove them, so + // need to get the right `in_shardings` explicitly using + // `getInSharding`. + TensorShardingAttr operandSharding = getSharding(operand); + TensorShardingAttr resultsSharding = + manualComputationOp.getInSharding(argNumber); + PropagationTensorParams operandsParams = PropagationTensorParams( + /*tensors=*/operand, + /*shardings=*/operandSharding, + /*setShardingCallback=*/ + [&operand](TensorShardingAttr sharding, int64_t) { + setSharding(operand, sharding); + }); + + PropagationTensorParams resultsParams = PropagationTensorParams( + /*tensors=*/blockArg, + /*shardings=*/resultsSharding, + /*setShardingCallback=*/ + [&manualComputationOp, argNumber](TensorShardingAttr sharding, + int64_t) { + manualComputationOp.setInSharding(argNumber, sharding); + }); + + updated |= propagateTensorShardings( + operandsParams, resultsParams, + createIdentityShardingRule( + cast(operand.getType())), + PropagationDirection::BOTH, factorPropagation, + /*conservativePropagation=*/false, manualComputationOp, + symbolTable, &rewriter, shardingGroupMap) + .succeeded(); } // 2. Propagate between the uses of the `ManualComputationOp` and the @@ -578,23 +612,37 @@ class PropagateManualComputationOp // Since this is propagating on the border of the local region of manual // axes and global program, only use shardings without the manual axes. // `setSharding` will add them back for `out_shardings`. - updated |= - propagateTensorShardings( - returnValue.get(), opResult, getSharding(returnValue.get()), - manualComputationOp.getOutShardingWithoutManualAxes( - operandNumber), - [&returnValue](TensorShardingAttr sharding) { - setSharding(returnValue.get(), sharding); - }, - [&](TensorShardingAttr sharding) { - manualComputationOp.setOutShardingAddingManualAxes( - operandNumber, sharding); - }, - createIdentityShardingRule( - cast(opResult.getType())), - manualComputationOp, symbolTable, &rewriter, factorPropagation, - shardingGroupMap) - .succeeded(); + Value value = returnValue.get(); + TensorShardingAttr operandSharding = getSharding(value); + TensorShardingAttr resultsSharding = + manualComputationOp.getOutShardingWithoutManualAxes(operandNumber); + + PropagationTensorParams operandsParams = PropagationTensorParams( + /*tensors=*/value, + /*shardings=*/operandSharding, + /*setShardingCallback=*/ + [&value](TensorShardingAttr sharding, int64_t) { + setSharding(value, sharding); + }); + + PropagationTensorParams resultsParams = PropagationTensorParams( + /*tensors=*/opResult, + /*shardings=*/resultsSharding, + /*setShardingCallback=*/ + [&manualComputationOp, operandNumber](TensorShardingAttr sharding, + int64_t) { + manualComputationOp.setOutShardingAddingManualAxes(operandNumber, + sharding); + }); + + updated |= propagateTensorShardings( + operandsParams, resultsParams, + createIdentityShardingRule( + cast(opResult.getType())), + PropagationDirection::BOTH, factorPropagation, + /*conservativePropagation=*/false, manualComputationOp, + symbolTable, &rewriter, shardingGroupMap) + .succeeded(); } return success(updated);