forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathelemenntwise_rtc_gpu.cc
130 lines (120 loc) · 4.04 KB
/
elemenntwise_rtc_gpu.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include "caffe2/core/common_gpu.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/operator.h"
#include "caffe2/cuda_rtc/common_rtc.h"
namespace caffe2 {
namespace {
class ElementwiseRTCFunction
: public CudaRTCFunction<ElementwiseRTCFunction> {
public:
ElementwiseRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {}
template <typename... Args>
string KernelName(Args... /*args*/) {
return name_;
}
template <typename... Args>
string GetSource(Args... args);
private:
string name_;
};
template<>
string ElementwiseRTCFunction::GetSource(
int input_size, int output_size,
const string command_string) {
std::stringstream ss;
ss << "extern \"C\" __global__ void " << name_ <<
"(const size_t nthreads, \n";
// Insert the parameter list.
int remain_params = input_size + output_size;
for (int i = 0; i < input_size; ++i) {
ss << "const float* in" << i
<< ((remain_params--) ? ", \n" : "");
}
for (int i = 0; i < output_size; ++i) {
ss << "float* out" << i
<< ((remain_params--) ? ", \n" : "");
}
ss << ") {\n"
"for (int index = blockIdx.x * blockDim.x + threadIdx.x;\n"
"index < nthreads; index += blockDim.x * gridDim.x) {\n"
<< command_string << "\n"
<< "}\n}";
return ss.str();
}
} // namespace
/**
* A GPU operator that can generate limited elementwise operations.
*
* ElementwiseRTCOp allows one to do a simple and limited thing: it takes in
* multiple inputs and multiple outputs, as well as a raw string argument
* rtc_src. The runtime then generates the following kernel code:
*
* __global__ void kernel_name(const size_t nthreads, ...) {
* for(int index = blockIdx.x * blockDim.x + threadIdx.x;
* index < nthreads; index += blockDim.x * gridDim.x) {
* rtc_src
* }
* }
* where the "..." part is auto generated, so one can refer to the input and
* output as in0, in1, ..., out0, out1... in the rtc_src string.
*
* For example, if one wants to do a vector multiplication, one can take two
* inputs and one outputs, and write rtc_src as
* out0[index] = in0[index] * in1[index];
*
* This op is currently highly experimental. We do not have a gradient
* registered for it either.
*/
class ElementwiseRTCOp final : public Operator<CUDAContext> {
public:
ElementwiseRTCOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CUDAContext>(operator_def, ws) {
const string src = OperatorBase::GetSingleArgument<string>(
"rtc_src", "");
CAFFE_ENFORCE(src.size(), "Op should have a non-zero source code size.");
func_.Compile(InputSize(), OutputSize(), src);
}
~ElementwiseRTCOp() override {}
bool RunOnDevice() override {
static_assert(sizeof(void*) == sizeof(size_t),
"The argbuffer relies on the assumption that void* and "
"size_t have the same size.");
vector<size_t> argBuffer_vec(InputSize() + OutputSize() + 1);
size_t* argBuffer = argBuffer_vec.data();
CAFFE_ENFORCE(
Input(0).numel() < std::numeric_limits<int>::max(),
"The kernel function currently only supports int index.");
argBuffer[0] = Input(0).numel();
void** ptr_buffer = reinterpret_cast<void**>(argBuffer + 1);
for (int i = 0; i < InputSize(); ++i) {
ptr_buffer[i] = const_cast<float*>(Input(i).data<float>());
}
for (int i = 0; i < OutputSize(); ++i) {
Output(i)->ResizeLike(Input(0));
ptr_buffer[i + InputSize()] = Output(i)->mutable_data<float>();
}
size_t argBufferSize = sizeof(argBuffer);
void* config[] = {
CU_LAUNCH_PARAM_BUFFER_POINTER, argBuffer,
CU_LAUNCH_PARAM_BUFFER_SIZE, &argBufferSize,
CU_LAUNCH_PARAM_END
};
func_.LaunchEx(
CAFFE_GET_BLOCKS(Input(0).numel()),
1,
1,
CAFFE_CUDA_NUM_THREADS,
1,
1,
0,
context_.cuda_stream(),
config);
return true;
}
private:
ElementwiseRTCFunction func_;
};
namespace {
REGISTER_CUDA_OPERATOR_WITH_ENGINE(ElementwiseRTC, NVRTC, ElementwiseRTCOp);
}
} // namespace caffe2