Skip to content

Commit

Permalink
[XLA:GPU] Cleanup includes and enforce op sorting in `triton/support_…
Browse files Browse the repository at this point in the history
…test.cc`

PiperOrigin-RevId: 718293626
  • Loading branch information
dimitar-asenov authored and Google-ML-Automation committed Jan 22, 2025
1 parent 5064da2 commit 28a75be
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 101 deletions.
3 changes: 1 addition & 2 deletions xla/backends/gpu/codegen/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -994,15 +994,14 @@ xla_cc_test(
"//xla/service/gpu:gpu_device_info_for_tests",
"//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
"//xla/stream_executor:device_description",
"//xla/tsl/platform:status_matchers",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:protobuf",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
225 changes: 126 additions & 99 deletions xla/backends/gpu/codegen/triton/support_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ limitations under the License.
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tsl/platform/status_matchers.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace gpu {
Expand Down Expand Up @@ -319,33 +318,39 @@ ENTRY triton_computation {
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc);
}

constexpr std::array kTestedOpsUnaryElementwise = {HloOpcode::kAbs,
HloOpcode::kCbrt,
HloOpcode::kCeil,
HloOpcode::kClz,
HloOpcode::kCos,
HloOpcode::kErf,
HloOpcode::kExp,
HloOpcode::kExpm1,
HloOpcode::kFloor,
HloOpcode::kImag,
HloOpcode::kIsFinite,
HloOpcode::kLog,
HloOpcode::kLog1p,
HloOpcode::kLogistic,
HloOpcode::kNegate,
HloOpcode::kNot,
HloOpcode::kPopulationCount,
HloOpcode::kReal,
HloOpcode::kReducePrecision,
HloOpcode::kRoundNearestAfz,
HloOpcode::kRoundNearestEven,
HloOpcode::kRsqrt,
HloOpcode::kSign,
HloOpcode::kSin,
HloOpcode::kSqrt,
HloOpcode::kTan,
HloOpcode::kTanh};
constexpr std::array kTestedOpsUnaryElementwise = {
// clang-format off
// go/keep-sorted start
HloOpcode::kAbs,
HloOpcode::kCbrt,
HloOpcode::kCeil,
HloOpcode::kClz,
HloOpcode::kCos,
HloOpcode::kErf,
HloOpcode::kExp,
HloOpcode::kExpm1,
HloOpcode::kFloor,
HloOpcode::kImag,
HloOpcode::kIsFinite,
HloOpcode::kLog,
HloOpcode::kLog1p,
HloOpcode::kLogistic,
HloOpcode::kNegate,
HloOpcode::kNot,
HloOpcode::kPopulationCount,
HloOpcode::kReal,
HloOpcode::kReducePrecision,
HloOpcode::kRoundNearestAfz,
HloOpcode::kRoundNearestEven,
HloOpcode::kRsqrt,
HloOpcode::kSign,
HloOpcode::kSin,
HloOpcode::kSqrt,
HloOpcode::kTan,
HloOpcode::kTanh
// go/keep-sorted end
// clang-format on
};

INSTANTIATE_TEST_SUITE_P(
UnaryElementwiseTestSuite, UnaryElementwiseTest,
Expand Down Expand Up @@ -495,22 +500,27 @@ ENTRY triton_computation {
}

constexpr std::array kTestedOpsBinaryElementwise = {
HloOpcode::kAnd,
HloOpcode::kOr,
HloOpcode::kXor,
// clang-format off
// go/keep-sorted start
HloOpcode::kAdd,
HloOpcode::kMultiply,
HloOpcode::kMaximum,
HloOpcode::kMinimum,
HloOpcode::kSubtract,
HloOpcode::kAnd,
HloOpcode::kAtan2,
HloOpcode::kCompare,
HloOpcode::kDivide,
HloOpcode::kRemainder,
HloOpcode::kMaximum,
HloOpcode::kMinimum,
HloOpcode::kMultiply,
HloOpcode::kOr,
HloOpcode::kPower,
HloOpcode::kRemainder,
HloOpcode::kShiftLeft,
HloOpcode::kShiftRightArithmetic,
HloOpcode::kShiftRightLogical,
HloOpcode::kCompare};
HloOpcode::kSubtract,
HloOpcode::kXor,
// go/keep-sorted end
// clang-format on
};

INSTANTIATE_TEST_SUITE_P(
BinaryElementwiseTestSuite, BinaryElementwiseTest,
Expand Down Expand Up @@ -1064,13 +1074,24 @@ ENTRY triton_computation {
}

constexpr std::array kTestedOpsCollectives = {
HloOpcode::kAllGather, HloOpcode::kAllGatherStart,
HloOpcode::kAllGatherDone, HloOpcode::kAllReduce,
HloOpcode::kAllReduceStart, HloOpcode::kAllReduceDone,
HloOpcode::kAsyncDone, HloOpcode::kAsyncStart,
HloOpcode::kAsyncUpdate, HloOpcode::kAllToAll,
HloOpcode::kCollectivePermute, HloOpcode::kReduceScatter,
HloOpcode::kCollectiveBroadcast};
// clang-format off
// go/keep-sorted start
HloOpcode::kAllGather,
HloOpcode::kAllGatherDone,
HloOpcode::kAllGatherStart,
HloOpcode::kAllReduce,
HloOpcode::kAllReduceDone,
HloOpcode::kAllReduceStart,
HloOpcode::kAllToAll,
HloOpcode::kAsyncDone,
HloOpcode::kAsyncStart,
HloOpcode::kAsyncUpdate,
HloOpcode::kCollectiveBroadcast,
HloOpcode::kCollectivePermute,
HloOpcode::kReduceScatter
// go/keep-sorted end
// clang-format on
};

INSTANTIATE_TEST_SUITE_P(
CollectiveTestSuite, CollectiveTest,
Expand Down Expand Up @@ -1209,60 +1230,66 @@ INSTANTIATE_TEST_SUITE_P(RngTestSuite, RngTest,
AllTestCombinationsForOpcodes(kTestedOpsRng),
TritonSupportTestTypeAndOpcodeAndDeviceToString);

constexpr std::array kUnsupportedOps = {HloOpcode::kAddDependency,
HloOpcode::kAfterAll,
HloOpcode::kBatchNormGrad,
HloOpcode::kBatchNormInference,
HloOpcode::kBatchNormTraining,
HloOpcode::kBitcastConvert,
HloOpcode::kCall,
HloOpcode::kCholesky,
HloOpcode::kCollectivePermuteDone,
HloOpcode::kCollectivePermuteStart,
HloOpcode::kComplex,
HloOpcode::kConcatenate,
HloOpcode::kConditional,
HloOpcode::kConvolution,
HloOpcode::kCopy,
HloOpcode::kCopyDone,
HloOpcode::kCopyStart,
HloOpcode::kCustomCall,
HloOpcode::kDomain,
HloOpcode::kDot,
HloOpcode::kDynamicReshape,
HloOpcode::kDynamicSlice,
HloOpcode::kDynamicUpdateSlice,
HloOpcode::kFft,
HloOpcode::kFusion,
HloOpcode::kGather,
HloOpcode::kGetDimensionSize,
HloOpcode::kGetTupleElement,
HloOpcode::kInfeed,
HloOpcode::kMap,
HloOpcode::kOptimizationBarrier,
HloOpcode::kOutfeed,
HloOpcode::kPad,
HloOpcode::kPartitionId,
HloOpcode::kRaggedAllToAll,
HloOpcode::kRaggedDot,
HloOpcode::kRecv,
HloOpcode::kRecvDone,
HloOpcode::kReduceWindow,
HloOpcode::kReplicaId,
HloOpcode::kReverse,
HloOpcode::kRngBitGenerator,
HloOpcode::kRngGetAndUpdateState,
HloOpcode::kScatter,
HloOpcode::kSelectAndScatter,
HloOpcode::kSend,
HloOpcode::kSendDone,
HloOpcode::kSetDimensionSize,
HloOpcode::kSort,
HloOpcode::kStochasticConvert,
HloOpcode::kTopK,
HloOpcode::kTriangularSolve,
HloOpcode::kTuple,
HloOpcode::kWhile};
constexpr std::array kUnsupportedOps = {
// clang-format off
// go/keep-sorted start
HloOpcode::kAddDependency,
HloOpcode::kAfterAll,
HloOpcode::kBatchNormGrad,
HloOpcode::kBatchNormInference,
HloOpcode::kBatchNormTraining,
HloOpcode::kBitcastConvert,
HloOpcode::kCall,
HloOpcode::kCholesky,
HloOpcode::kCollectivePermuteDone,
HloOpcode::kCollectivePermuteStart,
HloOpcode::kComplex,
HloOpcode::kConcatenate,
HloOpcode::kConditional,
HloOpcode::kConvolution,
HloOpcode::kCopy,
HloOpcode::kCopyDone,
HloOpcode::kCopyStart,
HloOpcode::kCustomCall,
HloOpcode::kDomain,
HloOpcode::kDot,
HloOpcode::kDynamicReshape,
HloOpcode::kDynamicSlice,
HloOpcode::kDynamicUpdateSlice,
HloOpcode::kFft,
HloOpcode::kFusion,
HloOpcode::kGather,
HloOpcode::kGetDimensionSize,
HloOpcode::kGetTupleElement,
HloOpcode::kInfeed,
HloOpcode::kMap,
HloOpcode::kOptimizationBarrier,
HloOpcode::kOutfeed,
HloOpcode::kPad,
HloOpcode::kPartitionId,
HloOpcode::kRaggedAllToAll,
HloOpcode::kRaggedDot,
HloOpcode::kRecv,
HloOpcode::kRecvDone,
HloOpcode::kReduceWindow,
HloOpcode::kReplicaId,
HloOpcode::kReverse,
HloOpcode::kRngBitGenerator,
HloOpcode::kRngGetAndUpdateState,
HloOpcode::kScatter,
HloOpcode::kSelectAndScatter,
HloOpcode::kSend,
HloOpcode::kSendDone,
HloOpcode::kSetDimensionSize,
HloOpcode::kSort,
HloOpcode::kStochasticConvert,
HloOpcode::kTopK,
HloOpcode::kTriangularSolve,
HloOpcode::kTuple,
HloOpcode::kWhile
// go/keep-sorted end
// clang-format on
};

absl::flat_hash_set<HloOpcode> AllTestedOpcodes() {
absl::flat_hash_set<HloOpcode> ret;
Expand Down

0 comments on commit 28a75be

Please sign in to comment.