forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDescriptors.h
157 lines (137 loc) · 4.79 KB
/
Descriptors.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#pragma once
#include <ATen/miopen/Exceptions.h>
#include <ATen/miopen/miopen-wrapper.h>
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
namespace at { namespace native {
inline int dataSize(miopenDataType_t dataType)
{
switch (dataType) {
case miopenHalf: return 2;
case miopenFloat: return 4;
case miopenBFloat16: return 2;
default: return 8;
}
}
// This function modifies 'stride' in place so that the stride for
// dim i is the product of the sizes of dims i+1 to the end.
static inline void fixSizeOneDimStride(int dim, const int *size, int *stride) {
int64_t z = 1;
for(int d = dim-1; d >= 0; d--)
{
if (size[d] == 1) {
stride[d] = z;
} else {
z *= size[d];
}
}
}
template <typename T, miopenStatus_t (*dtor)(T*)>
struct DescriptorDeleter {
void operator()(T* x) {
if (x != nullptr) {
MIOPEN_CHECK(dtor(x));
}
}
};
// A generic class for wrapping MIOpen descriptor types. All you need
// is to give the underlying type the Descriptor_t points to (usually,
// if it's miopenTensorDescriptor_t it points to miopenTensorStruct),
// the constructor and the destructor. Subclasses are responsible
// for defining a set() function to actually set the descriptor.
//
// Descriptors default construct to a nullptr, and have a descriptor
// initialized the first time you call set() or any other initializing
// function.
template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
class Descriptor
{
public:
// Use desc() to access the underlying descriptor pointer in
// a read-only fashion. Most client code should use this.
// If the descriptor was never initialized, this will return
// nullptr.
T* desc() const { return desc_.get(); }
T* desc() { return desc_.get(); }
// Use mut_desc() to access the underlying descriptor pointer
// if you intend to modify what it points to (e.g., using
// miopenSetFooDescriptor). This will ensure that the descriptor
// is initialized. Code in this file will use this function.
T* mut_desc() { init(); return desc_.get(); }
protected:
void init() {
if (desc_ == nullptr) {
T* raw_desc;
MIOPEN_CHECK(ctor(&raw_desc));
desc_.reset(raw_desc);
}
}
private:
std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
};
class TensorDescriptor
: public Descriptor<miopenTensorDescriptor,
&miopenCreateTensorDescriptor,
&miopenDestroyTensorDescriptor>
{
public:
TensorDescriptor() {}
explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
set(t, pad);
}
void set(const at::Tensor &t, size_t pad = 0);
void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);
void print();
private:
void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
fixSizeOneDimStride(dim, size, stride);
MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
}
};
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
class FilterDescriptor
: public Descriptor<miopenTensorDescriptor,
&miopenCreateTensorDescriptor,
&miopenDestroyTensorDescriptor>
{
public:
void set(const at::Tensor &t, int64_t pad = 0);
private:
void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
fixSizeOneDimStride(dim, size, stride);
MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
}
};
struct ConvolutionDescriptor
: public Descriptor<miopenConvolutionDescriptor,
&miopenCreateConvolutionDescriptor,
&miopenDestroyConvolutionDescriptor>
{
void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups) {
MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode));
MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups));
}
};
struct RNNDescriptor
: public Descriptor<miopenRNNDescriptor,
&miopenCreateRNNDescriptor,
&miopenDestroyRNNDescriptor>
{
void set(int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, miopenRNNMode_t rnn_mode,
miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) {
MIOPEN_CHECK(miopenSetRNNDescriptor(mut_desc(), hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algorithm, datatype));
}
};
union Constant
{
float f;
double d;
Constant(miopenDataType_t dataType, double value) {
if (dataType == miopenHalf || dataType == miopenFloat || dataType == miopenBFloat16) {
f = static_cast<float>(value);
} else {
d = value;
}
}
};
}} // namespace