Skip to content

Commit

Permalink
Add single value add_* and set_front_end_attribute functions to HloIn…
Browse files Browse the repository at this point in the history
…struction.

Eliminates the need to create temporary attribute object when you simply want to add or set one or few attributes.
Replace some of the usage.

PiperOrigin-RevId: 720014857
  • Loading branch information
toli-y authored and Google-ML-Automation committed Jan 27, 2025
1 parent 45e5c1a commit 06424d7
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 28 deletions.
5 changes: 2 additions & 3 deletions xla/frontend_attributes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ limitations under the License.
namespace xla {

void SetDisjointReadWriteRegionsAttr(HloInstruction* instruction) {
FrontendAttributes attrs;
(*attrs.mutable_map())[xla::kXlaDisjointReadWriteRegions] = "true";
instruction->add_frontend_attributes(attrs);
instruction->set_frontend_attribute(xla::kXlaDisjointReadWriteRegions,
"true");
}

bool HasDisjointReadWriteRegionsAttr(HloInstruction* instruction) {
Expand Down
30 changes: 22 additions & 8 deletions xla/hlo/ir/hlo_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2114,6 +2114,16 @@ class HloInstruction {
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
const std::string& suffix, HloCloneContext* context = nullptr) const;

// Implementation for non-common logic of CloneWithNewOperands.
// CloneWithNewOperands forwards to this method for some of the intstruction
// types.
virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
// TODO(b/80131774): This should be pure virtual.
LOG(FATAL) << "Unimplemented method.";
}

// Returns the computations this instruction directly calls (if any).
const PtrVec<HloComputation*>& called_computations() const {
return rare()->called_computations;
Expand Down Expand Up @@ -2260,6 +2270,18 @@ class HloInstruction {
}
}

bool add_frontend_attribute(const std::string& key,
const std::string& value) {
auto it =
mutable_rare()->frontend_attributes.mutable_map()->insert({key, value});
return it.second;
}

void set_frontend_attribute(const std::string& key,
const std::string& value) {
(*mutable_rare()->frontend_attributes.mutable_map())[key] = value;
}

bool has_frontend_attributes() const {
return has_rare() && !rare()->frontend_attributes.map().empty();
}
Expand Down Expand Up @@ -2798,14 +2820,6 @@ class HloInstruction {
bool ignore_channel_id_values,
bool ignore_commutative_operand_order) const;

// Implementation for non-common logic of CloneWithNewOperands.
virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
// TODO(b/80131774): This should be pure virtual.
LOG(FATAL) << "Unimplemented method.";
}

// Implementation for non-common logic of PrintExtraAttributes.
virtual void PrintExtraAttributesImpl(AttributePrinter& printer,
const HloPrintOptions& options) const {}
Expand Down
10 changes: 4 additions & 6 deletions xla/service/collective_permute_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,10 @@ absl::StatusOr<DecomposedCp> DecomposeCollectivePermute(
TF_RETURN_IF_ERROR(send->AddControlDependencyTo(recv_done));

if (!pipeline_decision.empty()) {
xla::FrontendAttributes attributes;
(*attributes.mutable_map())[kSendRecvPipelineAttr] = pipeline_decision;
send->add_frontend_attributes(attributes);
send_done->add_frontend_attributes(attributes);
recv->add_frontend_attributes(attributes);
recv_done->add_frontend_attributes(attributes);
send->set_frontend_attribute(kSendRecvPipelineAttr, pipeline_decision);
send_done->set_frontend_attribute(kSendRecvPipelineAttr, pipeline_decision);
recv->set_frontend_attribute(kSendRecvPipelineAttr, pipeline_decision);
recv_done->set_frontend_attribute(kSendRecvPipelineAttr, pipeline_decision);
}
return DecomposedCp{send, recv, cp->source_target_pairs()};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ absl::StatusOr<bool> CollectivePermuteValidIterationAnnotator::Run(
std::reverse(sendRecvValidation.begin(), sendRecvValidation.end());
}

xla::FrontendAttributes attributes;
std::string iteration_instances =
"{" +
absl::StrJoin(sendRecvValidation, ",",
Expand All @@ -149,10 +148,9 @@ absl::StatusOr<bool> CollectivePermuteValidIterationAnnotator::Run(
item.second, "}");
}) +
"}";
(*attributes.mutable_map())[kSendRecvValidationAttr] =
iteration_instances;

inst->add_frontend_attributes(attributes);
inst->set_frontend_attribute(kSendRecvValidationAttr,
iteration_instances);
VLOG(1) << "Adding " << kSendRecvValidationAttr << " to " << inst->name()
<< ": " << iteration_instances;
changed = true;
Expand Down
6 changes: 2 additions & 4 deletions xla/service/gpu/transforms/windowed_einsum_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1427,10 +1427,8 @@ absl::StatusOr<bool> WindowedEinsumHandler::Run(
// The loop is fully unrolled but has a trip count of 1
// To prevent it from being inlined by while loop simplifier,
// we add this attribute to it.
xla::FrontendAttributes attributes;
(*attributes.mutable_map())["skip-simplify-while-loops_trip-count-one"] =
"true";
result.new_while_op->add_frontend_attributes(attributes);
result.new_while_op->set_frontend_attribute(
"skip-simplify-while-loops_trip-count-one", "true");
TF_RETURN_IF_ERROR(
PostProcessUnrolledLoop(result.new_while_op, stream_id));
}
Expand Down
4 changes: 1 addition & 3 deletions xla/service/spmd/stateful_rng_spmd_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ absl::Status StatefulRngSpmdPartitioner::HandleRotateRightWhilePreprocessing(
->config()
.debug_options()
.xla_gpu_unsafe_pipelined_loop_annotator()) {
xla::FrontendAttributes attributes;
(*attributes.mutable_map())["is_pipelined_while_loop"] = "true";
while_loop->add_frontend_attributes(attributes);
while_loop->add_frontend_attribute("is_pipelined_while_loop", "true");
}
return absl::OkStatus();
}
Expand Down

0 comments on commit 06424d7

Please sign in to comment.