Skip to content

Commit

Permalink
Add kCpu property tag.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715953981
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Jan 23, 2025
1 parent 2fe62ae commit a7799a8
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 16 deletions.
6 changes: 2 additions & 4 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4660,17 +4660,14 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
"//xla/hlo/parser:hlo_parser",
"//xla/service/gpu:gpu_executable_run_options",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_memory_allocator",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor:stream_executor_memory_allocator",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
"@tsl//tsl/platform:blocking_counter",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
],
)

Expand All @@ -4695,6 +4692,7 @@ cc_library(
"//xla/pjrt:host_memory_spaces",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_common",
"//xla/pjrt:pjrt_compiler",
"//xla/pjrt:pjrt_executable",
"//xla/pjrt:pjrt_future",
"//xla/tsl/platform:env",
Expand Down
22 changes: 16 additions & 6 deletions xla/service/hlo_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS

#include "xla/service/hlo_runner.h"

#include <memory>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include "unsupported/Eigen/CXX11/Tensor"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/layout_util.h"
#include "xla/service/executable.h"
#include "xla/service/gpu/gpu_executable_run_options.h"
#include "xla/service/hlo_module_util.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/service/transfer_manager.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/stream_executor_memory_allocator.h"
#include "tsl/platform/blocking_counter.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"

namespace xla {

Expand Down Expand Up @@ -714,4 +711,17 @@ absl::string_view HloRunner::Name() const {
return backend_->platform()->Name();
}

bool HloRunner::HasProperty(const HloRunnerPropertyTag::Type tag) const {
if (tag == HloRunnerPropertyTag::kUsingGpuRocm) {
const stream_executor::DeviceDescription& device_description =
backend().default_stream_executor()->GetDeviceDescription();
return std::holds_alternative<stream_executor::RocmComputeCapability>(
device_description.gpu_compute_capability());
}
if (tag == HloRunnerPropertyTag::kCpu) {
return backend().platform()->Name() == "Host";
}
return false;
}

} // namespace xla
4 changes: 1 addition & 3 deletions xla/service/hlo_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ class HloRunner : public HloRunnerInterface {

int device_count() const override { return backend().device_count(); }

bool HasProperty(const HloRunnerPropertyTag::Type tag) const override {
return false;
}
bool HasProperty(HloRunnerPropertyTag::Type tag) const override;

private:
absl::StatusOr<ExecutionOutput> ExecuteWithExecutionInputs(
Expand Down
4 changes: 4 additions & 0 deletions xla/service/hlo_runner_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class HloRunnerPropertyTag final {
// Default, reserved value for HloRunnerPropertyTag. Perhaps this could be
// used as a sentinel value for a tag that is not present. Do not use.
static constexpr Type kDefault = 0;
// Indicates that the runner is using ROCm.
static constexpr Type kUsingGpuRocm = 1;
// Indicates that this runner is a CPU runner.
static constexpr Type kCpu = 2;

private:
HloRunnerPropertyTag() = default;
Expand Down
11 changes: 11 additions & 0 deletions xla/service/hlo_runner_pjrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "xla/pjrt/host_memory_spaces.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/service/computation_layout.h"
Expand Down Expand Up @@ -657,4 +658,14 @@ absl::StatusOr<std::vector<Literal>> HloRunnerPjRt::ExecuteReplicatedImpl(

absl::string_view HloRunnerPjRt::Name() const { return "HloRunnerPjRt"; }

bool HloRunnerPjRt::HasProperty(const HloRunnerPropertyTag::Type tag) const {
if (tag == HloRunnerPropertyTag::kUsingGpuRocm) {
return pjrt_client_->platform_name() == xla::RocmName();
}
if (tag == HloRunnerPropertyTag::kCpu) {
return pjrt_client_->platform_name() == xla::CpuName();
}
return false;
}

} // namespace xla
4 changes: 1 addition & 3 deletions xla/service/hlo_runner_pjrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,7 @@ class HloRunnerPjRt : public HloRunnerInterface {

int device_count() const override { return pjrt_client_->device_count(); }

bool HasProperty(const HloRunnerPropertyTag::Type tag) const override {
return false;
}
bool HasProperty(HloRunnerPropertyTag::Type tag) const override;

private:
absl::StatusOr<CompileOptions> GenerateDefaultCompileOptions(
Expand Down

0 comments on commit a7799a8

Please sign in to comment.