Skip to content

Commit

Permalink
Switch CollectiveMemoryAllocation to use CreateMemoryAllocator infras…
Browse files Browse the repository at this point in the history
…tructure.

PiperOrigin-RevId: 720275726
  • Loading branch information
klucke authored and Google-ML-Automation committed Jan 28, 2025
1 parent f6f7a0a commit 191dbc1
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 40 deletions.
9 changes: 6 additions & 3 deletions xla/pjrt/gpu/gpu_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,12 @@ absl::StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateCollectiveBFCAllocator(
se::StreamExecutor* executor, double memory_fraction,
size_t collective_memory_size) {
int device_ordinal = executor->device_ordinal();
auto sub_allocator = std::make_unique<se::DeviceMemAllocator>(
executor, tsl::PlatformDeviceId(device_ordinal),
/*memory_type=*/stream_executor::MemoryType::kCollective);
TF_ASSIGN_OR_RETURN(auto collective_memory_allocator,
executor->CreateMemoryAllocator(
stream_executor::MemoryType::kCollective));
auto sub_allocator = std::make_unique<se::StreamExecutorAllocator>(
std::move(collective_memory_allocator),
/*memory_type=*/stream_executor::MemoryType::kCollective, device_ordinal);

int64_t free_memory;
int64_t total_memory;
Expand Down
21 changes: 21 additions & 0 deletions xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,27 @@ CudaExecutor::CreateMemoryAllocator(MemoryType type) {
}
});
});
} else if (type == MemoryType::kCollective) {
return std::make_unique<GenericMemoryAllocator>(
[this](uint64_t size)
-> absl::StatusOr<std::unique_ptr<MemoryAllocation>> {
TF_ASSIGN_OR_RETURN(
void* ptr, CudaCollectives::CollectiveMemoryAllocate(this, size));
VLOG(2) << "allocated " << ptr << " for context " << cuda_context_
<< " of " << size << " bytes of collective memory";
return std::make_unique<GenericMemoryAllocation>(
ptr, size, [this](void* location, uint64_t size) {
auto status =
CudaCollectives::CollectiveMemoryDeallocate(this, location);
if (!status.ok()) {
LOG(ERROR) << "failed to free collective memory at "
<< location << "; result: " << status;
} else {
VLOG(2) << "deallocated collective memory at " << location
<< " for context " << cuda_context_;
}
});
});
}
return absl::UnimplementedError(
absl::StrFormat("Unsupported memory type %d", type));
Expand Down
8 changes: 0 additions & 8 deletions xla/stream_executor/cuda/cuda_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,6 @@ class CudaExecutor : public GpuExecutor {
absl::StatusOr<DeviceMemoryBase> GetMemoryRange(
const DeviceMemoryBase& location) override;

absl::StatusOr<void*> CollectiveMemoryAllocate(uint64_t size) override {
return CudaCollectives::CollectiveMemoryAllocate(this, size);
}

absl::Status CollectiveMemoryDeallocate(void* location) override {
return CudaCollectives::CollectiveMemoryDeallocate(this, location);
}

absl::StatusOr<std::unique_ptr<EventBasedTimer>> CreateEventBasedTimer(
Stream* stream, bool use_delay_kernel) override;
absl::StatusOr<DeviceMemoryBase> GetSymbol(
Expand Down
12 changes: 10 additions & 2 deletions xla/stream_executor/cuda/cuda_executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ TEST(CudaExecutorTest, CreateUnifiedMemoryAllocatorWorks) {
allocation.reset();
}

TEST(CudaExecutorTest, CreateCollectiveMemoryAllocatorWorks) {
TF_ASSERT_OK_AND_ASSIGN(Platform * platform,
PlatformManager::PlatformWithName("CUDA"));
TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor,
platform->ExecutorForDevice(0));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<MemoryAllocator> allocator,
executor->CreateMemoryAllocator(MemoryType::kCollective));
}

TEST(CudaExecutorTest, CreateUnsupportedMemoryAllocatorsFail) {
TF_ASSERT_OK_AND_ASSIGN(Platform * platform,
PlatformManager::PlatformWithName("CUDA"));
Expand All @@ -116,8 +126,6 @@ TEST(CudaExecutorTest, CreateUnsupportedMemoryAllocatorsFail) {
EXPECT_THAT(executor->CreateMemoryAllocator(MemoryType::kHost), Not(IsOk()));
EXPECT_THAT(executor->CreateMemoryAllocator(MemoryType::kDevice),
Not(IsOk()));
EXPECT_THAT(executor->CreateMemoryAllocator(MemoryType::kCollective),
Not(IsOk()));
}
} // namespace
} // namespace stream_executor::gpu
12 changes: 3 additions & 9 deletions xla/stream_executor/integrations/device_mem_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class DeviceMemAllocator : public tsl::SubAllocator {
memory_type_(memory_type) {
CHECK(stream_exec_ != nullptr);
CHECK(memory_type_ != MemoryType::kUnified);
CHECK(memory_type_ != MemoryType::kCollective);
}

~DeviceMemAllocator() override = default;
Expand All @@ -51,11 +52,7 @@ class DeviceMemAllocator : public tsl::SubAllocator {
void* ptr = nullptr;
*bytes_received = num_bytes;
if (num_bytes > 0) {
if (memory_type_ == MemoryType::kCollective) {
auto status_or = stream_exec_->CollectiveMemoryAllocate(num_bytes);
CHECK(status_or.ok()) << status_or.status().message();
ptr = status_or.value();
} else if (memory_type_ == MemoryType::kHost) {
if (memory_type_ == MemoryType::kHost) {
// Convert size_t to long unsigned int
long unsigned int value = static_cast<long unsigned int>(num_bytes);
auto status_or = stream_exec_->HostMemoryAllocate(value);
Expand All @@ -73,10 +70,7 @@ class DeviceMemAllocator : public tsl::SubAllocator {

if (ptr != nullptr) {
VisitFree(ptr, device_id_.value(), num_bytes);
if (memory_type_ == MemoryType::kCollective) {
auto status = stream_exec_->CollectiveMemoryDeallocate(ptr);
CHECK(status.ok()) << status.message();
} else if (memory_type_ == MemoryType::kHost) {
if (memory_type_ == MemoryType::kHost) {
stream_exec_->HostMemoryDeallocate(ptr);
} else {
DeviceMemoryBase device_ptr(ptr);
Expand Down
4 changes: 0 additions & 4 deletions xla/stream_executor/mock_stream_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ class MockStreamExecutor : public StreamExecutor {
MOCK_METHOD(DeviceMemoryBase, Allocate, (uint64_t size, int64_t memory_space),
(override));
MOCK_METHOD(void, Deallocate, (DeviceMemoryBase * mem), (override));
MOCK_METHOD(absl::StatusOr<void*>, CollectiveMemoryAllocate, (uint64_t size),
(override));
MOCK_METHOD(absl::Status, CollectiveMemoryDeallocate, (void* mem),
(override));
MOCK_METHOD(absl::StatusOr<std::unique_ptr<MemoryAllocation>>,
HostMemoryAllocate, (uint64_t size), (override));
MOCK_METHOD(void, HostMemoryDeallocate, (void* mem), (override));
Expand Down
14 changes: 0 additions & 14 deletions xla/stream_executor/stream_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,6 @@ class StreamExecutor {
// Deallocation of a nullptr-representative value is permitted.
virtual void Deallocate(DeviceMemoryBase* mem) = 0;

// Allocates collective device memory using ncclMemAlloc.
// See
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html
// for more details on User Buffer Registration.
virtual absl::StatusOr<void*> CollectiveMemoryAllocate(uint64_t size) {
return absl::UnimplementedError("Not implemented");
}

// Deallocates collective device memory previously allocated with
// CollectiveMemoryAllocate.
virtual absl::Status CollectiveMemoryDeallocate(void* mem) {
return absl::UnimplementedError("Not implemented");
}

// Allocates a region of host memory and registers it with the platform API.
// Memory allocated in this manner is required for use in asynchronous memcpy
// operations, such as Stream::Memcpy.
Expand Down

0 comments on commit 191dbc1

Please sign in to comment.