-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdepth_repertoire.py
447 lines (388 loc) · 16.3 KB
/
depth_repertoire.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
from __future__ import annotations
import jax
import jax.numpy as jnp
import flax.struct
from qdax.core.containers.mapelites_repertoire import get_cells_indices
from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey
from functools import partial
from ..utils import jax_jit
class DeepMapElitesRepertoire(flax.struct.PyTreeNode):
'''
Class for the deep repertoire in Map Elites.
Args:
genotypes: a PyTree containing the genotypes of the best solutions ordered
by the centroids. Each leaf has a shape (num_centroids, num_features). The
PyTree can be a simple Jax array or a more complex nested structure such
as to represent parameters of neural network in Flax.
genotypes_depth: a PyTree containing all the genotypes ordered by the centroids.
Each leaf has a shape (num_centroids * depth, num_features). The PyTree
can be a simple Jax array or a more complex nested structure such as to
represent parameters of neural network in Flax.
fitnesses: an array that contains the fitness of best solutions in each cell of
the repertoire, ordered by centroids. The array shape is (num_centroids,).
fitnesses_depth: an array that contains the fitness of all solutions in each
cell of the repertoire, ordered by centroids. The array shape
is (num_centroids * depth).
descriptors: an array that contains the descriptors of best solutions in each
cell of the repertoire, ordered by centroids. The array shape
is (num_centroids, num_descriptors).
descriptors_depth: an array that contains the descriptors of all solutions in
each cell of the repertoire, ordered by centroids. The array shape
is (num_centroids * depth, num_descriptors).
centroids: an array the contains the centroids of the tesselation. The array
shape is (num_centroids, num_descriptors).
'''
genotypes: Genotype
genotypes_depth: Genotype
fitnesses: Fitness
fitnesses_depth: Fitness
descriptors: Descriptor
descriptors_depth: Descriptor
centroids: Centroid
dims: jax.Array
@partial(jax_jit, static_argnames=('num_samples',))
def sample(self, random_key: RNGKey, num_samples: int) -> tuple[Genotype, RNGKey]:
'''
Sample elements in the grid. Sample only from the best individuals ('first
layer of the depth') contained in genotypes, fitnesses and descriptors.
Args:
random_key: a jax PRNG random key
num_samples: the number of elements to be sampled
Returns:
samples: a batch of genotypes sampled in the repertoire
random_key: an updated jax PRNG random key
'''
random_key, sub_key = jax.random.split(random_key)
grid_empty = self.fitnesses == -jnp.inf
p = (1.0 - grid_empty) / jnp.sum(1.0 - grid_empty)
samples = jax.tree_map(
lambda x: jax.random.choice(sub_key, x, shape=(num_samples,), p=p),
self.genotypes,
)
return samples, random_key
@jax_jit
def _cell_min_fitnesses(
self,
cells_indices: jax.Array,
batch_of_indices: jax.Array,
batch_of_fitnesses: Fitness,
) -> Fitness:
'''
Sub-method for add(). Give the minimum fitness in each cell of cells_indices,
given current indivs in cells and new indivs to add in cells.
!!!WARNING!!! This is strict min fitness and should be used with >, not >=.
Args:
cells_indices: the cells to consider
batch_of_indices: indices of new indivs
batch_of_fitnesses: fitnesses of new indivs
Returns: minimum fitness for each cell in cells_indices
'''
@partial(jax_jit, static_argnames=('depth',))
def _get_cell_min_fitnesses(
idx: jax.Array,
fitnesses_depth_reshape: Fitness,
depth: int,
batch_of_indices: jax.Array,
batch_of_fitnesses: Fitness,
) -> jax.Array:
'''
Applied using vmap on all cells_indices.
'''
filter_fitnesses = jnp.where(
batch_of_indices == idx, batch_of_fitnesses, -jnp.inf
)
all_fitnesses = jnp.concatenate(
[filter_fitnesses, fitnesses_depth_reshape], axis=0
)
min_fitnesses, _ = jax.lax.top_k(all_fitnesses, depth + 1)
return min_fitnesses[depth]
get_cell_min_fitnesses_fn = partial(
_get_cell_min_fitnesses,
depth=self.dims.shape[0],
batch_of_indices=batch_of_indices,
batch_of_fitnesses=batch_of_fitnesses,
)
return jax.vmap(get_cell_min_fitnesses_fn)(
cells_indices,
jnp.reshape(
self.fitnesses_depth, (self.centroids.shape[0], self.dims.shape[0])
)[cells_indices],
)
@jax_jit
def _indices_to_occurence(
self,
batch_of_indices: jax.Array,
) -> jax.Array:
'''
Sub-method for add(). Return an array similar to the batch_of_indices
replacing each indice with its occurence number in the batch.
Args:
batch_of_indices: indices of new indivs
Returns: batch_of_occurences: number of occurence for each indice
'''
@partial(jax_jit, static_argnames=('num_centroids',))
def _cumulative_count(
idx: jax.Array,
indices: jax.Array,
batch_of_indices: jax.Array,
num_centroids: int,
) -> jax.Array:
filter_batch_of_indices = jnp.where(
indices.ravel() <= idx, batch_of_indices, num_centroids
)
count_indices = jnp.bincount(filter_batch_of_indices, length=num_centroids)
return count_indices.at[batch_of_indices[idx]].get() - 1
num_centroids = self.centroids.shape[0]
# Get occurence
indices = jnp.arange(0, batch_of_indices.size, step=1)
cumulative_count = partial(
_cumulative_count,
indices=indices,
batch_of_indices=batch_of_indices,
num_centroids=num_centroids,
)
batch_of_occurence = jax.vmap(cumulative_count)(indices)
# Filter out-of-bond individuals
out_of_bound = self.dims.shape[0] * num_centroids
batch_of_occurence = jnp.where(
batch_of_indices < out_of_bound, batch_of_occurence, out_of_bound
)
return batch_of_occurence
@jax_jit
def _place_indivs(
self,
batch_of_indices: jax.Array,
batch_of_fitnesses: Fitness,
) -> jax.Array:
'''
Sub-method for add(). Return indices to place new indiv in the depth grid.
Args:
batch_of_indices: indices of new indivs
batch_of_fitnesses: fitnesses of new indivs
Returns: indices to place each new indiv
'''
num_centroids = self.centroids.shape[0]
depth = self.dims.shape[0]
out_of_bound = num_centroids * depth # Index of non-added individuals
# Get minimum fitness in each cell after addition
min_fitnesses = self._cell_min_fitnesses(
jnp.arange(0, num_centroids, step=1),
batch_of_indices,
batch_of_fitnesses,
)
# Filter individuals and keep those greater than min
batch_of_indices = jnp.where(
batch_of_fitnesses > min_fitnesses[batch_of_indices],
batch_of_indices,
out_of_bound,
)
# Get in-cell indices of individuals
batch_of_cell_indices = self._indices_to_occurence(batch_of_indices)
batch_of_cell_indices = jnp.where(
batch_of_indices < out_of_bound,
batch_of_indices * depth + batch_of_cell_indices,
out_of_bound,
)
# Filter empty slots using minimum fitness
@jax_jit
def _get_empty_slots(
slots: jax.Array,
fitness: Fitness,
min_fitness: Fitness,
out_of_bound: int,
) -> jax.Array:
return jnp.where(fitness > min_fitness, out_of_bound, slots)
get_empty_slots = partial(_get_empty_slots, out_of_bound=out_of_bound)
empty_slots = jax.vmap(get_empty_slots)(
jnp.reshape(
jnp.arange(0, num_centroids * depth, step=1), (num_centroids, depth)
),
jnp.reshape(self.fitnesses_depth, (num_centroids, depth)),
min_fitnesses,
)
# Sort the indices in each cell
empty_slots = jnp.sort(empty_slots, axis=1)
# Transforms in-cell indices to account for empty slots
final_batch_of_indices = jnp.where(
batch_of_cell_indices < out_of_bound,
empty_slots.ravel()[batch_of_cell_indices],
out_of_bound,
)
return final_batch_of_indices
@jax_jit
def add(
self,
batch_of_genotypes: Genotype,
batch_of_descriptors: Descriptor,
batch_of_fitnesses: Fitness,
batch_of_extra_scores: ExtraScores,
) -> DeepMapElitesRepertoire:
'''
Add a batch of elements to the repertoire.
Args:
batch_of_genotypes: a batch of genotypes to be added to the repertoire.
Similarly to the self.genotypes argument, this is a PyTree in which
the leaves have a shape (batch_size, num_features)
batch_of_descriptors: an array that contains the descriptors of the
aforementioned genotypes. Its shape is (batch_size, num_descriptors)
batch_of_fitnesses: an array that contains the fitnesses of the
batch_of_extra_scores: unused tree that contains the extra_scores of
aforementioned genotypes. Its shape is (batch_size,)
Returns:
The updated MAP-Elites repertoire.
'''
out_of_bound = (
self.dims.shape[0] * self.centroids.shape[0]
) # Index of non-added individuals
# Get indices for given descriptors
batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
# Filter dead individuals
batch_of_indices = jnp.where(
batch_of_fitnesses > -jnp.inf,
batch_of_indices,
out_of_bound,
)
# Get final indices of individuals added to top layer of the grid
# (i.e. best indivs added in: genotypes, fitnesses, descriptors)
best_fitnesses = jax.ops.segment_max(
batch_of_fitnesses,
batch_of_indices,
num_segments=self.centroids.shape[0],
)
filter_fitnesses = jnp.where(
best_fitnesses[batch_of_indices] == batch_of_fitnesses,
batch_of_fitnesses,
-jnp.inf,
)
current_fitnesses = jnp.take_along_axis(self.fitnesses, batch_of_indices, 0)
final_batch_of_max_indices = jnp.where(
filter_fitnesses > current_fitnesses,
batch_of_indices,
out_of_bound,
)
# Get final indices of individuals added to the depth of the grid
# (i.e. indivs in: genotypes_depth, fitnesses_depth, descriptors_depth)
final_batch_of_indices = self._place_indivs(
batch_of_indices, batch_of_fitnesses
)
# Create new grid
new_grid_genotypes_depth = jax.tree_map(
lambda grid_genotypes, new_genotypes: grid_genotypes.at[
final_batch_of_indices
].set(new_genotypes),
self.genotypes_depth,
batch_of_genotypes,
)
new_grid_genotypes = jax.tree_map(
lambda grid_genotypes, new_genotypes: grid_genotypes.at[
final_batch_of_max_indices
].set(new_genotypes),
self.genotypes,
batch_of_genotypes,
)
# Compute new fitness and descriptors
new_fitnesses = self.fitnesses.at[final_batch_of_max_indices].set(
batch_of_fitnesses
)
new_fitnesses_depth = self.fitnesses_depth.at[final_batch_of_indices].set(
batch_of_fitnesses
)
new_descriptors = self.descriptors.at[final_batch_of_max_indices].set(
batch_of_descriptors
)
new_descriptors_depth = self.descriptors_depth.at[final_batch_of_indices].set(
batch_of_descriptors
)
return DeepMapElitesRepertoire(
genotypes=new_grid_genotypes,
genotypes_depth=new_grid_genotypes_depth,
fitnesses=new_fitnesses,
fitnesses_depth=new_fitnesses_depth,
descriptors=new_descriptors,
descriptors_depth=new_descriptors_depth,
centroids=self.centroids,
dims=self.dims,
)
@classmethod
def init(
cls,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
centroids: Centroid,
depth: int,
) -> DeepMapElitesRepertoire:
'''
Initialize a Map-Elites repertoire with an initial population of genotypes.
Requires the definition of centroids that can be computed with any method
such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MapElites, so it can
be called easily called from other modules.
Args:
genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
fitnesses: fitness of the initial genotypes of shape (batch_size,)
descriptors: descriptors of the initial genotypes
of shape (batch_size, num_descriptors)
extra_scores: unused extra_scores of the initial genotypes
centroids: tesselation centroids of shape (batch_size, num_descriptors)
centroids: tesselation centroids of shape (batch_size, num_descriptors)
Returns:
an initialized MAP-Elite repertoire
'''
# Initialize grid with default values
num_centroids = centroids.shape[0]
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
default_fitnesses_depth = -jnp.inf * jnp.ones(shape=(num_centroids * depth))
default_genotypes = jax.tree_map(
lambda x: jnp.zeros(shape=(num_centroids,) + x.shape[1:]),
genotypes,
)
default_genotypes_depth = jax.tree_map(
lambda x: jnp.zeros(shape=(num_centroids * depth,) + x.shape[1:]),
genotypes,
)
default_descriptors = jnp.zeros(shape=(num_centroids, centroids.shape[-1]))
default_descriptors_depth = jnp.zeros(
shape=(num_centroids * depth, centroids.shape[-1])
)
dims = jnp.zeros(shape=(depth))
repertoire = DeepMapElitesRepertoire(
genotypes=default_genotypes,
genotypes_depth=default_genotypes_depth,
fitnesses=default_fitnesses,
fitnesses_depth=default_fitnesses_depth,
descriptors=default_descriptors,
descriptors_depth=default_descriptors_depth,
centroids=centroids,
dims=dims,
)
# Add initial values to the grid
new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)
return new_repertoire
@jax_jit
def empty(self) -> DeepMapElitesRepertoire:
'''
Empty the grid from all existing individuals.
Returns:
An empty DeepMapElitesRepertoire
'''
new_fitnesses = jnp.full_like(self.fitnesses, -jnp.inf)
new_fitnesses_depth = jnp.full_like(self.fitnesses_depth, -jnp.inf)
new_descriptors = jnp.zeros_like(self.descriptors)
new_descriptors_depth = jnp.zeros_like(self.descriptors_depth)
new_genotypes = jax.tree_map(lambda x: jnp.zeros_like(x), self.genotypes)
new_genotypes_depth = jax.tree_map(
lambda x: jnp.zeros_like(x), self.genotypes_depth
)
return DeepMapElitesRepertoire(
genotypes=new_genotypes,
genotypes_depth=new_genotypes_depth,
fitnesses=new_fitnesses,
fitnesses_depth=new_fitnesses_depth,
descriptors=new_descriptors,
descriptors_depth=new_descriptors_depth,
centroids=self.centroids,
dims=self.dims,
)