diff --git a/src/main.py b/src/main.py index e88ad00..3d0862d 100644 --- a/src/main.py +++ b/src/main.py @@ -582,7 +582,7 @@ def crossover(parents, num_crossovers, d_model): return crossover_population -def fine_tune(model, data, target_length, lambda_factors, n_hat, num_epochs=3): +def fine_tune(model, data, target_length, lambda_factors, n_hat, steps): """ Fine-tune the LongRoPE model. @@ -592,31 +592,33 @@ def fine_tune(model, data, target_length, lambda_factors, n_hat, num_epochs=3): target_length (int): Target context window length. lambda_factors (list): Lambda factors for interpolation. n_hat (int): Threshold for applying interpolation. - num_epochs (int, optional): Number of fine-tuning epochs. Defaults to 3. + steps (int): Number of fine-tuning steps, as specified in the paper. Returns: nn.Module: Fine-tuned LongRoPE model. """ - model.lambda_factors = lambda_factors - model.n_hat = n_hat + model.lambda_factors[f"{target_length // 1000}k"] = lambda_factors + model.n_hat[f"{target_length // 1000}k"] = n_hat optimizer = optim.Adam(model.parameters(), lr=1e-4) - for epoch in range(num_epochs): - for seq in data: - optimizer.zero_grad() + for step in range(steps): + optimizer.zero_grad() - seq_len = seq.size(0) - if seq_len <= target_length: - input_ids = seq.unsqueeze(0) - else: - start_idx = random.randint(0, seq_len - target_length) - input_ids = seq[start_idx : start_idx + target_length].unsqueeze(0) + # Randomly select a sequence from the data + seq = random.choice(data) + seq_len = seq.size(0) - output = model(input_ids) - loss = torch.mean(output) + if seq_len <= target_length: + input_ids = seq.unsqueeze(0) + else: + start_idx = random.randint(0, seq_len - target_length) + input_ids = seq[start_idx : start_idx + target_length].unsqueeze(0) + + output = model(input_ids) + loss = torch.mean(output) - loss.backward() - optimizer.step() + loss.backward() + optimizer.step() return model