forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathsm90_tile_scheduler.hpp
139 lines (117 loc) · 5.5 KB
/
sm90_tile_scheduler.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/kernel/static_tile_scheduler.hpp"
namespace cutlass::gemm::kernel::detail {
///////////////////////////////////////////////////////////////////////////////
// Persistent Thread Block (TB) scheduler
class PersistentTileSchedulerSm90:
public StaticPersistentTileScheduler<PersistentTileSchedulerSm90> {
using BaseScheduler = StaticPersistentTileScheduler<PersistentTileSchedulerSm90>;
public:
using StaticPersistentTileScheduler::StaticPersistentTileScheduler;
using Params = PersistentTileSchedulerSm90Params;
using RasterOrder = typename Params::RasterOrder;
using RasterOrderOptions = typename Params::RasterOrderOptions;
using Arguments = BaseScheduler::Arguments;
static constexpr bool IsDynamicPersistent = false;
// get work_idx_m, work_idx_n from blk_per_grid_dim while applying swizzle
static CUTLASS_DEVICE
cute::tuple<int32_t, int32_t>
get_work_idx_m_and_n(
uint64_t blk_per_grid_dim,
FastDivmodU64Pow2 const& divmod_cluster_shape_major,
FastDivmodU64Pow2 const& divmod_cluster_shape_minor,
FastDivmodU64 const& divmod_cluster_blk_major,
int32_t log_swizzle_size,
RasterOrder raster_order) {
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
return get_work_idx_m_and_n(
blk_per_grid_dim,
divmod_cluster_shape_major,
divmod_cluster_shape_minor,
divmod_cluster_blk_major,
log_swizzle_size,
raster_order,
cta_m_in_cluster,
cta_n_in_cluster
);
}
static CUTLASS_DEVICE
cute::tuple<int32_t, int32_t>
get_work_idx_m_and_n(
uint64_t blk_per_grid_dim,
FastDivmodU64Pow2 const& divmod_cluster_shape_major,
FastDivmodU64Pow2 const& divmod_cluster_shape_minor,
FastDivmodU64 const& divmod_cluster_blk_major,
int32_t log_swizzle_size,
RasterOrder raster_order,
uint64_t cta_m_in_cluster,
uint64_t cta_n_in_cluster) {
uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0;
divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim);
if (raster_order == RasterOrder::AlongN) {
cluster_minor_offset = cta_m_in_cluster;
}
else {
cluster_minor_offset = cta_n_in_cluster;
}
uint64_t cluster_idx_minor, cluster_idx_major;
uint64_t cluster_idx_minor_div_swizzle, extra, offset;
offset = cluster_id & ((1 << log_swizzle_size) - 1);
extra = cluster_id >> log_swizzle_size;
divmod_cluster_blk_major(cluster_idx_minor_div_swizzle, cluster_idx_major, extra);
cluster_idx_minor = cluster_idx_minor_div_swizzle * (1 << log_swizzle_size) + offset;
auto minor_work_idx = static_cast<int32_t>(cluster_idx_minor * divmod_cluster_shape_minor.divisor +
cluster_minor_offset);
auto major_work_idx = static_cast<int32_t>(cluster_idx_major * divmod_cluster_shape_major.divisor +
cluster_major_offset);
if (raster_order == RasterOrder::AlongN) {
return {minor_work_idx, major_work_idx};
}
else {
return {major_work_idx, minor_work_idx};
}
}
// The basic tile scheduler does not require any additional workspace
template <class ProblemShape, class ElementAccumulator>
static size_t
get_workspace_size(Arguments const&, ProblemShape, KernelHardwareInfo const&, uint32_t, const uint32_t = 1, uint32_t = 1) {
return 0;
}
template <class ProblemShape, class ElementAccumulator>
static cutlass::Status
initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape, KernelHardwareInfo const&,
uint32_t, const uint32_t = 1, uint32_t = 1, CudaHostAdapter* cuda_adapter = nullptr) {
return Status::kSuccess;
}
};
}