forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcum_gpu.ts
94 lines (88 loc) · 3.13 KB
/
cum_gpu.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
/**
* @license
* Copyright 2022 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {GPGPUProgram} from './gpgpu_math';
import {getCoordsDataType, UniformType} from './shader_compiler';
export enum CumOpType {
Prod = '*',
Sum = '+',
}
export class CumProgram implements GPGPUProgram {
variableNames = ['x'];
userCode: string;
customUniforms = [{name: 'index', type: 'float' as UniformType}];
constructor(
public op: CumOpType, public outputShape: number[], exclusive: boolean,
reverse: boolean) {
const rank = this.outputShape.length;
const initVal = this.op === CumOpType.Prod ? '1.0' : '0.0';
const val =
exclusive ? initVal : `getX(${getCoords(rank, 'coords', this.op)})`;
const length = this.outputShape[this.outputShape.length - 1];
let condition = '';
let idxString = '';
// When exclusive is set, the cum op becomes roll op that copies the
// value from the previous index based on the direction specified by the
// reverse flag.
if (exclusive) {
condition = reverse ? `end != ${length - 1}` : 'end != 0';
idxString = reverse ? 'end + 1' : 'end - 1';
} else {
condition = reverse ? `end + pow2 < ${length}` : 'end >= pow2';
idxString = (reverse ? 'end + pow2' : 'end - pow2');
}
this.userCode = `
void main() {
${getCoordsDataType(rank)} coords = getOutputCoords();
int end = ${getFinalCoord(rank, 'coords', this.op)};
float val = ${val};
int pow2 = int(pow(2.0, index));
if (${condition}) {
int idx = ${idxString};
${getFinalCoord(rank, 'coords', this.op)} = idx;
val ${this.op}= getX(${getCoords(rank, 'coords', this.op)});
}
setOutput(val);
}
`;
}
}
function getCoords(rank: number, name: string, op: CumOpType): string {
if (rank === 1) {
return `${name}`;
} else if (rank === 2) {
return `${name}.x, ${name}.y`;
} else if (rank === 3) {
return `${name}.x, ${name}.y, ${name}.z`;
} else if (rank === 4) {
return `${name}.x, ${name}.y, ${name}.z, ${name}.w`;
} else {
throw new Error(`Cumulative ${op} for rank ${rank} is not yet supported`);
}
}
function getFinalCoord(rank: number, name: string, op: CumOpType): string {
if (rank === 1) {
return `${name}`;
} else if (rank === 2) {
return `${name}.y`;
} else if (rank === 3) {
return `${name}.z`;
} else if (rank === 4) {
return `${name}.w`;
} else {
throw new Error(`Cumulative ${op} for rank ${rank} is not yet supported`);
}
}