forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPowKernel.cpp
166 lines (159 loc) · 5.07 KB
/
PowKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#include <cmath>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec256/vec256.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/Pow.h>
#include <ATen/native/cpu/Loops.h>
namespace at { namespace native {
namespace {
void pow_tensor_tensor_kernel(TensorIterator& iter) {
if (isFloatingType(iter.dtype()) || isComplexType(iter.dtype())) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "pow", [&]() {
using Vec = Vec256<scalar_t>;
cpu_kernel_vec(iter,
[=](scalar_t base, scalar_t exp) -> scalar_t {
return std::pow(base, exp);
},
[&](Vec base, Vec exp) -> Vec {
return base.pow(exp);
}
);
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow", [&]() {
cpu_kernel(iter,
[=](scalar_t base, scalar_t exp) -> scalar_t {
return native::powi(base, exp);
}
);
});
}
}
void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) {
if (isFloatingType(iter.dtype())) {
const auto exp = exp_scalar.to<double>();
// Floating types allow AVX2 vector optimizations for pow/sqrt/rsqrt:
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "pow", [&]() {
using Vec = Vec256<scalar_t>;
if (exp == 0.5) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
return std::sqrt(base);
},
[](Vec base) -> Vec { return base.sqrt(); }
);
} else if (exp == 2) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
return base * base;
},
[](Vec base) -> Vec { return base * base; }
);
} else if (exp == 3) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
return base * base * base;
},
[](Vec base) -> Vec { return base * base * base; }
);
} else if (exp == -0.5) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
return 1.0 / std::sqrt(base);
},
[](Vec base) -> Vec { return base.rsqrt(); }
);
} else if (exp == -1) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
return 1.0 / base;
},
[](Vec base) -> Vec { return base.reciprocal(); }
);
} else if (exp == -2) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
return 1.0 / (base * base);
},
[](Vec base) -> Vec { return (base * base).reciprocal(); }
);
} else {
cpu_kernel_vec(iter,
[=](scalar_t base) -> scalar_t {
return std::pow(base, exp);
},
[=](Vec base) -> Vec { return base.pow(exp); }
);
}
});
} else if (isComplexType(iter.dtype())) {
const auto exp = exp_scalar.to<c10::complex<double>>();
// Floating types allow AVX2 vector optimizations for pow/sqrt/rsqrt:
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "pow", [&]() {
using Vec = Vec256<scalar_t>;
if (exp == 0.5) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
return std::sqrt(base);
},
[](Vec base) -> Vec { return base.sqrt(); }
);
} else if (exp == 2.0) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
return base * base;
},
[](Vec base) -> Vec { return base * base; }
);
} else if (exp == 3.0) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
return base * base * base;
},
[](Vec base) -> Vec { return base * base * base; }
);
} else if (exp == -0.5) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
return scalar_t(1.0) / std::sqrt(base);
},
[](Vec base) -> Vec { return base.rsqrt(); }
);
} else if (exp == -1.0) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
return scalar_t(1.0) / base;
},
[](Vec base) -> Vec { return base.reciprocal(); }
);
} else if (exp == -2.0) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
return scalar_t(1.0) / (base * base);
},
[](Vec base) -> Vec { return (base * base).reciprocal(); }
);
} else {
cpu_kernel_vec(iter,
[=](scalar_t base) -> scalar_t {
return std::pow(base, scalar_t(exp));
},
[=](Vec base) -> Vec { return base.pow(scalar_t(exp)); } // std::pow cannot accept mixed complex data types.
);
}
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow", [&]() {
const scalar_t exp = exp_scalar.to<scalar_t>();
cpu_kernel(iter,
[=](scalar_t base) -> scalar_t {
return native::powi(base, exp);
});
});
}
}
} // anonymous namespace
REGISTER_DISPATCH(pow_tensor_tensor_stub, &pow_tensor_tensor_kernel);
REGISTER_DISPATCH(pow_tensor_scalar_stub, &pow_tensor_scalar_kernel);
}} // namespace at::native