Skip to content

Commit

Permalink
Remove use_parameter_layout_on_device.
Browse files Browse the repository at this point in the history
With the removal of calls to `UpdateEntryComputationLayout`, it turns out this
functionality is not necessary (and potentially harmful).

PiperOrigin-RevId: 718045607
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Jan 27, 2025
1 parent 1a94880 commit 0b16b62
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 49 deletions.
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4700,6 +4700,7 @@ cc_library(
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:die_if_null",
"@com_google_absl//absl/status",
Expand Down
105 changes: 63 additions & 42 deletions xla/service/hlo_runner_pjrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/base/nullability.h"
#include "absl/log/die_if_null.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
Expand Down Expand Up @@ -159,6 +160,23 @@ std::vector<std::vector<PjRtBuffer*>> BufferMatToPointerMat(
return argument_ptrs;
}

constexpr int kDeviceIdx = 0;

absl::StatusOr<absl::Nonnull<PjRtMemorySpace*>> GetMemorySpaceFromLayout(
absl::Nonnull<PjRtDevice*> const device, const Layout& layout) {
PjRtMemorySpace* memory_space = nullptr;
if (layout.memory_space() == Layout::kHostMemorySpace) {
TF_ASSIGN_OR_RETURN(memory_space, device->memory_space_by_kind(
PinnedHostMemorySpace::kKind));
} else {
TF_ASSIGN_OR_RETURN(memory_space, device->default_memory_space());
}
TF_RET_CHECK(memory_space != nullptr)
<< "Memory space " << layout.memory_space()
<< " does not exist on device " << device->id();
return memory_space;
}

} // namespace

// TODO(b/245550554): Remove the use of PjRtWrappedExecutable.
Expand Down Expand Up @@ -190,17 +208,13 @@ absl::StatusOr<ExecutionOutput> PjRtWrappedExecutable::ExecuteAsyncOnStream(
"PjRtWrappedExecutable: Unimplemented ExecuteAsyncOnStream");
}

static const int kDeviceIdx = 0;

HloRunnerPjRt::HloRunnerPjRt(
std::unique_ptr<PjRtClient> pjrt_client,
DeviceShapeRepresentationFn device_shape_representation_fn,
DeviceShapeSizeFn device_shape_size_fn,
const bool use_parameter_layout_on_device)
DeviceShapeSizeFn device_shape_size_fn)
: pjrt_client_(std::move(pjrt_client)),
device_shape_representation_fn_(device_shape_representation_fn),
device_shape_size_fn_(device_shape_size_fn),
use_parameter_layout_on_device_(use_parameter_layout_on_device) {}
device_shape_size_fn_(device_shape_size_fn) {}

HloRunnerPjRt::~HloRunnerPjRt() = default;

Expand Down Expand Up @@ -254,50 +268,55 @@ absl::StatusOr<Literal> HloRunnerPjRt::TransferLiteralFromDevice(
}

absl::StatusOr<std::unique_ptr<PjRtBuffer>>
HloRunnerPjRt::TransferLiteralToDevice(const Literal& literal,
const Layout& parameter_layout) {
auto devices = pjrt_client_->addressable_devices();
PjRtDevice* device = devices[kDeviceIdx];

auto get_pjrt_memory_space = [](PjRtDevice* pjrt_device,
int64_t xla_memory_space) {
if (xla_memory_space == Layout::kHostMemorySpace) {
return pjrt_device->memory_space_by_kind(PinnedHostMemorySpace::kKind);
}
return pjrt_device->default_memory_space();
};
TF_ASSIGN_OR_RETURN(
PjRtMemorySpace * pjrt_memory_space,
get_pjrt_memory_space(device, parameter_layout.memory_space()));
TF_ASSIGN_OR_RETURN(
auto assignment,
use_parameter_layout_on_device_
? pjrt_client_->BufferFromHostLiteral(literal, pjrt_memory_space,
&parameter_layout)
: pjrt_client_->BufferFromHostLiteral(literal, pjrt_memory_space));
return std::move(assignment);
HloRunnerPjRt::TransferLiteralToDevice(
const Literal& literal, absl::Nonnull<PjRtMemorySpace*> const memory_space,
const Layout& on_device_layout) {
// Whenever possible, we want to respect the provided on-device layout. This
// layout was either provided by the user or was inferred by the compiler. The
// runtime should ideally not select a layout of its own accord.
//
// Not all clients implement this functionality.
absl::StatusOr<std::unique_ptr<PjRtBuffer>> buffer =
pjrt_client_->BufferFromHostLiteral(literal, memory_space,
&on_device_layout);
if (buffer.ok() || !absl::IsUnimplemented(buffer.status())) {
return buffer;
}
// Fall back to the two-argument version of BufferFromHostLiteral.
return pjrt_client_->BufferFromHostLiteral(literal, memory_space);
}

absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
HloRunnerPjRt::TransferLiteralsToDevice(
const ComputationLayout& entry_layout,
absl::Span<const Literal* const> literals) {
// Note: This function is used for single (default) device execution.
if (pjrt_client_->addressable_device_count() <= kDeviceIdx) {
return absl::InternalError("No addressable devices available");
}
PjRtDevice* device = pjrt_client_->addressable_devices()[kDeviceIdx];
TF_RET_CHECK(device != nullptr)
<< "Device with ordinal " << kDeviceIdx << " is null.";

TF_ASSIGN_OR_RETURN(bool flatten, MustFlattenInputTuple(entry_layout));
TF_ASSIGN_OR_RETURN(std::vector<Layout> parameter_layouts,
entry_layout.FlattenedParameterLayouts());

auto transfer_literals = [&parameter_layouts, this](
absl::Span<const Literal* const> input_literals)
auto transfer_literals =
[&, this](absl::Span<const Literal* const> input_literals)
-> absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> {
TF_RET_CHECK(parameter_layouts.size() == input_literals.size());
std::vector<std::unique_ptr<PjRtBuffer>> buffers;
buffers.reserve(input_literals.size());
for (int i = 0; i < input_literals.size(); ++i) {
const Literal* literal = input_literals[i];
TF_RET_CHECK(literal != nullptr);
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> buffer,
TransferLiteralToDevice(*literal, parameter_layouts[i]));
const Layout& on_device_layout = parameter_layouts[i];
TF_ASSIGN_OR_RETURN(absl::Nonnull<PjRtMemorySpace*> memory_space,
GetMemorySpaceFromLayout(device, on_device_layout));
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> buffer,
TransferLiteralToDevice(*literal, memory_space,
parameter_layouts[i]));
TF_RETURN_IF_ERROR(buffer->GetReadyFuture().Await());
buffers.push_back(std::move(buffer));
}
Expand All @@ -321,8 +340,10 @@ absl::StatusOr<Literal> HloRunnerPjRt::Execute(
std::unique_ptr<HloModule> module,
absl::Span<const Literal* const> arguments, bool run_hlo_passes,
ExecutionProfile* profile) {
// TODO (b/245550554) : Remove UpdateEntryComputationLayout from runner.
UpdateEntryComputationLayout(module.get());
if (run_hlo_passes) {
// TODO - b/391868033: Remove calls to UpdateEntryComputationLayout.
UpdateEntryComputationLayout(module.get());
}
TF_ASSIGN_OR_RETURN(auto executable,
CreateExecutable(std::move(module), run_hlo_passes));

Expand Down Expand Up @@ -409,7 +430,10 @@ absl::StatusOr<std::unique_ptr<Executable>> HloRunnerPjRt::CreateExecutable(
absl::StatusOr<std::vector<Literal>> HloRunnerPjRt::ExecuteReplicated(
std::unique_ptr<HloModule> module,
const HloRunnerInterface::ReplicatedExecuteOptions& options) {
UpdateEntryComputationLayout(module.get());
if (options.run_hlo_passes) {
// TODO - b/391868033: Remove calls to UpdateEntryComputationLayout.
UpdateEntryComputationLayout(module.get());
}

TF_ASSIGN_OR_RETURN(
auto device_assignment,
Expand Down Expand Up @@ -560,12 +584,9 @@ absl::StatusOr<std::vector<Literal>> HloRunnerPjRt::ExecuteReplicatedImpl(
TF_RET_CHECK(argument != nullptr);
TF_ASSIGN_OR_RETURN(PjRtMemorySpace * memory_space,
device_ptr->default_memory_space());
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> assignment,
use_parameter_layout_on_device_
? pjrt_client_->BufferFromHostLiteral(*argument, memory_space,
&argument->shape().layout())
: pjrt_client_->BufferFromHostLiteral(*argument, memory_space));
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> assignment,
TransferLiteralToDevice(*argument, memory_space,
argument->shape().layout()));
replica_buffers.push_back(std::move(assignment));
}
argument_buffer_slices.push_back(std::move(replica_buffers));
Expand Down
9 changes: 5 additions & 4 deletions xla/service/hlo_runner_pjrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <memory>
#include <vector>

#include "absl/base/nullability.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
Expand All @@ -47,14 +48,14 @@ class HloRunnerPjRt : public HloRunnerInterface {
explicit HloRunnerPjRt(
std::unique_ptr<PjRtClient> pjrt_client,
DeviceShapeRepresentationFn device_shape_representation_fn,
DeviceShapeSizeFn device_shape_size_fn,
bool use_parameter_layout_on_device = false);
DeviceShapeSizeFn device_shape_size_fn);

~HloRunnerPjRt() override;

// Transfers data between the host and device.
absl::StatusOr<std::unique_ptr<PjRtBuffer>> TransferLiteralToDevice(
const Literal& literal, const Layout& parameter_layout);
const Literal& literal, absl::Nonnull<PjRtMemorySpace*> memory_space,
const Layout& on_device_layout);
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
TransferLiteralsToDevice(const ComputationLayout& entry_layout,
absl::Span<const Literal* const> literals);
Expand Down Expand Up @@ -112,6 +113,7 @@ class HloRunnerPjRt : public HloRunnerInterface {
absl::string_view Name() const override;

void UpdateEntryComputationLayout(HloModule* module) {
// TODO - b/391868033: Remove UpdateEntryComputationLayout from this class.
xla::UpdateEntryComputationLayout(module, device_shape_representation_fn_);
}

Expand Down Expand Up @@ -143,7 +145,6 @@ class HloRunnerPjRt : public HloRunnerInterface {
std::unique_ptr<PjRtClient> pjrt_client_;
DeviceShapeRepresentationFn device_shape_representation_fn_;
DeviceShapeSizeFn device_shape_size_fn_;
bool use_parameter_layout_on_device_ = false;
};

} // namespace xla
Expand Down
3 changes: 1 addition & 2 deletions xla/tests/hlo_pjrt_interpreter_reference_mixin.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ class HloPjRtInterpreterReferenceMixin
std::make_unique<HloRunnerPjRt>(
std::make_unique<InterpreterClient>(),
InterpreterClient::DeviceShapeRepresentation,
InterpreterClient::ShapeSizeBytes,
/*use_parameter_layout_on_device=*/true),
InterpreterClient::ShapeSizeBytes),
std::forward<BaseArgs>(base_args)...) {}
~HloPjRtInterpreterReferenceMixin() override = default;
};
Expand Down
2 changes: 1 addition & 1 deletion xla/tests/hlo_runner_agnostic_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ HloRunnerAgnosticTestBase::ParseAndReturnVerifiedModule(
instruction_can_change_layout_func());
TF_RETURN_IF_ERROR(
module->ParseHloStringAndVerifyModule(hlo_text, parser_options));
UpdateEntryComputationLayout(module.get());
return std::move(module);
}

Expand All @@ -109,6 +108,7 @@ HloRunnerAgnosticTestBase::AddEntryComputationAndUpdateEntryComputationLayout(

void HloRunnerAgnosticTestBase::UpdateEntryComputationLayout(
HloModule* const module) const {
// TODO - b/391868033: Remove UpdateEntryComputationLayout from this class.
xla::UpdateEntryComputationLayout(
module, test_runner_->device_shape_representation_fn());
}
Expand Down
4 changes: 4 additions & 0 deletions xla/tests/hlo_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,16 @@ absl::StatusOr<std::unique_ptr<HloModule>> HloTestBase::GetOptimizedModule(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo, GetModuleConfigForTest()));
// TODO - b/391868033: Remove calls to UpdateEntryComputationLayout.
UpdateEntryComputationLayout(module.get());
return backend().compiler()->RunHloPasses(
std::move(module), backend().default_stream_executor(), GetAllocator());
}

absl::StatusOr<std::unique_ptr<HloModule>> HloTestBase::GetOptimizedModule(
std::unique_ptr<HloModule> hlo_module) {
// TODO - b/391868033: Remove calls to UpdateEntryComputationLayout.
UpdateEntryComputationLayout(hlo_module.get());
return backend().compiler()->RunHloPasses(std::move(hlo_module),
backend().default_stream_executor(),
GetAllocator());
Expand Down

0 comments on commit 0b16b62

Please sign in to comment.