forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSobolEngineOps.cpp
159 lines (135 loc) · 6.04 KB
/
SobolEngineOps.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/SobolEngineOpsUtils.h>
#include <vector>
namespace at {
namespace native {
/// This is the core function to draw samples from a `SobolEngine` given
/// its state variables (`sobolstate` and `quasi`). `dimension` can be
/// inferred from `sobolstate`, but choosing to pass it explicitly to avoid
/// an extra operation to obtain the size of the first dimension of
/// `sobolstate`.
std::tuple<Tensor, Tensor> _sobol_engine_draw(const Tensor& quasi, int64_t n, const Tensor& sobolstate,
int64_t dimension, int64_t num_generated, optional<ScalarType> dtype) {
AT_CHECK(sobolstate.dtype() == at::kLong,
"sobolstate needs to be of type ", at::kLong);
AT_CHECK(quasi.dtype() == at::kLong,
"quasi needs to be of type ", at::kLong);
Tensor wquasi = quasi.clone();
auto result_dtype = dtype.has_value() ? dtype.value() : at::kFloat;
Tensor result = at::empty({n, dimension}, sobolstate.options().dtype(result_dtype));
AT_DISPATCH_FLOATING_TYPES(result_dtype, "_sobol_engine_draw", [&]() -> void {
// We deal with `data` and `strides` due to performance issues.
int64_t l;
int64_t* wquasi_data = wquasi.data<int64_t>();
int64_t* sobolstate_data = sobolstate.data<int64_t>();
scalar_t* result_data = result.data<scalar_t>();
int64_t wquasi_stride = wquasi.stride(0);
int64_t sobolstate_row_stride = sobolstate.stride(0), sobolstate_col_stride = sobolstate.stride(1);
int64_t result_row_stride = result.stride(0), result_col_stride = result.stride(1);
for (int64_t i = 0; i < n; i++, num_generated++) {
l = rightmost_zero(num_generated);
for (int64_t j = 0; j < dimension; j++) {
wquasi_data[j * wquasi_stride] ^= sobolstate_data[j * sobolstate_row_stride + l * sobolstate_col_stride];
result_data[i * result_row_stride + j * result_col_stride] = wquasi_data[j * wquasi_stride];
}
}
});
result.mul_(RECIPD);
return std::tuple<Tensor, Tensor>(result, wquasi);
}
/// This is the core function to fast-forward a `SobolEngine` given
/// its state variables (`sobolstate` and `quasi`). `dimension` can be
/// inferred from `sobolstate`, but is passed as an argument for the same reasons
/// specified above.
Tensor& _sobol_engine_ff_(Tensor& quasi, int64_t n, const Tensor& sobolstate,
int64_t dimension, int64_t num_generated) {
AT_CHECK(sobolstate.dtype() == at::kLong,
"sobolstate needs to be of type ", at::kLong);
AT_CHECK(quasi.dtype() == at::kLong,
"quasi needs to be of type ", at::kLong);
// We deal with `data` and `strides` due to performance issues.
int64_t l;
int64_t* quasi_data = quasi.data<int64_t>();
int64_t* sobolstate_data = sobolstate.data<int64_t>();
int64_t quasi_stride = quasi.stride(0);
int64_t sobolstate_row_stride = sobolstate.stride(0), sobolstate_col_stride = sobolstate.stride(1);
for (int64_t i = 0; i < n; i++, num_generated++) {
l = rightmost_zero(num_generated);
for (int64_t j = 0; j < dimension; j++) {
quasi_data[j * quasi_stride] ^= sobolstate_data[j * sobolstate_row_stride + l * sobolstate_col_stride];
}
}
return quasi;
}
/// This is an implicit function used for randomizing the state variables of the.
/// `SobolEngine`. Arguments are a randomized `sobolstate` state variables
/// and a list of random lower triangular matrices consisting of 0s and 1s. `dimension` is
/// passed explicitly again.
Tensor& _sobol_engine_scramble_(Tensor& sobolstate, const Tensor& ltm, int64_t dimension) {
AT_CHECK(sobolstate.dtype() == at::kLong,
"sobolstate needs to be of type ", at::kLong);
/// Require a tensor accessor for `sobolstate`
auto ss_a = sobolstate.accessor<int64_t, 2>();
/// For every tensor in the list of tensors, the diagonals are made 1
/// Require a dot product of every row with a specific vector of each of the matrices in `ltm`.
/// Instead, we perform an element-wise product of all the matrices and sum over the last dimension.
/// The required product of the m^{th} row in the d^{th} square matrix in `ltm` can be accessed
/// using ltm_d_a[d][m] m and d are zero-indexed
Tensor diag_true = ltm.clone();
diag_true.diagonal(0, -2, -1).fill_(1);
Tensor ltm_dots = cdot_pow2(diag_true);
auto ltm_d_a = ltm_dots.accessor<int64_t, 2>();
/// Main scrambling loop
for (int64_t d = 0; d < dimension; ++d) {
for (int64_t j = 0; j < MAXBIT; ++j) {
int64_t vdj = ss_a[d][j], l = 1, t2 = 0;
for (int64_t p = MAXBIT - 1; p >= 0; --p) {
int64_t lsmdp = ltm_d_a[d][p];
int64_t t1 = 0;
for (int64_t k = 0; k < MAXBIT; ++k) {
t1 += (bitsubseq(lsmdp, k, 1) * bitsubseq(vdj, k, 1));
}
t1 = t1 % 2;
t2 = t2 + t1 * l;
l = l << 1;
}
ss_a[d][j] = t2;
}
}
return sobolstate;
}
/// This is a core function to initialize the main state variable of a `SobolEngine`.
/// `dimension` is passed explicitly as well (see why above)
Tensor& _sobol_engine_initialize_state_(Tensor& sobolstate, int64_t dimension) {
AT_CHECK(sobolstate.dtype() == at::kLong,
"sobolstate needs to be of type ", at::kLong);
/// First row of `sobolstate` is 1
sobolstate.select(0, 0).fill_(1);
/// Use a tensor accessor for `sobolstate`
auto ss_a = sobolstate.accessor<int64_t, 2>();
for (int64_t d = 0; d < dimension; ++d) {
int64_t p = poly[d];
int64_t m = bit_length(p) - 1;
for (int64_t i = 0; i < m; ++i) {
ss_a[d][i] = initsobolstate[d][i];
}
for (int64_t j = m; j < MAXBIT; ++j) {
int64_t newv = ss_a[d][j - m];
int64_t pow2 = 1;
for (int64_t k = 0; k < m; ++k) {
pow2 <<= 1;
if ((p >> (m - 1 - k)) & 1) {
newv = newv ^ (pow2 * ss_a[d][j - k - 1]);
}
}
ss_a[d][j] = newv;
}
}
Tensor pow2s = at::pow(2, at::native::arange((MAXBIT - 1), -1, -1, sobolstate.options()));
sobolstate.mul_(pow2s);
return sobolstate;
}
} // namespace native
} // namespace at