Skip to content

Commit

Permalink
fixed bugs with malloc async
Browse files Browse the repository at this point in the history
  • Loading branch information
weihanmines committed Jan 11, 2024
1 parent 43bd036 commit bf154a5
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
7 changes: 5 additions & 2 deletions xla/stream_executor/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,9 @@ tsl_gpu_library(
srcs = [
"gpu_cudamallocasync_allocator.cc",
],
hdrs = ["gpu_cudamallocasync_allocator.h"],
hdrs = ["gpu_cudamallocasync_allocator.h",
"gpu_types.h",
],
cuda_deps = [
"//xla/stream_executor/cuda:cuda_activation",
"//xla/stream_executor/cuda:cuda_executor",
Expand All @@ -545,7 +547,8 @@ tsl_gpu_library(
"@tsl//tsl/platform:macros",
"@tsl//tsl/platform:mutex",
"@tsl//tsl/util:env_var",
],
] + if_rocm_is_configured([
"//xla/stream_executor/rocm:rocm_activation"]),
)

cc_library(
Expand Down
11 changes: 8 additions & 3 deletions xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,16 @@ using cuuint64_t = uint64_t;
namespace stream_executor {

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
static std::string GetCudaErrorMessage(GpuStatus result) {
static std::string GetCudaErrorMessage(gpu::GpuStatus result) {
const char* error;
GpuGetErrorString(result, &error);
const char* name;
#if GOOGLE_CUDA
GpuGetErrorString(result, &error);
GpuGetErrorName(result, &name);
#elif TENSORFLOW_USE_ROCM
error = GpuGetErrorString(result);
name = GpuGetErrorName(result);
#endif
return absl::StrCat("CUDA error: ", error ? error : "<unknown>", " (",
name ? name : "Unknown", ")");
}
Expand Down Expand Up @@ -281,7 +286,7 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator(
DCHECK(all_pools_->size() == all_ids_->size());
for (int i = 0; i < all_pools_->size(); ++i) {
// Set the current pool access to the previous GPUs.
GpuMemAccessDesc map;
gpu::GpuMemAccessDesc map;
map.flags = GPU_MEM_ACCESS_FLAGS_PROT_READWRITE;
map.location.id = (*all_ids_)[i].value();

Expand Down
4 changes: 2 additions & 2 deletions xla/stream_executor/gpu/gpu_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ using GpuGraphExecHandle = hipGraphExec_t;
using GpuGraphNodeHandle = hipGraphNode_t;
using GpuGraphConditionalHandle = UnsupportedGpuFeature;
using GpuMemoryPoolHandle = hipMemPool_t;
using GpuMemAccessDes = hipMemAccessDesc;
using GpuMemAccessDesc = hipMemAccessDesc;
#else // CUDA

using GpuContextHandle = CUcontext;
Expand All @@ -85,7 +85,7 @@ using GpuGraphHandle = CUgraph;
using GpuGraphExecHandle = CUgraphExec;
using GpuGraphNodeHandle = CUgraphNode;
using GpuMemoryPoolHandle = CUmemoryPool;
using GpuMemAccessDes = CUmemAccessDesc;
using GpuMemAccessDesc = CUmemAccessDesc;

#if CUDA_VERSION >= 12030
using GpuGraphConditionalHandle = CUgraphConditionalHandle;
Expand Down

0 comments on commit bf154a5

Please sign in to comment.