Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CK_TILE] Add GetName for GEMM kernels #1791

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
2 changes: 2 additions & 0 deletions include/ck_tile/ops/common/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#include <string>

#include "ck_tile/core.hpp"

namespace ck_tile {
aledudek marked this conversation as resolved.
Show resolved Hide resolved

// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
Expand Down
33 changes: 20 additions & 13 deletions include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,26 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;

using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using GemmKernelArgs = typename Base::GemmKernelArgs;

static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t KBatch = 1;

struct GemmTransKernelArg
{
GemmKernelArgs group_karg;
ck_tile::index_t block_start;
ck_tile::index_t block_end;

GemmTransKernelArg() = default;
GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
{
}
};

CK_TILE_HOST static std::string GetName()
{
#define _SS_ std::string
Expand All @@ -69,19 +89,6 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// clang-format on
}

struct GemmTransKernelArg
{
GemmKernelArgs group_karg;
ck_tile::index_t block_start;
ck_tile::index_t block_end;

GemmTransKernelArg() = default;
GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
{
}
};

__host__ static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
-> std::size_t
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
using _SS_ = std::string;

return _SS_("pipeline_AgBgCrCompV3_") +
_TS_(MPerBlock) + "x" + _TS_(NPerBlock) + "x" + _TS_(KPerBlock) + "x" + _TS_(BlockSize) + "_" +
_TS_(BlockSize) + "_" +
_TS_(VectorSizeA) + "x" + _TS_(VectorSizeB) + "x" + _TS_(VectorSizeC) + "_" +
_TS_(kPadM) + "x" + _TS_(kPadN) + "x" + _TS_(kPadK);
#undef _TS_
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <ostream>
#include <sstream>

#include "ck_tile/core.hpp"

Expand Down Expand Up @@ -71,3 +72,10 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::TailNumber& s)
}
return os;
}

inline std::string GemmPipelineSchedulerToString(const ck_tile::GemmPipelineScheduler& s)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[[nodiscard]] fn?

{
std::ostringstream oss;
Copy link
Contributor

@mozga-amd mozga-amd Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thread_local std::ostringstream oss ? then an object of this storage class is created when the thread of execution passes through its definition and is destroyed immediately when execution leaves the lexical.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oss.str("")? - it could be fine to clean buffer before

oss << s;
return oss.str();
Comment on lines +78 to +80
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this print just integer value?

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ struct GemmPipelineProblemBase

return _SS_("gemm_problem_") +
_TS_(VectorLoadSize) + "x" + _TS_(kBlockSize) + "_" +
aledudek marked this conversation as resolved.
Show resolved Hide resolved
_TS_(kPadM) + "x" + _TS_(kPadN) + "x" + _TS_(kPadK);
_TS_(kPadM) + "x" + _TS_(kPadN) + "x" + _TS_(kPadK) + "_" +
Copy link
Contributor

@mozga-amd mozga-amd Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about the concat() function, to pass Scheduler and other params?

GemmPipelineSchedulerToString(Scheduler);
#undef _TS_
// clang-format on
}
Expand Down