diff --git a/xla/hlo/ir/hlo_computation.cc b/xla/hlo/ir/hlo_computation.cc index 45651b3d80c21..f04850e93ee7b 100644 --- a/xla/hlo/ir/hlo_computation.cc +++ b/xla/hlo/ir/hlo_computation.cc @@ -1358,14 +1358,15 @@ Status HloComputation::ReplaceWithNewEntryComputationParameter( absl::StatusOr HloComputation::ReplaceInstruction( HloInstruction* old_instruction, HloInstruction* new_instruction, - bool preserve_sharding, bool relay_control_dependency) { + bool preserve_sharding, bool relay_control_dependency, + bool remove_unused_operands) { TF_RET_CHECK( ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape())) << ShapeUtil::HumanString(old_instruction->shape()) << " vs " << ShapeUtil::HumanString(new_instruction->shape()); - return ReplaceInstructionWithDifferentShape(old_instruction, new_instruction, - preserve_sharding, - relay_control_dependency); + return ReplaceInstructionWithDifferentShape( + old_instruction, new_instruction, preserve_sharding, + relay_control_dependency, remove_unused_operands); } Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, diff --git a/xla/hlo/ir/hlo_computation.h b/xla/hlo/ir/hlo_computation.h index 5b67b15788ca1..9c63f104d58ab 100644 --- a/xla/hlo/ir/hlo_computation.h +++ b/xla/hlo/ir/hlo_computation.h @@ -581,9 +581,11 @@ class HloComputation { // return false. Otherwise, when the replacement happens, if |new_instruction| // doesn't have any sharding information it will receive the sharding // information of |old_instruction|, and function will return true. - absl::StatusOr ReplaceInstruction( - HloInstruction* old_instruction, HloInstruction* new_instruction, - bool preserve_sharding, bool relay_control_dependency = false); + absl::StatusOr ReplaceInstruction(HloInstruction* old_instruction, + HloInstruction* new_instruction, + bool preserve_sharding, + bool relay_control_dependency = false, + bool remove_unused_operands = true); // Same as above, with preserve_sharding=false. Since this replacement always // happens, it returns just a Status as opposed to StatusOr diff --git a/xla/service/sharding_propagation.cc b/xla/service/sharding_propagation.cc index 5c15d22615dea..b18512ad1a9da 100644 --- a/xla/service/sharding_propagation.cc +++ b/xla/service/sharding_propagation.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -1562,41 +1563,69 @@ absl::StatusOr ProcessShardingInstruction( const bool use_shard_group = instruction_to_shard_group_id && shard_group_id_to_shard_as_group && shard_group_id_to_shard_like_group; - auto process_shard_group_instruction = [&](HloInstruction* instruction, - HloSharding sharding) { - if (use_shard_group && sharding.IsShardGroup()) { - // Store shard group relations. - const int64_t shard_group_id = sharding.GetShardGroup().shard_group_id; - (*instruction_to_shard_group_id)[instruction] = shard_group_id; - if (sharding.IsShardAs()) { - auto& shard_as_group = - (*shard_group_id_to_shard_as_group)[shard_group_id]; - if (!shard_as_group.empty()) { - CHECK(ShapeUtil::SameDimensions(instruction->shape(), - (*shard_as_group.begin())->shape())) - << "Instruction: " << instruction->ToString() - << " has different shape from the shapes of the other " - "instructions within the same shard_as group: " - << (*shard_as_group.begin())->shape().ToString(); - } - shard_as_group.insert(instruction); + // Process shard group instruction and returns if current instruction needs + // to be removed. + auto process_shard_group_instruction = + [&](HloInstruction* instruction, + bool replaced_with_copy) -> absl::StatusOr { + if (use_shard_group && instruction->has_sharding() && + instruction->sharding().IsShardGroup()) { + if (instruction->IsCustomCall("Sharding")) { + CHECK(instruction->operand(0)->opcode() != HloOpcode::kParameter || + (allow_spmd_sharding_propagation_to_parameters_vector && + allow_spmd_sharding_propagation_to_parameters_vector->size() == + module->entry_computation()->num_parameters() && + allow_spmd_sharding_propagation_to_parameters_vector->at( + instruction->operand(0)->parameter_number()))); + } + if (instruction->IsCustomCall("Sharding") && !replaced_with_copy) { + // Pass shard group to operand sharding custom-call if it's not + // replaced with a copy, meaning that the shardings are to annotate + // shard_group or shard_barrier only. + HloSharding operand_sharding = instruction->operand(0)->has_sharding() + ? instruction->operand(0)->sharding() + : HloSharding::Unknown(); + operand_sharding.SetShardGroup(instruction->sharding().GetShardGroup()); + instruction->mutable_operand(0)->set_sharding(operand_sharding); + return true; } else { - auto& shard_like_group = - (*shard_group_id_to_shard_like_group)[shard_group_id]; - if (!shard_like_group.empty()) { - CHECK(ShapeUtil::SameDimensions(instruction->shape(), - (*shard_like_group.begin())->shape())) - << "Instruction: " << instruction->ToString() - << " has different shape from the shapes of the other " - "instructions within the same shard_like group: " - << (*shard_like_group.begin())->shape().ToString(); + // Otherwise store the shard group relations. + const int64_t shard_group_id = + instruction->sharding().GetShardGroup().shard_group_id; + (*instruction_to_shard_group_id)[instruction] = shard_group_id; + if (instruction->sharding().IsShardAs()) { + auto& shard_as_group = + (*shard_group_id_to_shard_as_group)[shard_group_id]; + if (!shard_as_group.empty()) { + CHECK(ShapeUtil::SameDimensions(instruction->shape(), + (*shard_as_group.begin())->shape())) + << "Instruction: " << instruction->ToString() + << " has different shape from the shapes of the other " + "instructions within the same shard_as group: " + << (*shard_as_group.begin())->shape().ToString(); + } + shard_as_group.insert(instruction); + } else { + auto& shard_like_group = + (*shard_group_id_to_shard_like_group)[shard_group_id]; + if (!shard_like_group.empty()) { + CHECK(ShapeUtil::SameDimensions( + instruction->shape(), (*shard_like_group.begin())->shape())) + << "Instruction: " << instruction->ToString() + << " has different shape from the shapes of the other " + "instructions within the same shard_like group: " + << (*shard_like_group.begin())->shape().ToString(); + } + shard_like_group.insert(instruction); } - shard_like_group.insert(instruction); + HloSharding sharding = instruction->sharding(); + sharding.ClearShardGroup(); + instruction->set_sharding(std::move(sharding)); } - sharding.ClearShardGroup(); } - return sharding; + return false; }; + for (HloComputation* computation : module->computations(execution_threads)) { auto instructions = computation->MakeInstructionPostOrder(); for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) { @@ -1612,44 +1641,48 @@ absl::StatusOr ProcessShardingInstruction( Cast(instruction)->opaque(), &unspec_dims)); - // Replace it with a copy node so that it does not need special - // handling. - if (replace_sharding_with_copy) { + bool replaced_with_copy = + replace_sharding_with_copy && + (!original_sharding.IsUnknown() || + instruction->operand(0)->opcode() == HloOpcode::kParameter); + // Replace the sharding instruction with a copy node so that it does not + // need special handling. + if (replaced_with_copy) { auto copy = computation->AddInstruction(HloInstruction::CreateUnary( instruction->shape(), HloOpcode::kCopy, instruction->mutable_operand(0))); - TF_RETURN_IF_ERROR( - computation->ReplaceInstruction(instruction, copy)); - // Add into shard group. - HloSharding sharding = - process_shard_group_instruction(copy, original_sharding); - copy->set_sharding(sharding); + TF_ASSIGN_OR_RETURN( + std::ignore, computation->ReplaceInstruction( + instruction, copy, /*preserve_sharding=*/false, + /*relay_control_dependency=*/false, + /*remove_unused_operands=*/false)); + copy->set_sharding(original_sharding); instruction = copy; changed = true; } - // Strip the sharding of the shard group related annotations. + + TF_ASSIGN_OR_RETURN( + bool shard_group_remove_instruction, + process_shard_group_instruction(instruction, replaced_with_copy)); if (!unspec_dims.empty()) { absl::c_sort(unspec_dims); unspecified_dims->emplace(instruction, std::move(unspec_dims)); } else if (!instruction->operand(0)->has_sharding()) { - HloSharding sharding = original_sharding; - if (instruction->operand(0)->opcode() != HloOpcode::kParameter || - (allow_spmd_sharding_propagation_to_parameters_vector && - allow_spmd_sharding_propagation_to_parameters_vector->size() == - module->entry_computation()->num_parameters() && - allow_spmd_sharding_propagation_to_parameters_vector->at( - instruction->operand(0)->parameter_number()))) { - // Add operand(i.e. the annotated op) into shard group. - sharding = process_shard_group_instruction( - instruction->mutable_operand(0), sharding); - } - instruction->mutable_operand(0)->set_sharding(std::move(sharding)); + instruction->mutable_operand(0)->set_sharding( + instruction->sharding()); } - } else if (instruction->has_sharding()) { - // Handle shard group in parameters/outputs. - HloSharding sharding = process_shard_group_instruction( - instruction, instruction->sharding()); - instruction->set_sharding(std::move(sharding)); + if (shard_group_remove_instruction) { + TF_ASSIGN_OR_RETURN(std::ignore, + computation->ReplaceInstruction( + instruction, instruction->mutable_operand(0), + /*preserve_sharding=*/false, + /*relay_control_dependency=*/false, + /*remove_unused_operands=*/false)); + } + } else { + TF_ASSIGN_OR_RETURN(std::ignore, + process_shard_group_instruction( + instruction, /*replaced_with_copy=*/false)); } } } @@ -2975,6 +3008,23 @@ absl::StatusOr ShardingPropagation::Run( &shard_group_id_to_shard_like_group, &allow_spmd_sharding_propagation_to_parameters_vector_)); any_changed |= changed; + + for (const auto& [shard_group_id, shard_as_group] : + shard_group_id_to_shard_as_group) { + VLOG(5) << "Shard-As group " << shard_group_id << " contains:"; + for (auto instruction : shard_as_group) { + VLOG(5) << " " << instruction->ToString(); + } + } + + for (const auto& [shard_group_id, shard_like_group] : + shard_group_id_to_shard_like_group) { + VLOG(5) << "Shard-Like group " << shard_group_id << " contains:"; + for (auto instruction : shard_like_group) { + VLOG(5) << " " << instruction->ToString(); + } + } + // Check sizes of the given allow_spmd_sharding_propagation vectors if (allow_spmd_sharding_propagation_to_output_) { CHECK(!module->entry_computation()->root_instruction()->has_sharding() || diff --git a/xla/service/sharding_propagation_test.cc b/xla/service/sharding_propagation_test.cc index 100e03ed2710f..b9a9c0a3b372d 100644 --- a/xla/service/sharding_propagation_test.cc +++ b/xla/service/sharding_propagation_test.cc @@ -10776,7 +10776,7 @@ TEST_F(ShardingPropagationTest, PropagateShardAsBetweenInputOutput2) { HloModule jit_f, entry_computation_layout={(f32[8]{0:T(256)})->(f32[8]{0:T(256)}, f32[8]{0:T(256)})}, allow_spmd_sharding_propagation_to_output={true,true}, num_partitions=4 ENTRY main.9 { - Arg_0.1 = f32[8]{0} parameter(0), sharding={replicated} + Arg_0.1 = f32[8]{0} parameter(0) custom-call.6 = f32[8]{0} custom-call(Arg_0.1), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 0}, metadata={op_name="jit(f)/jit(main)/shard_alike" source_file="third_party/py/jax/tests/shard_alike_test.py" source_line=206} custom-call.4 = f32[8]{0} custom-call(Arg_0.1), custom_call_target="Sharding", sharding={devices=[4]<=[4]}, metadata={op_name="jit(f)/jit(main)/sharding_constraint[sharding=GSPMDSharding({devices=[4]<=[4]}) resource_env=ResourceEnv(mesh=Mesh(), ()) unconstrained_dims=set()]" source_file="third_party/py/jax/tests/shard_alike_test.py" source_line=204} constant.0 = f32[] constant(2)