Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused conv fix and revert memory management #21

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3990,6 +3990,7 @@ cc_library(
":conv_algorithm_picker",
":cublas_pad_for_gemms",
":cublas_padding_requirements",
":cudnn_fused_conv_rewriter",
":cusolver_rewriter",
":gemm_algorithm_picker",
":gemm_rewriter",
Expand Down Expand Up @@ -4598,7 +4599,6 @@ cc_library(
name = "cudnn_fused_conv_rewriter",
srcs = ["cudnn_fused_conv_rewriter.cc"],
hdrs = ["cudnn_fused_conv_rewriter.h"],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
deps = [
":backend_configs_cc",
":cublas_cudnn",
Expand Down Expand Up @@ -4627,10 +4627,7 @@ cc_library(
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:ml_dtypes",
"@tsl//tsl/platform:statusor",
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudnn_header",
]),
],
)

xla_test(
Expand All @@ -4646,13 +4643,15 @@ xla_test(
backends = [
"gpu_a100",
] + if_oss(["gpu"]),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) +
if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]),
shard_count = 10,
deps = [
":backend_configs_cc",
":cublas_cudnn",
":cudnn_fused_conv_rewriter",
":gpu_conv_rewriter",
":stream_executor_util",
"//xla:comparison_util",
"//xla:error_spec",
"//xla/hlo/ir:hlo",
Expand All @@ -4667,6 +4666,7 @@ xla_test(
"//xla/service:reshape_mover",
"//xla/service/gpu/tests:gpu_codegen_test",
"//xla/stream_executor:device_description",
"//xla/stream_executor:stream_executor_headers",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
Expand All @@ -4681,6 +4681,8 @@ xla_test(
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudnn_header",
]) + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers"
]),
)

Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/amdgpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "xla/service/float_normalization.h"
#include "xla/service/gpu/autotuner_util.h"
#include "xla/service/gpu/conv_algorithm_picker.h"
#include "xla/service/gpu/cudnn_fused_conv_rewriter.h"
#include "xla/service/gpu/cublas_pad_for_gemms.h"
#include "xla/service/gpu/cublas_padding_requirements.h"
#include "xla/service/gpu/cusolver_rewriter.h"
Expand Down Expand Up @@ -109,6 +110,9 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
pipeline.AddPass<GpusolverRewriter>();
pipeline.AddPass<GpuConvRewriter>();
pipeline.AddPass<GpuConvPaddingLegalization>();
auto rcc = std::get<se::RocmComputeCapability>(gpu_version);
pipeline.AddPass<CudnnFusedConvRewriter>(rcc, dnn_version,
0);

// The conv padding/vectorization passes which we need to get rid of. They
// also leave behind unnecessary tuple/get-tuple-element pairs that
Expand Down
66 changes: 41 additions & 25 deletions xla/service/gpu/cudnn_fused_conv_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,23 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "xla/comparison_util.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/literal.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_description.h"
#include "xla/util.h"
#include "tsl/platform/ml_dtypes.h"

#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cudnn/cudnn.h"
#endif

#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/primitive_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/service/pattern_matcher.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/dnn.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/ml_dtypes.h"
#include "tsl/platform/statusor.h"

namespace xla {
Expand Down Expand Up @@ -96,6 +91,10 @@ bool IsNonDepthwiseConvCustomCall(const HloInstruction* instr) {
return IsConvCustomCall(instr) && !IsConvDepthwise(instr);
}

bool IsROCm(se::GpuComputeCapability cc) {
return std::holds_alternative<se::RocmComputeCapability>(cc);
}

// elu, relu6, and leaky-relu activations are supported in cudnn via the
// "runtime fusion" engine, which JIT compiles C++ code. This can be slow to
// compile, so we guard it with a debug option.
Expand All @@ -106,8 +105,12 @@ bool IsNonDepthwiseConvCustomCall(const HloInstruction* instr) {
// Note that as of writing, xla_gpu_use_runtime_fusion is disabled by default
// due to apparent bugs in cudnn 8.9.0. See debug_options_flags.cc for details.
bool ShouldUseCudnnRuntimeFusion(const DebugOptions& debug_opts,
se::CudaComputeCapability cc) {
return debug_opts.xla_gpu_use_runtime_fusion() && cc.IsAtLeast(7, 5);
se::GpuComputeCapability cc) {
const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(&cc);
if(cuda_cc != nullptr)
return debug_opts.xla_gpu_use_runtime_fusion() && cuda_cc->IsAtLeast(7, 5);
else
return true;
}

bool IsSuitableForCudnnRuntimeFusion(HloInstruction* conv) {
Expand Down Expand Up @@ -658,10 +661,17 @@ CaptureConvGraph(HloInstruction* instr, HloInstruction* convolution,
// 5. Optionally calculate the maximum of the absolute of the result.
// 6. Optionally cast the output back to FP8.
absl::StatusOr<bool> F8GraphConv(HloComputation* comp,
se::CudaComputeCapability cc) {
se::CudaComputeCapability cc,
se::dnn::VersionInfo dnn_version,
int32_t toolkit_version) {
bool changed = false;

#if CUDA_VERSION >= 12000 && CUDNN_VERSION >= 8900
if (dnn_version < se::dnn::VersionInfo(8, 9, 0)) {
return false;
}
if (toolkit_version < 12000) {
return false;
}
if (!cc.IsAtLeast(se::CudaComputeCapability::HOPPER)) {
return false;
}
Expand Down Expand Up @@ -759,7 +769,6 @@ absl::StatusOr<bool> F8GraphConv(HloComputation* comp,
changed = true;
}
}
#endif // CUDA_VERSION >= 12000 && CUDNN_VERSION >= 8900
return changed;
}

Expand Down Expand Up @@ -984,7 +993,7 @@ absl::StatusOr<bool> FuseSideInputAlpha(HloComputation* comp) {
}

absl::StatusOr<bool> FuseElu(HloComputation* comp,
se::CudaComputeCapability cc) {
se::GpuComputeCapability cc) {
if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
cc)) {
return false;
Expand Down Expand Up @@ -1085,7 +1094,7 @@ absl::StatusOr<bool> FuseRelu(HloComputation* comp) {
}

absl::StatusOr<bool> FuseRelu6(HloComputation* comp,
se::CudaComputeCapability cc) {
se::GpuComputeCapability cc) {
if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
cc)) {
return false;
Expand Down Expand Up @@ -1134,7 +1143,7 @@ absl::StatusOr<bool> FuseRelu6(HloComputation* comp,
}

absl::StatusOr<bool> FuseLeakyRelu(HloComputation* comp,
se::CudaComputeCapability cc) {
se::GpuComputeCapability cc) {
if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
cc)) {
return false;
Expand Down Expand Up @@ -1254,7 +1263,10 @@ absl::StatusOr<bool> FuseConvertToF16(HloComputation* comp) {
return changed;
}

absl::StatusOr<bool> FuseConvertToS8(HloComputation* comp) {
absl::StatusOr<bool> FuseConvertToS8(HloComputation* comp,
se::GpuComputeCapability cc) {
if(IsROCm(cc))
return false;
bool changed = false;
for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
HloInstruction* gte = nullptr;
Expand Down Expand Up @@ -1480,9 +1492,13 @@ absl::StatusOr<bool> CudnnFusedConvRewriter::Run(
bool changed = false;
// Rewrite FP8 convolutions and supported adjacent pointwise ops into a
// ForwardGraph Custom Call.
TF_ASSIGN_OR_RETURN(changed, F8GraphConv(comp, compute_capability_));
if (changed) {
return changed;
if(!IsROCm(compute_capability_)) {
auto cc = std::get<se::CudaComputeCapability>(compute_capability_);
TF_ASSIGN_OR_RETURN(
changed, F8GraphConv(comp, cc, dnn_version_, toolkit_version_));
if (changed) {
return changed;
}
}
// Fuse "inside out" starting with the operations closest to the conv.
TF_ASSIGN_OR_RETURN(changed, FuseRemoveConvertInConv(comp));
Expand Down Expand Up @@ -1516,7 +1532,7 @@ absl::StatusOr<bool> CudnnFusedConvRewriter::Run(
TF_ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp));
any_changed |= changed;

TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp));
TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp, compute_capability_));
any_changed |= changed;

// f16 convs' bias+side-input can appear before or after conversion to f16.
Expand Down
21 changes: 18 additions & 3 deletions xla/service/gpu/cudnn_fused_conv_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ limitations under the License.
#ifndef XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_
#define XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_

#include <cstdint>

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/hlo_pass_interface.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/dnn.h"

namespace xla {
namespace gpu {
Expand Down Expand Up @@ -98,8 +101,18 @@ namespace gpu {
// pass returns an error -- cudnn will not be able to run it.
class CudnnFusedConvRewriter : public HloModulePass {
public:
explicit CudnnFusedConvRewriter(se::CudaComputeCapability cc)
: compute_capability_(cc) {}
CudnnFusedConvRewriter(se::CudaComputeCapability cc,
se::dnn::VersionInfo dnn_version,
int32_t toolkit_version)
: compute_capability_(cc),
dnn_version_(dnn_version),
toolkit_version_(toolkit_version) {}
CudnnFusedConvRewriter(se::RocmComputeCapability cc,
se::dnn::VersionInfo dnn_version,
int32_t toolkit_version)
: compute_capability_(cc),
dnn_version_(dnn_version),
toolkit_version_(toolkit_version) {}

absl::string_view name() const override {
return "cudnn-fused-convolution-rewriter";
Expand All @@ -111,7 +124,9 @@ class CudnnFusedConvRewriter : public HloModulePass {
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

private:
const se::CudaComputeCapability compute_capability_;
const se::GpuComputeCapability compute_capability_;
const se::dnn::VersionInfo dnn_version_;
const int32_t toolkit_version_;
};

} // namespace gpu
Expand Down
Loading
Loading