diff --git a/xla/pjrt/gpu/BUILD b/xla/pjrt/gpu/BUILD index d2fe121d50199..4414ce51dd89a 100644 --- a/xla/pjrt/gpu/BUILD +++ b/xla/pjrt/gpu/BUILD @@ -32,16 +32,17 @@ cc_library( "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/integrations:device_mem_allocator", + "//xla/stream_executor/integrations:stream_executor_allocator", "//xla/tsl/framework:allocator", "//xla/tsl/framework:bfc_allocator", "//xla/tsl/framework:device_id_impl", + "//xla/tsl/platform:statusor", "//xla/tsl/util:env_var", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/pjrt/gpu/gpu_helpers.cc b/xla/pjrt/gpu/gpu_helpers.cc index e604f771ebacf..bc0f17fcc3ed4 100644 --- a/xla/pjrt/gpu/gpu_helpers.cc +++ b/xla/pjrt/gpu/gpu_helpers.cc @@ -37,14 +37,15 @@ limitations under the License. #include "xla/service/platform_util.h" #include "xla/stream_executor/integrations/device_host_allocator.h" #include "xla/stream_executor/integrations/device_mem_allocator.h" +#include "xla/stream_executor/integrations/stream_executor_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/framework/allocator.h" #include "xla/tsl/framework/bfc_allocator.h" #include "xla/tsl/framework/device_id.h" +#include "xla/tsl/platform/statusor.h" #include "xla/tsl/util/env_var.h" #include "xla/util.h" -#include "tsl/platform/statusor.h" namespace xla { @@ -98,11 +99,20 @@ absl::StatusOr> CreateBFCAllocator( } int device_ordinal = executor->device_ordinal(); - auto sub_allocator = std::make_unique( - executor, tsl::PlatformDeviceId(device_ordinal), - /*memory_type=*/ - enable_unified_memory ? stream_executor::MemoryType::kUnified - : stream_executor::MemoryType::kDevice); + std::unique_ptr sub_allocator; + + if (enable_unified_memory) { + TF_ASSIGN_OR_RETURN( + auto unified_memory_allocator, + executor->CreateMemoryAllocator(stream_executor::MemoryType::kUnified)); + sub_allocator = std::make_unique( + std::move(unified_memory_allocator), + stream_executor::MemoryType::kUnified, device_ordinal); + } else { + sub_allocator = std::make_unique( + executor, tsl::PlatformDeviceId(device_ordinal), + stream_executor::MemoryType::kDevice); + } int64_t free_memory; int64_t total_memory; diff --git a/xla/stream_executor/integrations/stream_executor_allocator.h b/xla/stream_executor/integrations/stream_executor_allocator.h index cce17df9325ef..312de7bbba961 100644 --- a/xla/stream_executor/integrations/stream_executor_allocator.h +++ b/xla/stream_executor/integrations/stream_executor_allocator.h @@ -36,8 +36,8 @@ class StreamExecutorAllocator : public tsl::SubAllocator { public: StreamExecutorAllocator(std::unique_ptr memory_allocator, MemoryType memory_type, int index, - const std::vector& alloc_visitors, - const std::vector& free_visitors); + const std::vector& alloc_visitors = {}, + const std::vector& free_visitors = {}); ~StreamExecutorAllocator() override = default; void* Alloc(size_t alignment, size_t num_bytes,