Skip to content

Commit

Permalink
Merge pull request #96 from ROCm/r0.4.30-blas-and-dots-fixes
Browse files Browse the repository at this point in the history
Avoid lazy init of blas handles, fix for non-canonical dots, rocBlas ResetStream
  • Loading branch information
hsharsha authored Jan 23, 2025
2 parents a3c22a0 + 1e29fe9 commit d6ee835
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 29 deletions.
8 changes: 1 addition & 7 deletions xla/service/gpu/gemm_algorithm_picker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,7 @@ class GemmAutotuner {
explicit GemmAutotuner(const AutotuneConfig& autotune_config)
: autotune_config_(autotune_config) {}

~GemmAutotuner() {
if (stream_ != nullptr) {
if (auto blas = stream_->parent()->AsBlas()) blas->ResetStream();
}
}

const AutotuneConfig& config() { return autotune_config_; }
const AutotuneConfig& config() const { return autotune_config_; }

size_t num_algorithms_left() const { return num_algorithms_left_; }

Expand Down
24 changes: 21 additions & 3 deletions xla/service/gpu/gemm_algorithm_picker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class GemmAlgorithmPickerTest : public HloTestBase,
se::StreamExecutor *stream_exec() {
return backend().default_stream_executor();
}
const se::DeviceDescription& gpu_device_desc() {
const se::DeviceDescription& device_desc() {
return stream_exec()->GetDeviceDescription();
}
const se::GpuComputeCapability& gpu_comp() {
return gpu_device_desc().gpu_compute_capability();
return device_desc().gpu_compute_capability();
}

void SetUp() override {
Expand All @@ -82,6 +82,15 @@ class GemmAlgorithmPickerTest : public HloTestBase,
}
};

TEST_P(GemmAlgorithmPickerTest, BlasGetVersion) {
auto* blas = backend().default_stream_executor()->AsBlas();
ASSERT_TRUE(blas != nullptr);
std::string version;
ASSERT_TRUE(blas->GetVersion(&version).ok());
VLOG(0) << "Blas version: " << version;
ASSERT_TRUE(!version.empty());
}

TEST_P(GemmAlgorithmPickerTest, SkipAlgorithmsWithAccuracyCheck) {
constexpr absl::string_view kHlo = R"(
HloModule module
Expand Down Expand Up @@ -117,6 +126,15 @@ TF_ASSERT_OK_AND_ASSIGN(auto module,
if(num_left1 < 2) {
GTEST_SKIP() << "Too few algorithms left after the first step";
}

// Test that the function to get current stream value works fine:
auto* blas = stream_exec()->AsBlas();
ASSERT_TRUE(blas != nullptr);
TF_ASSERT_OK_AND_ASSIGN(bool is_main_stream, blas->IsMainStreamSet());
// ROCM only: CUDA blas API does not reset stream after each blas call.
if (std::holds_alternative<se::RocmComputeCapability>(gpu_comp())) {
ASSERT_TRUE(is_main_stream);
}
}

// Clear cache before the second run!
Expand Down Expand Up @@ -254,7 +272,7 @@ ENTRY main {
changed = false;

DevicelessConfig deviceless_config{
gpu_device_desc().model_str(), gpu_comp()};
device_desc().model_str(), gpu_comp()};
AutotuneConfig deviceless_cfg{deviceless_config, opts};
TF_ASSERT_OK_AND_ASSIGN(
changed,
Expand Down
1 change: 1 addition & 0 deletions xla/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ cc_library(
"//xla/stream_executor/platform",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
Expand Down
10 changes: 5 additions & 5 deletions xla/stream_executor/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ limitations under the License.
#include <type_traits>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/stream_executor/data_type.h"
#include "xla/stream_executor/device_memory.h"
Expand Down Expand Up @@ -221,9 +221,10 @@ class BlasSupport {
virtual ~BlasSupport() {}

virtual gpu::BlasLt *GetBlasLt() = 0;
// resets the underlying blas stream to its default value
virtual bool ResetStream() = 0;

// For tests only: sets *is_main_stream to true if the underlying Blas library
// has stream 0 set as its current stream.
virtual absl::StatusOr<bool> IsMainStreamSet() const = 0;
// Performs a BLAS y <- ax+y operation.
virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha,
const DeviceMemory<float> &x, int incx,
Expand All @@ -233,7 +234,6 @@ class BlasSupport {
virtual bool DoBlasCopy(Stream *stream, uint64_t elem_count,
const DeviceMemory<float> &x, int incx,
DeviceMemory<float> *y, int incy) = 0;

// Computes the product of a vector by a scalar: x <- a*x.
virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha,
DeviceMemory<float> *x, int incx) = 0;
Expand Down Expand Up @@ -750,13 +750,13 @@ class BlasSupport {
// Macro used to quickly declare overrides for abstract virtuals in the
// BlasSupport base class.
#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \
absl::StatusOr<bool> IsMainStreamSet() const override; \
bool DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha, \
const DeviceMemory<float> &x, int incx, \
DeviceMemory<float> *y, int incy) override; \
bool DoBlasCopy(Stream *stream, uint64_t elem_count, \
const DeviceMemory<float> &x, int incx, \
DeviceMemory<float> *y, int incy) override; \
bool ResetStream() override; \
bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, \
DeviceMemory<float> *x, int incx) override; \
bool DoBlasScal(Stream *stream, uint64_t elem_count, double alpha, \
Expand Down
1 change: 1 addition & 0 deletions xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ cc_library(
"//xla/stream_executor/platform:dso_loader",
"//xla/tsl/util:determinism_hdr_lib",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
Expand Down
33 changes: 20 additions & 13 deletions xla/stream_executor/rocm/rocm_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ limitations under the License.

#include <complex>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive
#include "rocm/rocm_config.h"
Expand Down Expand Up @@ -149,25 +151,28 @@ ROCMBlas::~ROCMBlas() {
}
}

bool ROCMBlas::ResetStream() {
absl::MutexLock lock{&mu_};
return SetStream(nullptr);
}

bool ROCMBlas::SetStream(Stream *stream) {
CHECK(blas_ != nullptr);
gpu::ScopedActivateExecutorContext sac{parent_};

GpuStreamHandle handle = (stream != nullptr) ? AsGpuStreamValue(stream) : 0;

if (auto ret = wrap::rocblas_set_stream(blas_, handle);
auto handle = (stream != nullptr) ? AsGpuStreamValue(stream) : nullptr;
if (auto ret = wrap::rocblas_set_stream(blas_, handle);
ret != rocblas_status_success) {
LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret);
return false;
}
return true;
}

absl::StatusOr<bool> ROCMBlas::IsMainStreamSet() const {
absl::MutexLock lock{&mu_};
CHECK(blas_ != nullptr);
GpuStreamHandle handle{};
if (auto ret = wrap::rocblas_get_stream(blas_, &handle);
ret != rocblas_status_success) {
return absl::InternalError("failed to get the current stream value");
}
return (handle == nullptr);
}

namespace {

// Helper functions transforming blas arguments into rocBLAS arguments.
Expand Down Expand Up @@ -343,12 +348,12 @@ absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
absl::MutexLock lock{&mu_};

CHECK(blas_ != nullptr);
gpu::ScopedActivateExecutorContext sac{parent_};
if (!SetStream(stream)) {
return absl::InternalError("Setting stream failed");
}

gpu::ScopedActivateExecutorContext sac{parent_};

rocblas_status ret;
// set the atomics mode, leaving default to library
bool allow_atomics = !OpDeterminismRequired();
if (!allow_atomics) {
Expand All @@ -371,7 +376,9 @@ absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
}
#endif

auto ret = rocblas_func(blas_, std::forward<Args>(args)...);
ret = rocblas_func(blas_, std::forward<Args>(args)...);
SetStream(nullptr); // Resetting stream after the function call

if (ret != rocblas_status_success) {
auto err_str =
absl::StrFormat("%s failed with: %s", FuncT::kName, ToString(ret));
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/rocm/rocm_blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class ROCMBlas : public blas::BlasSupport {
ScratchAllocator *scratch_allocator);

// mutex that guards the rocBLAS handle for this device.
absl::Mutex mu_;
mutable absl::Mutex mu_;

// GpuExecutor which instantiated this ROCMBlas.
// Immutable post-initialization.
Expand Down
1 change: 1 addition & 0 deletions xla/tools/multihost_hlo_runner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ xla_test(
"data/sharded_2_devices.hlo",
"data/single_device.hlo",
"data/single_device_tupled.hlo",
"data/sharded_computation.hlo",
],
tags = ["nomac"],
deps = [
Expand Down

0 comments on commit d6ee835

Please sign in to comment.