-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathReservoir.java
207 lines (182 loc) · 7.31 KB
/
Reservoir.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
// Contains most of the information for the Data Reservoir Compute AI.
// Including all trainable weights for all Compute class operators.
// One reason for this is to simplify access for evolution based algorithms.
// All Compute based subclasses operate on computeSize-ed data arrays.
// They get data using the gather method which uses random projection
// base dimension reduction of the entire reservoir. A change in one single
// value in the reservoir produces a unique pattern change in the gathered data.
// The gather method also can select specific places in the reservoir to get
// information from because of weighting prior to the random projection process.
// The reservoir is composed of 3 parts <input><writable section><general>
// When a compute object is finished it can either scatter the result to a
// specific place in the writable section or again use random projection with
// weighting based selection to write to the general section.
// This should allow for complex connectivity (eg. modualar) to emerge.
package s6regen;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.ArrayList;
public class Reservoir implements Serializable {
final static float MIN_SQ = 1e-20f;
final int computeSize;
private final int reservoirSize;
private final int inputSize;
private final int writableSize;
private final int outputSize;
int hashIndex;
int weightIndex;
private int weightSize;
float[] weights;
private transient float[] reservoir;
transient float[][] computeBuffers;
private transient RNG rng;
private final ArrayList<Compute> list; // list of all compute units for the AI
public Reservoir(int computeSize, int reservoirSize, int inputSize, int writableSize, int outputSize) {
assert computeSize >= 16 : "Requirement for WTH class";
assert (computeSize & (computeSize - 1)) == 0 : "Power of 2 for WHT class";
assert reservoirSize % computeSize == 0 : "Make multiple of computeSize";
assert inputSize % computeSize == 0 : "Make multiple of computeSize";
assert writableSize % computeSize == 0 : "Make multiple of computeSize";
this.computeSize = computeSize;
this.reservoirSize = reservoirSize;
this.inputSize = inputSize;
this.writableSize = writableSize;
this.outputSize = outputSize;
list = new ArrayList<>();
}
public void addComputeUnit(Compute c) {
list.add(c);
}
// Call after adding compute units. Don't add more compute units and call again.
// Also sets up after deserialization. Called from the readObject() method.
public void prepareForUse() {
reservoir = new float[reservoirSize]; //transient
int bN = 0;
for (Compute c : list) {
weightSize += c.weightSize();
if (bN < c.buffersRequired()) {
bN = c.buffersRequired();
}
}
computeBuffers = new float[bN][computeSize]; //transient
rng = new RNG(); //transient
if (weights == null) { // if deserializing weights typically != null
weights = new float[weightSize];
for (int i = 0; i < weightSize; i++) { // randomize inital weights.
weights[i] = rng.nextFloatSym();
}
}
}
public void computeAll() {
hashIndex = 0;
weightIndex = 0;
for (Compute c : list) {
c.compute();
}
assert weightIndex == weightSize : "Error in Compute subclass weightIndex or weightSize";
}
// clears all held state such as in associative memory.
public void resetHeldStateAll() {
for (Compute c : list) {
c.resetHeldState();
}
}
public void setInput(float[] input) {
System.arraycopy(input, 0, reservoir, 0, inputSize);
}
public void getOutput(float[] output) {
System.arraycopy(reservoir, inputSize + writableSize, output, 0, outputSize);
}
public void mutate(long mutatePrecision) {
for (int i = 0; i < weightSize; i++) {
weights[i] = rng.mutateXSym(weights[i], mutatePrecision);
}
}
public int getWeightSize() {
return weightSize;
}
public void getWeights(float[] vec) {
System.arraycopy(weights, 0, vec, 0, weightSize);
}
public void setWeights(float[] vec) {
System.arraycopy(vec, 0, weights, 0, weightSize);
}
void gather(float[] g) {
int wtIdx = weightIndex; // Get as local as an optimization
float[] wt = weights; // Check later if this is so necessary
float[] res = reservoir;
int i = 0;
while (i < computeSize) {
g[i] = res[i] * wt[wtIdx++];
i++;
}
while (i < reservoirSize) {
WHT.fastRP(g, hashIndex++);
for (int j = 0; j < computeSize; j++) {
g[j] += res[i++] * wt[wtIdx++];
}
}
weightIndex = wtIdx; // set weightIndex to correct value
WHT.fastRP(g, hashIndex++);
}
// s is destroyed by this method, it is assumed the compute unit will not need
// it again.
void scatter(float[] s) {
int wtIdx = weightIndex; // Get as local as an optimization
float[] wt = weights;
float[] res = reservoir;
int i = inputSize + writableSize;
while (i < reservoirSize) {
WHT.fastRP(s, hashIndex++);
/* for (int j = 0; j < computeSize; j++) {
float p=wt[wtIdx++];
if(p>=0f){ // if p<0 then leave the reservoir value unchanged.
res[i]= s[j]*p+(1f-p)*res[i]; // otherwise blend the two.
}
i++;
} */
for (int j = 0; j < computeSize; j++) {
float p = wt[wtIdx++];
p*=p; // smooth non-linear blending to make things easier for evolution, if wt is low eg 0.1
res[i] = s[j] * p + (1f - p) * res[i]; // the reservoir is hardly effected
i++;
}
}
weightIndex = wtIdx; // put back the new index
}
void scatterWritable(float[] s, int location) {
System.arraycopy(s, 0, reservoir, inputSize + location, computeSize);
}
int sizeGather() {
return reservoirSize;
}
int sizeScatter() {
return reservoirSize - inputSize - writableSize;
}
void normalizeGeneral() {
int start = inputSize + writableSize;
float sumSq = 0f;
for (int i = start; i < reservoirSize; i++) {
sumSq += reservoir[i] * reservoir[i];
}
float adj = 1f / (float) Math.sqrt((sumSq / (reservoirSize - start)) + MIN_SQ);
for (int i = start; i < reservoirSize; i++) {
reservoir[i] *= adj;
}
}
void normalizeInput() {
float sumSq = 0f;
for (int i = 0; i < inputSize; i++) {
sumSq += reservoir[i] * reservoir[i];
}
float adj = 1f / (float) Math.sqrt((sumSq / inputSize) + MIN_SQ);
for (int i = 0; i < inputSize; i++) {
reservoir[i] *= adj;
}
}
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
in.defaultReadObject();
prepareForUse(); //Set up all the buffers and working arrays
}
}