Skip to content

Commit

Permalink
update the fine_tune method to use the steps mentioned in the paper f…
Browse files Browse the repository at this point in the history
…or each stage (400 for 128k, 600 for 256k)
  • Loading branch information
jshuadvd committed Jul 3, 2024
1 parent c83bd6a commit 5f6c465
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def crossover(parents, num_crossovers, d_model):
return crossover_population


def fine_tune(model, data, target_length, lambda_factors, n_hat, num_epochs=3):
def fine_tune(model, data, target_length, lambda_factors, n_hat, steps):
"""
Fine-tune the LongRoPE model.
Expand All @@ -592,31 +592,33 @@ def fine_tune(model, data, target_length, lambda_factors, n_hat, num_epochs=3):
target_length (int): Target context window length.
lambda_factors (list): Lambda factors for interpolation.
n_hat (int): Threshold for applying interpolation.
num_epochs (int, optional): Number of fine-tuning epochs. Defaults to 3.
steps (int): Number of fine-tuning steps, as specified in the paper.
Returns:
nn.Module: Fine-tuned LongRoPE model.
"""
model.lambda_factors = lambda_factors
model.n_hat = n_hat
model.lambda_factors[f"{target_length // 1000}k"] = lambda_factors
model.n_hat[f"{target_length // 1000}k"] = n_hat
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(num_epochs):
for seq in data:
optimizer.zero_grad()
for step in range(steps):
optimizer.zero_grad()

seq_len = seq.size(0)
if seq_len <= target_length:
input_ids = seq.unsqueeze(0)
else:
start_idx = random.randint(0, seq_len - target_length)
input_ids = seq[start_idx : start_idx + target_length].unsqueeze(0)
# Randomly select a sequence from the data
seq = random.choice(data)
seq_len = seq.size(0)

output = model(input_ids)
loss = torch.mean(output)
if seq_len <= target_length:
input_ids = seq.unsqueeze(0)
else:
start_idx = random.randint(0, seq_len - target_length)
input_ids = seq[start_idx : start_idx + target_length].unsqueeze(0)

output = model(input_ids)
loss = torch.mean(output)

loss.backward()
optimizer.step()
loss.backward()
optimizer.step()

return model

Expand Down

0 comments on commit 5f6c465

Please sign in to comment.