diff --git a/xla/service/BUILD b/xla/service/BUILD index 1234b7d754a316..1956acf2176492 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -4660,17 +4660,14 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", - "//xla/hlo/parser:hlo_parser", "//xla/service/gpu:gpu_executable_run_options", + "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@tsl//tsl/platform:blocking_counter", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", ], ) @@ -4695,6 +4692,7 @@ cc_library( "//xla/pjrt:host_memory_spaces", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", + "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", "//xla/tsl/platform:env", diff --git a/xla/service/hlo_runner.cc b/xla/service/hlo_runner.cc index 2288f560f722f0..0f87c13a12caf0 100644 --- a/xla/service/hlo_runner.cc +++ b/xla/service/hlo_runner.cc @@ -12,30 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#define EIGEN_USE_THREADS #include "xla/service/hlo_runner.h" #include #include #include +#include #include #include "unsupported/Eigen/CXX11/Tensor" #include "xla/hlo/ir/hlo_module_group.h" -#include "xla/hlo/parser/hlo_parser.h" -#include "xla/layout_util.h" #include "xla/service/executable.h" #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/hlo_module_util.h" +#include "xla/service/hlo_runner_interface.h" #include "xla/service/transfer_manager.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "tsl/platform/blocking_counter.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" namespace xla { @@ -714,4 +711,17 @@ absl::string_view HloRunner::Name() const { return backend_->platform()->Name(); } +bool HloRunner::HasProperty(const HloRunnerPropertyTag::Type tag) const { + if (tag == HloRunnerPropertyTag::kUsingGpuRocm) { + const stream_executor::DeviceDescription& device_description = + backend().default_stream_executor()->GetDeviceDescription(); + return std::holds_alternative( + device_description.gpu_compute_capability()); + } + if (tag == HloRunnerPropertyTag::kCpu) { + return backend().platform()->Name() == "Host"; + } + return false; +} + } // namespace xla diff --git a/xla/service/hlo_runner.h b/xla/service/hlo_runner.h index 779698695cba71..856d42f1016b65 100644 --- a/xla/service/hlo_runner.h +++ b/xla/service/hlo_runner.h @@ -197,9 +197,7 @@ class HloRunner : public HloRunnerInterface { int device_count() const override { return backend().device_count(); } - bool HasProperty(const HloRunnerPropertyTag::Type tag) const override { - return false; - } + bool HasProperty(HloRunnerPropertyTag::Type tag) const override; private: absl::StatusOr ExecuteWithExecutionInputs( diff --git a/xla/service/hlo_runner_interface.h b/xla/service/hlo_runner_interface.h index 88f095e13b3e04..0f64789384bfa9 100644 --- a/xla/service/hlo_runner_interface.h +++ b/xla/service/hlo_runner_interface.h @@ -73,6 +73,10 @@ class HloRunnerPropertyTag final { // Default, reserved value for HloRunnerPropertyTag. Perhaps this could be // used as a sentinel value for a tag that is not present. Do not use. static constexpr Type kDefault = 0; + // Indicates that the runner is using ROCm. + static constexpr Type kUsingGpuRocm = 1; + // Indicates that this runner is a CPU runner. + static constexpr Type kCpu = 2; private: HloRunnerPropertyTag() = default; diff --git a/xla/service/hlo_runner_pjrt.cc b/xla/service/hlo_runner_pjrt.cc index d8adce70e0e053..08f79010dacdb9 100644 --- a/xla/service/hlo_runner_pjrt.cc +++ b/xla/service/hlo_runner_pjrt.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/service/computation_layout.h" @@ -657,4 +658,14 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( absl::string_view HloRunnerPjRt::Name() const { return "HloRunnerPjRt"; } +bool HloRunnerPjRt::HasProperty(const HloRunnerPropertyTag::Type tag) const { + if (tag == HloRunnerPropertyTag::kUsingGpuRocm) { + return pjrt_client_->platform_name() == xla::RocmName(); + } + if (tag == HloRunnerPropertyTag::kCpu) { + return pjrt_client_->platform_name() == xla::CpuName(); + } + return false; +} + } // namespace xla diff --git a/xla/service/hlo_runner_pjrt.h b/xla/service/hlo_runner_pjrt.h index c1c3b71542a782..c0a6c5340c37f4 100644 --- a/xla/service/hlo_runner_pjrt.h +++ b/xla/service/hlo_runner_pjrt.h @@ -125,9 +125,7 @@ class HloRunnerPjRt : public HloRunnerInterface { int device_count() const override { return pjrt_client_->device_count(); } - bool HasProperty(const HloRunnerPropertyTag::Type tag) const override { - return false; - } + bool HasProperty(HloRunnerPropertyTag::Type tag) const override; private: absl::StatusOr GenerateDefaultCompileOptions(