Skip to content

Commit

Permalink
Enable triton sparse gemm only for CUDA
Browse files Browse the repository at this point in the history
  • Loading branch information
hsharsha committed Jul 5, 2024
1 parent a52b6a3 commit deee85c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5387,6 +5387,7 @@ cc_library(
name = "elemental_ir_emitter",
srcs = ["elemental_ir_emitter.cc"],
hdrs = ["elemental_ir_emitter.h"],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
deps = [
":algorithm_util",
":float8_fnuz_ir_emitter",
Expand Down
2 changes: 2 additions & 0 deletions xla/service/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2921,10 +2921,12 @@ absl::StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
"Algorithm not supported by the ElementalIrEmitter: %s",
PrecisionConfig::Algorithm_Name(hlo->precision_config().algorithm())));
}
#ifdef GOOGLE_CUDA
const HloDotInstruction* dot = Cast<HloDotInstruction>(hlo);
if (dot->sparse_operands()) {
return Unimplemented("Sparse dot is supported by Triton emitter only.");
}
#endif

auto lhs_generator = operand_to_generator.at(hlo->operand(0));
auto rhs_generator = operand_to_generator.at(hlo->operand(1));
Expand Down

0 comments on commit deee85c

Please sign in to comment.