Skip to content

Commit

Permalink
Add kUsingGpuRocm property tag.
Browse files Browse the repository at this point in the history
Tests can query this tag to determine whether they are running under ROCm.

PiperOrigin-RevId: 718909039
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Jan 23, 2025
1 parent d71325f commit e795171
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 16 deletions.
6 changes: 2 additions & 4 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4660,17 +4660,14 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
"//xla/hlo/parser:hlo_parser",
"//xla/service/gpu:gpu_executable_run_options",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_memory_allocator",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor:stream_executor_memory_allocator",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
"@tsl//tsl/platform:blocking_counter",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
],
)

Expand All @@ -4695,6 +4692,7 @@ cc_library(
"//xla/pjrt:host_memory_spaces",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_common",
"//xla/pjrt:pjrt_compiler",
"//xla/pjrt:pjrt_executable",
"//xla/pjrt:pjrt_future",
"//xla/tsl/platform:env",
Expand Down
19 changes: 13 additions & 6 deletions xla/service/hlo_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS

#include "xla/service/hlo_runner.h"

#include <memory>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include "unsupported/Eigen/CXX11/Tensor"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/layout_util.h"
#include "xla/service/executable.h"
#include "xla/service/gpu/gpu_executable_run_options.h"
#include "xla/service/hlo_module_util.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/service/transfer_manager.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/stream_executor_memory_allocator.h"
#include "tsl/platform/blocking_counter.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"

namespace xla {

Expand Down Expand Up @@ -714,4 +711,14 @@ absl::string_view HloRunner::Name() const {
return backend_->platform()->Name();
}

bool HloRunner::HasProperty(const HloRunnerPropertyTag::Type tag) const {
if (tag == HloRunnerPropertyTag::kUsingGpuRocm) {
const stream_executor::DeviceDescription& device_description =
backend().default_stream_executor()->GetDeviceDescription();
return std::holds_alternative<stream_executor::RocmComputeCapability>(
device_description.gpu_compute_capability());
}
return false;
}

} // namespace xla
4 changes: 1 addition & 3 deletions xla/service/hlo_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ class HloRunner : public HloRunnerInterface {

int device_count() const override { return backend().device_count(); }

bool HasProperty(const HloRunnerPropertyTag::Type tag) const override {
return false;
}
bool HasProperty(HloRunnerPropertyTag::Type tag) const override;

private:
absl::StatusOr<ExecutionOutput> ExecuteWithExecutionInputs(
Expand Down
2 changes: 2 additions & 0 deletions xla/service/hlo_runner_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class HloRunnerPropertyTag final {
// Default, reserved value for HloRunnerPropertyTag. Perhaps this could be
// used as a sentinel value for a tag that is not present. Do not use.
static constexpr Type kDefault = 0;
// Indicates that the runner is using ROCm.
static constexpr Type kUsingGpuRocm = 1;

private:
HloRunnerPropertyTag() = default;
Expand Down
8 changes: 8 additions & 0 deletions xla/service/hlo_runner_pjrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "xla/pjrt/host_memory_spaces.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/service/computation_layout.h"
Expand Down Expand Up @@ -657,4 +658,11 @@ absl::StatusOr<std::vector<Literal>> HloRunnerPjRt::ExecuteReplicatedImpl(

absl::string_view HloRunnerPjRt::Name() const { return "HloRunnerPjRt"; }

bool HloRunnerPjRt::HasProperty(const HloRunnerPropertyTag::Type tag) const {
if (tag == HloRunnerPropertyTag::kUsingGpuRocm) {
return pjrt_client_->platform_name() == xla::RocmName();
}
return false;
}

} // namespace xla
4 changes: 1 addition & 3 deletions xla/service/hlo_runner_pjrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,7 @@ class HloRunnerPjRt : public HloRunnerInterface {

int device_count() const override { return pjrt_client_->device_count(); }

bool HasProperty(const HloRunnerPropertyTag::Type tag) const override {
return false;
}
bool HasProperty(HloRunnerPropertyTag::Type tag) const override;

private:
absl::StatusOr<CompileOptions> GenerateDefaultCompileOptions(
Expand Down

0 comments on commit e795171

Please sign in to comment.