Skip to content

Commit

Permalink
Update the crossover method to use dictionaries instead of tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 4, 2024
1 parent f510479 commit efce3b9
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,33 +561,34 @@ def mutate(parents, num_mutations, d_model):
return mutated_population


def mutate(parents, num_mutations, d_model):
def crossover(parents, num_crossovers, d_model):
"""
Perform mutation on the parent population.
Perform crossover on the parent population.
Args:
parents (list): Parent population.
num_mutations (int): Number of mutations to perform.
num_crossovers (int): Number of crossovers to perform.
d_model (int): Dimension of the model.
Returns:
list: Mutated population.
list: Crossover population.
"""
mutated_population = []
for _ in range(num_mutations):
parent = random.choice(parents)
child = {"lambda_i": parent["lambda_i"].clone(), "n_hat": parent["n_hat"]}
crossover_population = []
for _ in range(num_crossovers):
parent1 = random.choice(parents)
parent2 = random.choice(parents)
child = {"lambda_i": parent1["lambda_i"].clone(), "n_hat": parent1["n_hat"]}

for i in range(d_model):
if random.random() < 0.1:
child["lambda_i"][i] *= random.uniform(0.8, 1.2)
if random.random() < 0.5:
child["lambda_i"][i] = parent2["lambda_i"][i]

if random.random() < 0.1:
child["n_hat"] = random.randint(0, d_model)
if random.random() < 0.5:
child["n_hat"] = parent2["n_hat"]

mutated_population.append(child)
crossover_population.append(child)

return mutated_population
return crossover_population


def fine_tune(model, train_data, val_data, target_length, lambda_factors, n_hat, steps):
Expand Down

0 comments on commit efce3b9

Please sign in to comment.