diff --git a/src/brush/deap_api/nsga2.py b/src/brush/deap_api/nsga2.py index a7d44f4d..a1a8c8b2 100644 --- a/src/brush/deap_api/nsga2.py +++ b/src/brush/deap_api/nsga2.py @@ -6,7 +6,7 @@ def nsga2(toolbox, NGEN, MU, CXPB, use_batch, verbosity, rnd_flt): # NGEN = 250 - # MU = 100 + # MU = 100 # CXPB = 0.9 # rnd_flt: random number generator to sample crossover prob @@ -18,15 +18,18 @@ def calculate_statistics(ind): stats = tools.Statistics(calculate_statistics) - stats.register("ave", np.mean, axis=0) + stats.register("avg", np.mean, axis=0) + stats.register("med", np.median, axis=0) stats.register("std", np.std, axis=0) stats.register("min", np.min, axis=0) stats.register("max", np.max, axis=0) logbook = tools.Logbook() - logbook.header = "gen", "evals", "ave (O1 train, O2 train, O1 val, O2 val)", \ + logbook.header = "gen", "evals", "avg (O1 train, O2 train, O1 val, O2 val)", \ + "med (O1 train, O2 train, O1 val, O2 val)", \ "std (O1 train, O2 train, O1 val, O2 val)", \ - "min (O1 train, O2 train, O1 val, O2 val)" + "min (O1 train, O2 train, O1 val, O2 val)", \ + "max (O1 train, O2 train, O1 val, O2 val)" pop = toolbox.population(n=MU) @@ -68,14 +71,12 @@ def calculate_statistics(ind): if rnd_flt() < CXPB: off1, off2 = toolbox.mate(ind1, ind2) else: - off1, off2 = ind1, ind2 - + off1 = toolbox.mutate(off1) + off2 = toolbox.mutate(off2) + # avoid inserting empty solutions - if off1 != None: off1 = toolbox.mutate(off1) - if off1 != None: offspring.extend([off1]) - - if off2 != None: off2 = toolbox.mutate(off2) - if off2 != None: offspring.extend([off2]) + if off1 is not None: offspring.extend([off1]) + if off2 is not None: offspring.extend([off2]) # archive.update(offspring) # Evaluate the individuals with an invalid fitness