-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadjoint.cpp
106 lines (80 loc) · 3.42 KB
/
adjoint.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include "adjoint.h"
#include "primal.h"
#include <cassert>
int enzyme_dup;
int enzyme_dupnoneed;
int enzyme_out;
int enzyme_const;
extern double __enzyme_autodiff_square(void*, ...);
extern void __enzyme_autodiff_add(void*, int, double&, double&, int, double&, double&);
extern void __enzyme_autodiff_sub(void*, ...);
extern void __enzyme_autodiff_mul(double (*) (const double&, const double&), int, double&, double&, int, double&, double&);
extern void __enzyme_autodiff_div(void*, int, double&, double&, int, double&, double&, int, double&, double&);
extern void __enzyme_autodiff_dot(void*, int, const std::vector<double>&, int, const std::vector<double>&, std::vector<double>&);
extern void __enzyme_autodiff_gemv(void*,
int, double,
int, const std::vector<std::vector<double>>&,
int, const std::vector<double>&, std::vector<double>&,
int, double,
int, std::vector<double>&, std::vector<double>&);
double grad_square(double x) {
return __enzyme_autodiff_square((void*)square, x);
}
std::pair<double, double> grad_add(double x, double y) {
auto grad = std::make_pair(0.0, 0.0);
__enzyme_autodiff_add((void*)add,
enzyme_dup, x, grad.first,
enzyme_dup, y, grad.second);
return grad;
}
std::pair<double, double> grad_sub(double x, double y) {
auto grad = std::make_pair(0.0, 0.0);
__enzyme_autodiff_sub((void*)sub,
enzyme_dup, &x, &(grad.first),
enzyme_dup, &y, &(grad.second));
return grad;
}
std::pair<double, double> grad_mul(double x, double y) {
auto grad = std::make_pair(0.0, 0.0);
auto add_wrap = [](const double& a, const double& b) {
return mul(a, b);
};
__enzyme_autodiff_mul(add_wrap,
enzyme_dup, x, grad.first,
enzyme_dup, y, grad.second);
return grad;
}
std::pair<double, double> grad_div(double x, double y) {
auto grad = std::make_pair(0.0, 0.0);
double result;
double dresult = 1.0;
__enzyme_autodiff_div((void*)div,
enzyme_dup, x, grad.first,
enzyme_dup, y, grad.second,
enzyme_dupnoneed, result, dresult);
return grad;
}
// grad_dot is treating 'c' as a constant vector so in effect, dot(c,x)=f(x)=<c,x>
std::vector<double> grad_dot(const std::vector<double>& c, const std::vector<double>& x) {
assert(c.size() == x.size());
std::vector<double> dx(x.size(), 0);
__enzyme_autodiff_dot((void*)dot,
enzyme_const, c,
enzyme_dup, x, dx);
return dx;
}
std::vector<double> grad_gemv(double alpha,
const std::vector<std::vector<double>>& A,
const std::vector<double>& x,
double beta,
std::vector<double>& y, std::vector<double>& dy) {
assert(A.size() == y.size() && A.size() > 1 && A[0].size() == x.size());
std::vector<double> dx(x.size(), 0);
__enzyme_autodiff_gemv((void*)gemv,
enzyme_const, alpha,
enzyme_const, A,
enzyme_dup, x, dx,
enzyme_const, beta,
enzyme_dupnoneed, y, dy);
return dx;
}