-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathengine.hpp
168 lines (149 loc) · 4.42 KB
/
engine.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#pragma once
#include <array>
#include <ostream>
#include <algorithm>
#include <iostream>
#include <sstream>
#include <unordered_set>
#include <vector>
#include <cmath>
#include <memory>
enum class Operation {
Null,
Addition,
Multiplication,
Power,
RELU
};
std::string_view toString(Operation op)
{
switch (op)
{
case Operation::Null: return "null";
case Operation::Addition: return "+";
case Operation::Multiplication: return "*";
case Operation::Power: return "pow";
case Operation::RELU: return "RELU";
}
throw std::runtime_error("Unhandled op");
}
struct Value;
struct Inputs {
Operation operation = Operation::Null;
std::vector<std::shared_ptr<Value>> values;
double power = 0.0;
};
// Note: do not construct directly unless you have specific requirements
// Use the Value::make(...), as the ValuePtr type has all the operators defined on it
//
// Scalar floating point number type which allows building and evaluating
// mathematical expression trees forwards and backward:
// - Forwards: resolve/simplify the mathematical expression value
// - Backwards: calculate the partial derivative for all input terms in the tree
// by applying the chain rule backwards
//
// This is done by saving the input expressions/terms for each 'Value'
// and traversing the tree as needed.
struct Value {
double _value;
Inputs _inputs;
double _grad = 0.0;
static void buildTopo(
std::vector<Value*>& topo,
std::unordered_set<Value*>& visited,
Value* value)
{
if (visited.find(value) != visited.end()) {
return;
}
visited.insert(value);
for (auto next : value->_inputs.values) {
buildTopo(topo, visited, next.get());
}
topo.push_back(value);
}
Value(double value, Inputs inputs = Inputs{})
: _value(value), _inputs(inputs)
{}
void zeroGrad() {
_grad = 0.0;
}
void backwardsOnce() {
if (_inputs.operation == Operation::Addition) {
_inputs.values[0]->_grad += _grad;
_inputs.values[1]->_grad += _grad;
} else if (_inputs.operation == Operation::Multiplication) {
auto& a = _inputs.values[0];
auto& b = _inputs.values[1];
a->_grad += b->_value * _grad;
b->_grad += a->_value * _grad;
} else if (_inputs.operation == Operation::Power) {
_inputs.values[0]->_grad += (_inputs.power * std::pow(_inputs.values[0]->_value, _inputs.power-1)) * _grad;
} else if (_inputs.operation == Operation::RELU) {
_inputs.values[0]->_grad += double(_value > 0.0) * _grad;
}
}
void backwards()
{
std::vector<Value*> topo;
std::unordered_set<Value*> visited;
buildTopo(topo, visited, this);
_grad = 1.0;
std::for_each(std::rbegin(topo), std::rend(topo), [&](Value* value) {
value->backwardsOnce();
});
}
void printTree(int indents = 0)
{
std::string_view operation = _inputs.operation != Operation::Null ? toString(_inputs.operation) : "";
std::stringstream current;
current << std::string(indents, ' ') << "value=" << _value << " grad=" << _grad << " " << operation << std::endl;
const auto currentStr = current.str();
if (const bool hasLeft = _inputs.values.size() > 0; hasLeft) {
_inputs.values[0]->printTree(currentStr.size());
}
std::cout << currentStr;
if (const bool hasRight = _inputs.values.size() > 1; hasRight) {
_inputs.values[1]->printTree(currentStr.size());
}
}
template<typename... Args>
static auto make(Args&&... args)
{
return std::make_shared<Value>(std::forward<Args>(args)...);
}
};
using ValuePtr = std::shared_ptr<Value>;
ValuePtr operator+(ValuePtr a, ValuePtr b)
{
return std::make_shared<Value>(a->_value + b->_value, Inputs{ Operation::Addition, { a, b }});
}
ValuePtr operator*(ValuePtr a, ValuePtr b)
{
return std::make_shared<Value>(a->_value * b->_value, Inputs{ Operation::Multiplication, { a, b }});
}
ValuePtr power(ValuePtr a, double value)
{
return std::make_shared<Value>(std::pow(a->_value, value), Inputs{ Operation::Power, { a }, value });
}
ValuePtr operator-(ValuePtr a)
{
return a * std::make_shared<Value>(-1.0);
}
ValuePtr operator/(ValuePtr a, ValuePtr b)
{
return a * power(b, -1.0);
}
ValuePtr operator-(ValuePtr a, ValuePtr b)
{
return a + (-b);
}
ValuePtr relu(ValuePtr a)
{
return std::make_shared<Value>((a->_value > 0.0 ? a->_value : 0.0), Inputs{ Operation::RELU, { a } });
}
std::ostream& operator<<(std::ostream& os, const ValuePtr& value)
{
os << value->_value;
return os;
}