Skip to content

Commit

Permalink
Added specific short context lengths (4096, 8192).
Browse files Browse the repository at this point in the history
Implemented a separate search and fine-tuning process for these lengths. Ensure the model model maintains good performance on shorter contexts even after being extended to very long contexts.
  • Loading branch information
jshuadvd committed Jun 30, 2024
1 parent 6eec8a7 commit 440b21b
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,26 +640,36 @@ def progressive_extension(

def short_context_recovery(model, data, base_length, lambda_factors_base, n_hat_base):
"""
Recover performance on shorter context lengths.
This function ensures that the model maintains good performance on shorter contexts (4k and 8k)
even after being extended to very long contexts.
Args:
model (nn.Module): LongRoPE model.
data (list): List of input sequences.
base_length (int): Base context window length.
lambda_factors_base (list): Base lambda factors.
n_hat_base (int): Base n_hat.
model (nn.Module): Extended LongRoPE model.
data (list): List of input sequences for fine-tuning and evaluation.
base_length (int): Original context window length of the model.
lambda_factors_base (list): Base lambda factors for the extended model.
n_hat_base (int): Base n_hat for the extended model.
Returns:
nn.Module: Recovered LongRoPE model.
nn.Module: LongRoPE model with recovered short context performance.
"""
short_lengths = [base_length // 2, base_length // 4]
short_lengths = [4096, 8192] # Specific lengths mentioned in the paper

for length in short_lengths:
extension_ratio = length / base_length
lambda_factors, n_hat = search_lambda_factors(
model, data, extension_ratio, max_length=length
model,
data,
extension_ratio,
population_size=64,
num_mutations=16,
num_crossovers=16,
max_iterations=40,
)
model = fine_tune(model, data, length, lambda_factors, n_hat)
# Fine-tune for short context recovery
model = fine_tune(model, data, length, lambda_factors, n_hat, steps=100)

# Store base factors for use during inference
model.lambda_factors_base = lambda_factors_base
model.n_hat_base = n_hat_base

Expand Down

0 comments on commit 440b21b

Please sign in to comment.